1
0
mirror of https://github.com/chylex/Discord-History-Tracker.git synced 2024-10-18 20:42:51 +02:00

Compare commits

...

3 Commits

53 changed files with 1950 additions and 1398 deletions

View File

@ -11,6 +11,7 @@ using DHT.Desktop.Dialogs.Message;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Database.Exceptions; using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite; using DHT.Server.Database.Sqlite;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging; using DHT.Utils.Logging;
namespace DHT.Desktop.Common; namespace DHT.Desktop.Common;
@ -20,9 +21,9 @@ static class DatabaseGui {
private const string DatabaseFileInitialName = "archive.dht"; private const string DatabaseFileInitialName = "archive.dht";
private static readonly IReadOnlyList<FilePickerFileType> DatabaseFileDialogFilter = new List<FilePickerFileType> { private static readonly IReadOnlyList<FilePickerFileType> DatabaseFileDialogFilter = [
FileDialogs.CreateFilter("Discord History Tracker Database", new [] { "dht" }) FileDialogs.CreateFilter("Discord History Tracker Database", ["dht"])
}; ];
public static async Task<string[]> NewOpenDatabaseFilesDialog(Window window, string? suggestedDirectory) { public static async Task<string[]> NewOpenDatabaseFilesDialog(Window window, string? suggestedDirectory) {
return await window.StorageProvider.OpenFiles(new FilePickerOpenOptions { return await window.StorageProvider.OpenFiles(new FilePickerOpenOptions {

View File

@ -4,5 +4,6 @@ namespace DHT.Desktop.Dialogs.Progress;
interface IProgressCallback { interface IProgressCallback {
Task Update(string message, int finishedItems, int totalItems); Task Update(string message, int finishedItems, int totalItems);
Task UpdateIndeterminate(string message);
Task Hide(); Task Hide();
} }

View File

@ -40,7 +40,7 @@
<TextBlock DockPanel.Dock="Right" Text="{Binding Items}" Classes="items" /> <TextBlock DockPanel.Dock="Right" Text="{Binding Items}" Classes="items" />
<TextBlock DockPanel.Dock="Left" Text="{Binding Message}" /> <TextBlock DockPanel.Dock="Left" Text="{Binding Message}" />
</DockPanel> </DockPanel>
<ProgressBar Value="{Binding Progress}" /> <ProgressBar IsIndeterminate="{Binding IsIndeterminate}" Value="{Binding Progress}" />
</StackPanel> </StackPanel>
</DataTemplate> </DataTemplate>
</ItemsRepeater.ItemTemplate> </ItemsRepeater.ItemTemplate>

View File

@ -7,6 +7,58 @@ namespace DHT.Desktop.Dialogs.Progress;
[SuppressMessage("ReSharper", "MemberCanBeInternal")] [SuppressMessage("ReSharper", "MemberCanBeInternal")]
public sealed partial class ProgressDialog : Window { public sealed partial class ProgressDialog : Window {
internal static async Task Show(Window owner, string title, Func<ProgressDialog, IProgressCallback, Task> action) {
var taskCompletionSource = new TaskCompletionSource();
var dialog = new ProgressDialog();
dialog.DataContext = new ProgressDialogModel(title, async callbacks => {
try {
await action(dialog, callbacks[0]);
taskCompletionSource.SetResult();
} catch (Exception e) {
taskCompletionSource.SetException(e);
}
});
await dialog.ShowProgressDialog(owner);
await taskCompletionSource.Task;
}
internal static async Task ShowIndeterminate(Window owner, string title, string message, Func<ProgressDialog, Task> action) {
var taskCompletionSource = new TaskCompletionSource();
var dialog = new ProgressDialog();
dialog.DataContext = new ProgressDialogModel(title, async callbacks => {
await callbacks[0].UpdateIndeterminate(message);
try {
await action(dialog);
taskCompletionSource.SetResult();
} catch (Exception e) {
taskCompletionSource.SetException(e);
}
});
await dialog.ShowProgressDialog(owner);
await taskCompletionSource.Task;
}
internal static async Task<T> ShowIndeterminate<T>(Window owner, string title, string message, Func<ProgressDialog, Task<T>> action) {
var taskCompletionSource = new TaskCompletionSource<T>();
var dialog = new ProgressDialog();
dialog.DataContext = new ProgressDialogModel(title, async callbacks => {
await callbacks[0].UpdateIndeterminate(message);
try {
taskCompletionSource.SetResult(await action(dialog));
} catch (Exception e) {
taskCompletionSource.SetException(e);
}
});
await dialog.ShowProgressDialog(owner);
return await taskCompletionSource.Task;
}
private bool isFinished = false; private bool isFinished = false;
private Task progressTask = Task.CompletedTask; private Task progressTask = Task.CompletedTask;

View File

@ -18,9 +18,10 @@ sealed class ProgressDialogModel : BaseModel {
[Obsolete("Designer")] [Obsolete("Designer")]
public ProgressDialogModel() {} public ProgressDialogModel() {}
public ProgressDialogModel(TaskRunner task, int progressItems = 1) { public ProgressDialogModel(string title, TaskRunner task, int progressItems = 1) {
this.Items = Enumerable.Range(0, progressItems).Select(static _ => new ProgressItem()).ToArray(); this.Title = title;
this.task = task; this.task = task;
this.Items = Enumerable.Range(0, progressItems).Select(static _ => new ProgressItem()).ToArray();
} }
internal async Task StartTask() { internal async Task StartTask() {
@ -43,6 +44,16 @@ sealed class ProgressDialogModel : BaseModel {
item.Message = message; item.Message = message;
item.Items = totalItems == 0 ? string.Empty : finishedItems.Format() + " / " + totalItems.Format(); item.Items = totalItems == 0 ? string.Empty : finishedItems.Format() + " / " + totalItems.Format();
item.Progress = totalItems == 0 ? 0 : 100 * finishedItems / totalItems; item.Progress = totalItems == 0 ? 0 : 100 * finishedItems / totalItems;
item.IsIndeterminate = false;
});
}
public async Task UpdateIndeterminate(string message) {
await Dispatcher.UIThread.InvokeAsync(() => {
item.Message = message;
item.Items = string.Empty;
item.Progress = 0;
item.IsIndeterminate = true;
}); });
} }

View File

@ -38,4 +38,11 @@ sealed class ProgressItem : BaseModel {
get => progress; get => progress;
set => Change(ref progress, value); set => Change(ref progress, value);
} }
private bool isIndeterminate;
public bool IsIndeterminate {
get => isIndeterminate;
set => Change(ref isIndeterminate, value);
}
} }

View File

@ -39,7 +39,8 @@ static class DiscordAppSettings {
public static async Task<bool?> AreDevToolsEnabled() { public static async Task<bool?> AreDevToolsEnabled() {
try { try {
return AreDevToolsEnabled(await ReadSettingsJson()); var settingsJson = await ReadSettingsJson().ConfigureAwait(false);
return AreDevToolsEnabled(settingsJson);
} catch (Exception e) { } catch (Exception e) {
Log.Error("Cannot read settings file."); Log.Error("Cannot read settings file.");
Log.Error(e); Log.Error(e);

View File

@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.Threading.Tasks;
using DHT.Desktop.Common; using DHT.Desktop.Common;
using DHT.Server; using DHT.Server;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
@ -13,17 +14,17 @@ namespace DHT.Desktop.Main.Controls;
sealed class AttachmentFilterPanelModel : BaseModel, IDisposable { sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
public sealed record Unit(string Name, uint Scale); public sealed record Unit(string Name, uint Scale);
private static readonly Unit[] AllUnits = { private static readonly Unit[] AllUnits = [
new ("B", 1), new Unit("B", 1),
new ("kB", 1024), new Unit("kB", 1024),
new ("MB", 1024 * 1024) new Unit("MB", 1024 * 1024)
}; ];
private static readonly HashSet<string> FilterProperties = new () { private static readonly HashSet<string> FilterProperties = [
nameof(LimitSize), nameof(LimitSize),
nameof(MaximumSize), nameof(MaximumSize),
nameof(MaximumSizeUnit) nameof(MaximumSizeUnit)
}; ];
public string FilterStatisticsText { get; private set; } = ""; public string FilterStatisticsText { get; private set; } = "";
@ -51,7 +52,7 @@ sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
private readonly State state; private readonly State state;
private readonly string verb; private readonly string verb;
private readonly AsyncValueComputer<long> matchingAttachmentCountComputer; private readonly RestartableTask<long> matchingAttachmentCountTask;
private long? matchingAttachmentCount; private long? matchingAttachmentCount;
private long? totalAttachmentCount; private long? totalAttachmentCount;
@ -62,7 +63,7 @@ sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
this.state = state; this.state = state;
this.verb = verb; this.verb = verb;
this.matchingAttachmentCountComputer = AsyncValueComputer<long>.WithResultProcessor(SetAttachmentCounts).Build(); this.matchingAttachmentCountTask = new RestartableTask<long>(SetAttachmentCounts, TaskScheduler.FromCurrentSynchronizationContext());
UpdateFilterStatistics(); UpdateFilterStatistics();
@ -90,14 +91,14 @@ sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
private void UpdateFilterStatistics() { private void UpdateFilterStatistics() {
var filter = CreateFilter(); var filter = CreateFilter();
if (filter.IsEmpty) { if (filter.IsEmpty) {
matchingAttachmentCountComputer.Cancel(); matchingAttachmentCountTask.Cancel();
matchingAttachmentCount = totalAttachmentCount; matchingAttachmentCount = totalAttachmentCount;
UpdateFilterStatisticsText(); UpdateFilterStatisticsText();
} }
else { else {
matchingAttachmentCount = null; matchingAttachmentCount = null;
UpdateFilterStatisticsText(); UpdateFilterStatisticsText();
matchingAttachmentCountComputer.Compute(() => state.Db.CountAttachments(filter)); matchingAttachmentCountTask.Restart(cancellationToken => state.Db.Downloads.CountAttachments(filter, cancellationToken));
} }
} }
@ -115,7 +116,7 @@ sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
} }
public AttachmentFilter CreateFilter() { public AttachmentFilter CreateFilter() {
AttachmentFilter filter = new(); AttachmentFilter filter = new ();
if (LimitSize) { if (LimitSize) {
try { try {

View File

@ -8,6 +8,7 @@ using Avalonia.Controls;
using DHT.Desktop.Common; using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.CheckBox; using DHT.Desktop.Dialogs.CheckBox;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Server; using DHT.Server;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
@ -18,7 +19,7 @@ using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Controls; namespace DHT.Desktop.Main.Controls;
sealed class MessageFilterPanelModel : BaseModel, IDisposable { sealed class MessageFilterPanelModel : BaseModel, IDisposable {
private static readonly HashSet<string> FilterProperties = new () { private static readonly HashSet<string> FilterProperties = [
nameof(FilterByDate), nameof(FilterByDate),
nameof(StartDate), nameof(StartDate),
nameof(EndDate), nameof(EndDate),
@ -26,7 +27,7 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
nameof(IncludedChannels), nameof(IncludedChannels),
nameof(FilterByUser), nameof(FilterByUser),
nameof(IncludedUsers) nameof(IncludedUsers)
}; ];
public string FilterStatisticsText { get; private set; } = ""; public string FilterStatisticsText { get; private set; } = "";
@ -62,8 +63,8 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
set => Change(ref filterByChannel, value); set => Change(ref filterByChannel, value);
} }
public HashSet<ulong> IncludedChannels { public HashSet<ulong>? IncludedChannels {
get => includedChannels ?? state.Db.GetAllChannels().Select(static channel => channel.Id).ToHashSet(); get => includedChannels;
set => Change(ref includedChannels, value); set => Change(ref includedChannels, value);
} }
@ -72,8 +73,8 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
set => Change(ref filterByUser, value); set => Change(ref filterByUser, value);
} }
public HashSet<ulong> IncludedUsers { public HashSet<ulong>? IncludedUsers {
get => includedUsers ?? state.Db.GetAllUsers().Select(static user => user.Id).ToHashSet(); get => includedUsers;
set => Change(ref includedUsers, value); set => Change(ref includedUsers, value);
} }
@ -95,7 +96,7 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
private readonly State state; private readonly State state;
private readonly string verb; private readonly string verb;
private readonly AsyncValueComputer<long> exportedMessageCountComputer; private readonly RestartableTask<long> exportedMessageCountTask;
private long? exportedMessageCount; private long? exportedMessageCount;
private long? totalMessageCount; private long? totalMessageCount;
@ -107,7 +108,7 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
this.state = state; this.state = state;
this.verb = verb; this.verb = verb;
this.exportedMessageCountComputer = AsyncValueComputer<long>.WithResultProcessor(SetExportedMessageCount).Build(); this.exportedMessageCountTask = new RestartableTask<long>(SetExportedMessageCount, TaskScheduler.FromCurrentSynchronizationContext());
UpdateFilterStatistics(); UpdateFilterStatistics();
UpdateChannelFilterLabel(); UpdateChannelFilterLabel();
@ -118,6 +119,7 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
} }
public void Dispose() { public void Dispose() {
exportedMessageCountTask.Cancel();
state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged; state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
} }
@ -148,17 +150,29 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
} }
} }
private void UpdateChannelFilterLabel() {
long total = state.Db.Statistics.TotalChannels;
long included = FilterByChannel && IncludedChannels != null ? IncludedChannels.Count : total;
ChannelFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("channel") + ".";
}
private void UpdateUserFilterLabel() {
long total = state.Db.Statistics.TotalUsers;
long included = FilterByUser && IncludedUsers != null ? IncludedUsers.Count : total;
UserFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("user") + ".";
}
private void UpdateFilterStatistics() { private void UpdateFilterStatistics() {
var filter = CreateFilter(); var filter = CreateFilter();
if (filter.IsEmpty) { if (filter.IsEmpty) {
exportedMessageCountComputer.Cancel(); exportedMessageCountTask.Cancel();
exportedMessageCount = totalMessageCount; exportedMessageCount = totalMessageCount;
UpdateFilterStatisticsText(); UpdateFilterStatisticsText();
} }
else { else {
exportedMessageCount = null; exportedMessageCount = null;
UpdateFilterStatisticsText(); UpdateFilterStatisticsText();
exportedMessageCountComputer.Compute(() => state.Db.CountMessages(filter)); exportedMessageCountTask.Restart(cancellationToken => state.Db.Messages.Count(filter, cancellationToken));
} }
} }
@ -175,103 +189,98 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
OnPropertyChanged(nameof(FilterStatisticsText)); OnPropertyChanged(nameof(FilterStatisticsText));
} }
public async void OpenChannelFilterDialog() { public async Task OpenChannelFilterDialog() {
var servers = state.Db.GetAllServers().ToDictionary(static server => server.Id); async Task<List<CheckBoxItem<ulong>>> PrepareChannelItems(ProgressDialog dialog) {
var items = new List<CheckBoxItem<ulong>>(); var items = new List<CheckBoxItem<ulong>>();
var included = IncludedChannels; var servers = await state.Db.Servers.Get().ToDictionaryAsync(static server => server.Id);
foreach (var channel in state.Db.GetAllChannels()) { await foreach (var channel in state.Db.Channels.Get()) {
var channelId = channel.Id; var channelId = channel.Id;
var channelName = channel.Name; var channelName = channel.Name;
string title; string title;
if (servers.TryGetValue(channel.Server, out var server)) { if (servers.TryGetValue(channel.Server, out var server)) {
var titleBuilder = new StringBuilder(); var titleBuilder = new StringBuilder();
var serverType = server.Type; var serverType = server.Type;
titleBuilder.Append('[') titleBuilder.Append('[')
.Append(ServerTypes.ToString(serverType)) .Append(ServerTypes.ToString(serverType))
.Append("] "); .Append("] ");
if (serverType == ServerType.DirectMessage) { if (serverType == ServerType.DirectMessage) {
titleBuilder.Append(channelName); titleBuilder.Append(channelName);
}
else {
titleBuilder.Append(server.Name)
.Append(" - ")
.Append(channelName);
}
title = titleBuilder.ToString();
} }
else { else {
titleBuilder.Append(server.Name) title = channelName;
.Append(" - ")
.Append(channelName);
} }
title = titleBuilder.ToString(); items.Add(new CheckBoxItem<ulong>(channelId) {
} Title = title,
else { Checked = IncludedChannels == null || IncludedChannels.Contains(channelId)
title = channelName; });
} }
items.Add(new CheckBoxItem<ulong>(channelId) { return items;
Title = title, }
Checked = included.Contains(channelId)
}); const string Title = "Included Channels";
List<CheckBoxItem<ulong>> items;
try {
items = await ProgressDialog.ShowIndeterminate(window, Title, "Loading channels...", PrepareChannelItems);
} catch (Exception e) {
await Dialog.ShowOk(window, Title, "Error loading channels: " + e.Message);
return;
} }
var result = await OpenIdFilterDialog(window, "Included Channels", items); var result = await OpenIdFilterDialog(Title, items);
if (result != null) { if (result != null) {
IncludedChannels = result; IncludedChannels = result;
} }
} }
public async void OpenUserFilterDialog() { public async Task OpenUserFilterDialog() {
var items = new List<CheckBoxItem<ulong>>(); async Task<List<CheckBoxItem<ulong>>> PrepareUserItems(ProgressDialog dialog) {
var included = IncludedUsers; var checkBoxItems = new List<CheckBoxItem<ulong>>();
foreach (var user in state.Db.GetAllUsers()) { await foreach (var user in state.Db.Users.Get()) {
var name = user.Name; var name = user.Name;
var discriminator = user.Discriminator; var discriminator = user.Discriminator;
items.Add(new CheckBoxItem<ulong>(user.Id) { checkBoxItems.Add(new CheckBoxItem<ulong>(user.Id) {
Title = discriminator == null ? name : name + " #" + discriminator, Title = discriminator == null ? name : name + " #" + discriminator,
Checked = included.Contains(user.Id) Checked = IncludedUsers == null || IncludedUsers.Contains(user.Id)
}); });
}
return checkBoxItems;
} }
var result = await OpenIdFilterDialog(window, "Included Users", items); const string Title = "Included Users";
List<CheckBoxItem<ulong>> items;
try {
items = await ProgressDialog.ShowIndeterminate(window, Title, "Loading users...", PrepareUserItems);
} catch (Exception e) {
await Dialog.ShowOk(window, Title, "Error loading users: " + e.Message);
return;
}
var result = await OpenIdFilterDialog(Title, items);
if (result != null) { if (result != null) {
IncludedUsers = result; IncludedUsers = result;
} }
} }
private void UpdateChannelFilterLabel() { private async Task<HashSet<ulong>?> OpenIdFilterDialog(string title, List<CheckBoxItem<ulong>> items) {
long total = state.Db.Statistics.TotalChannels;
long included = FilterByChannel ? IncludedChannels.Count : total;
ChannelFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("channel") + ".";
}
private void UpdateUserFilterLabel() {
long total = state.Db.Statistics.TotalUsers;
long included = FilterByUser ? IncludedUsers.Count : total;
UserFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("user") + ".";
}
public MessageFilter CreateFilter() {
MessageFilter filter = new();
if (FilterByDate) {
filter.StartDate = StartDate;
filter.EndDate = EndDate?.AddDays(1).AddMilliseconds(-1);
}
if (FilterByChannel) {
filter.ChannelIds = new HashSet<ulong>(IncludedChannels);
}
if (FilterByUser) {
filter.UserIds = new HashSet<ulong>(IncludedUsers);
}
return filter;
}
private static async Task<HashSet<ulong>?> OpenIdFilterDialog(Window window, string title, List<CheckBoxItem<ulong>> items) {
items.Sort(static (item1, item2) => item1.Title.CompareTo(item2.Title)); items.Sort(static (item1, item2) => item1.Title.CompareTo(item2.Title));
var model = new CheckBoxDialogModel<ulong>(items) { var model = new CheckBoxDialogModel<ulong>(items) {
@ -283,4 +292,23 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
return result == DialogResult.OkCancel.Ok ? model.SelectedItems.Select(static item => item.Item).ToHashSet() : null; return result == DialogResult.OkCancel.Ok ? model.SelectedItems.Select(static item => item.Item).ToHashSet() : null;
} }
public MessageFilter CreateFilter() {
MessageFilter filter = new ();
if (FilterByDate) {
filter.StartDate = StartDate;
filter.EndDate = EndDate?.AddDays(1).AddMilliseconds(-1);
}
if (FilterByChannel && IncludedChannels != null) {
filter.ChannelIds = new HashSet<ulong>(IncludedChannels);
}
if (FilterByUser && IncludedUsers != null) {
filter.UserIds = new HashSet<ulong>(IncludedUsers);
}
return filter;
}
} }

View File

@ -121,9 +121,10 @@ sealed class ServerConfigurationPanelModel : BaseModel, IDisposable {
ServerConfiguration.Port = port; ServerConfiguration.Port = port;
ServerConfiguration.Token = inputToken; ServerConfiguration.Token = inputToken;
await StartServer();
OnPropertyChanged(nameof(HasMadeChanges)); OnPropertyChanged(nameof(HasMadeChanges));
await StartServer();
} }
public void OnClickCancelChanges() { public void OnClickCancelChanges() {

View File

@ -34,7 +34,7 @@ sealed class StatusBarModel : BaseModel, IDisposable {
} }
public void Dispose() { public void Dispose() {
state.Server.StatusChanged += OnServerStatusChanged; state.Server.StatusChanged -= OnServerStatusChanged;
} }
private void OnServerStatusChanged(object? sender, ServerManager.Status e) { private void OnServerStatusChanged(object? sender, ServerManager.Status e) {

View File

@ -1,5 +1,4 @@
using System; using System;
using System.ComponentModel;
using System.IO; using System.IO;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -8,6 +7,7 @@ using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Main.Screens; using DHT.Desktop.Main.Screens;
using DHT.Desktop.Server; using DHT.Desktop.Server;
using DHT.Server; using DHT.Server;
using DHT.Server.Database;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using DHT.Utils.Models; using DHT.Utils.Models;
@ -15,7 +15,7 @@ namespace DHT.Desktop.Main;
sealed class MainWindowModel : BaseModel, IAsyncDisposable { sealed class MainWindowModel : BaseModel, IAsyncDisposable {
private const string DefaultTitle = "Discord History Tracker"; private const string DefaultTitle = "Discord History Tracker";
private static readonly Log Log = Log.ForType<MainWindowModel>(); private static readonly Log Log = Log.ForType<MainWindowModel>();
public string Title { get; private set; } = DefaultTitle; public string Title { get; private set; } = DefaultTitle;
@ -25,7 +25,6 @@ sealed class MainWindowModel : BaseModel, IAsyncDisposable {
private readonly WelcomeScreen welcomeScreen; private readonly WelcomeScreen welcomeScreen;
private readonly WelcomeScreenModel welcomeScreenModel; private readonly WelcomeScreenModel welcomeScreenModel;
private MainContentScreen? mainContentScreen;
private MainContentScreenModel? mainContentScreenModel; private MainContentScreenModel? mainContentScreenModel;
private readonly Window window; private readonly Window window;
@ -39,11 +38,11 @@ sealed class MainWindowModel : BaseModel, IAsyncDisposable {
this.window = window; this.window = window;
welcomeScreenModel = new WelcomeScreenModel(window); welcomeScreenModel = new WelcomeScreenModel(window);
welcomeScreenModel.DatabaseSelected += OnDatabaseSelected;
welcomeScreen = new WelcomeScreen { DataContext = welcomeScreenModel }; welcomeScreen = new WelcomeScreen { DataContext = welcomeScreenModel };
CurrentScreen = welcomeScreen; CurrentScreen = welcomeScreen;
welcomeScreenModel.PropertyChanged += WelcomeScreenModelOnPropertyChanged;
var dbFile = args.DatabaseFile; var dbFile = args.DatabaseFile;
if (!string.IsNullOrWhiteSpace(dbFile)) { if (!string.IsNullOrWhiteSpace(dbFile)) {
async void OnWindowOpened(object? o, EventArgs eventArgs) { async void OnWindowOpened(object? o, EventArgs eventArgs) {
@ -74,61 +73,59 @@ sealed class MainWindowModel : BaseModel, IAsyncDisposable {
} }
} }
private async void WelcomeScreenModelOnPropertyChanged(object? sender, PropertyChangedEventArgs e) { private async void OnDatabaseSelected(object? sender, IDatabaseFile db) {
if (e.PropertyName == nameof(welcomeScreenModel.Db)) { welcomeScreenModel.DatabaseSelected -= OnDatabaseSelected;
if (mainContentScreenModel != null) {
mainContentScreenModel.DatabaseClosed -= MainContentScreenModelOnDatabaseClosed; await DisposeState();
mainContentScreenModel.Dispose();
} state = new State(db);
if (state != null) { try {
await state.DisposeAsync(); await state.Server.Start(ServerConfiguration.Port, ServerConfiguration.Token);
} } catch (Exception ex) {
Log.Error(ex);
if (welcomeScreenModel.Db == null) { await Dialog.ShowOk(window, "Internal Server Error", ex.Message);
state = null;
Title = DefaultTitle;
mainContentScreenModel = null;
mainContentScreen = null;
CurrentScreen = welcomeScreen;
}
else {
state = new State(welcomeScreenModel.Db);
try {
await state.Server.Start(ServerConfiguration.Port, ServerConfiguration.Token);
} catch (Exception ex) {
Log.Error(ex);
await Dialog.ShowOk(window, "Internal Server Error", ex.Message);
}
Title = Path.GetFileName(state.Db.Path) + " - " + DefaultTitle;
mainContentScreenModel = new MainContentScreenModel(window, state);
await mainContentScreenModel.Initialize();
mainContentScreenModel.DatabaseClosed += MainContentScreenModelOnDatabaseClosed;
mainContentScreen = new MainContentScreen { DataContext = mainContentScreenModel };
CurrentScreen = mainContentScreen;
}
OnPropertyChanged(nameof(CurrentScreen));
OnPropertyChanged(nameof(Title));
window.Focus();
} }
mainContentScreenModel = new MainContentScreenModel(window, state);
mainContentScreenModel.DatabaseClosed += MainContentScreenModelOnDatabaseClosed;
Title = Path.GetFileName(state.Db.Path) + " - " + DefaultTitle;
CurrentScreen = new MainContentScreen { DataContext = mainContentScreenModel };
OnPropertyChanged(nameof(Title));
OnPropertyChanged(nameof(CurrentScreen));
window.Focus();
} }
private void MainContentScreenModelOnDatabaseClosed(object? sender, EventArgs e) { private async void MainContentScreenModelOnDatabaseClosed(object? sender, EventArgs e) {
welcomeScreenModel.CloseDatabase(); if (mainContentScreenModel != null) {
mainContentScreenModel.DatabaseClosed -= MainContentScreenModelOnDatabaseClosed;
mainContentScreenModel.Dispose();
mainContentScreenModel = null;
}
await DisposeState();
Title = DefaultTitle;
CurrentScreen = welcomeScreen;
welcomeScreenModel.DatabaseSelected += OnDatabaseSelected;
OnPropertyChanged(nameof(Title));
OnPropertyChanged(nameof(CurrentScreen));
} }
public async ValueTask DisposeAsync() { private async Task DisposeState() {
mainContentScreenModel?.Dispose();
if (state != null) { if (state != null) {
await state.DisposeAsync(); await state.DisposeAsync();
state = null; state = null;
} }
}
welcomeScreenModel.Dispose();
public async ValueTask DisposeAsync() {
mainContentScreenModel?.Dispose();
await DisposeState();
} }
} }

View File

@ -1,6 +1,8 @@
using System; using System;
using System.Threading.Tasks;
using Avalonia.Controls; using Avalonia.Controls;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Desktop.Main.Controls; using DHT.Desktop.Main.Controls;
using DHT.Server; using DHT.Server;
using DHT.Utils.Models; using DHT.Utils.Models;
@ -9,7 +11,7 @@ namespace DHT.Desktop.Main.Pages;
sealed class AdvancedPageModel : BaseModel, IDisposable { sealed class AdvancedPageModel : BaseModel, IDisposable {
public ServerConfigurationPanelModel ServerConfigurationModel { get; } public ServerConfigurationPanelModel ServerConfigurationModel { get; }
private readonly Window window; private readonly Window window;
private readonly State state; private readonly State state;
@ -27,8 +29,9 @@ sealed class AdvancedPageModel : BaseModel, IDisposable {
ServerConfigurationModel.Dispose(); ServerConfigurationModel.Dispose();
} }
public async void VacuumDatabase() { public async Task VacuumDatabase() {
state.Db.Vacuum(); const string Title = "Vacuum Database";
await Dialog.ShowOk(window, "Vacuum Database", "Done."); await ProgressDialog.ShowIndeterminate(window, Title, "Vacuuming database...", _ => state.Db.Vacuum());
await Dialog.ShowOk(window, Title, "Done.");
} }
} }

View File

@ -48,7 +48,7 @@
</DataGrid> </DataGrid>
</Expander> </Expander>
<StackPanel Orientation="Horizontal" Spacing="10"> <StackPanel Orientation="Horizontal" Spacing="10">
<Button Command="{Binding OnClickRetryFailedDownloads}" IsEnabled="{Binding HasFailedDownloads}">Retry Failed Downloads</Button> <Button Command="{Binding OnClickRetryFailedDownloads}" IsEnabled="{Binding IsRetryFailedOnDownloadsButtonEnabled}">Retry Failed Downloads</Button>
</StackPanel> </StackPanel>
</StackPanel> </StackPanel>
</StackPanel> </StackPanel>

View File

@ -11,12 +11,15 @@ using DHT.Server.Data;
using DHT.Server.Data.Aggregations; using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Utils.Logging;
using DHT.Utils.Models; using DHT.Utils.Models;
using DHT.Utils.Tasks; using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Pages; namespace DHT.Desktop.Main.Pages;
sealed class AttachmentsPageModel : BaseModel, IDisposable { sealed class AttachmentsPageModel : BaseModel, IDisposable {
private static readonly Log Log = Log.ForType<AttachmentsPageModel>();
private static readonly DownloadItemFilter EnqueuedItemFilter = new () { private static readonly DownloadItemFilter EnqueuedItemFilter = new () {
IncludeStatuses = new HashSet<DownloadStatus> { IncludeStatuses = new HashSet<DownloadStatus> {
DownloadStatus.Enqueued, DownloadStatus.Enqueued,
@ -24,15 +27,27 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
} }
}; };
private bool isThreadDownloadButtonEnabled = true; private bool isToggleDownloadButtonEnabled = true;
public bool IsToggleDownloadButtonEnabled {
get => isToggleDownloadButtonEnabled;
set => Change(ref isToggleDownloadButtonEnabled, value);
}
public string ToggleDownloadButtonText => IsDownloading ? "Stop Downloading" : "Start Downloading"; public string ToggleDownloadButtonText => IsDownloading ? "Stop Downloading" : "Start Downloading";
public bool IsToggleDownloadButtonEnabled { private bool isRetryingFailedDownloads = false;
get => isThreadDownloadButtonEnabled;
set => Change(ref isThreadDownloadButtonEnabled, value); public bool IsRetryingFailedDownloads {
get => isRetryingFailedDownloads;
set {
isRetryingFailedDownloads = value;
OnPropertyChanged(nameof(IsRetryFailedOnDownloadsButtonEnabled));
}
} }
public bool IsRetryFailedOnDownloadsButtonEnabled => !IsRetryingFailedDownloads && HasFailedDownloads;
public string DownloadMessage { get; set; } = ""; public string DownloadMessage { get; set; } = "";
public double DownloadProgress => totalItemsToDownloadCount is null or 0 ? 0.0 : 100.0 * doneItemsCount / totalItemsToDownloadCount.Value; public double DownloadProgress => totalItemsToDownloadCount is null or 0 ? 0.0 : 100.0 * doneItemsCount / totalItemsToDownloadCount.Value;
@ -43,26 +58,23 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
private readonly StatisticsRow statisticsFailed = new ("Failed"); private readonly StatisticsRow statisticsFailed = new ("Failed");
private readonly StatisticsRow statisticsSkipped = new ("Skipped"); private readonly StatisticsRow statisticsSkipped = new ("Skipped");
public List<StatisticsRow> StatisticsRows { public List<StatisticsRow> StatisticsRows => [
get { statisticsEnqueued,
return new List<StatisticsRow> { statisticsDownloaded,
statisticsEnqueued, statisticsFailed,
statisticsDownloaded, statisticsSkipped
statisticsFailed, ];
statisticsSkipped
};
}
}
public bool IsDownloading => state.Downloader.IsDownloading; public bool IsDownloading => state.Downloader.IsDownloading;
public bool HasFailedDownloads => statisticsFailed.Items > 0; public bool HasFailedDownloads => statisticsFailed.Items > 0;
private readonly State state; private readonly State state;
private readonly AsyncValueComputer<DownloadStatusStatistics>.Single downloadStatisticsComputer; private readonly ThrottledTask<int> enqueueDownloadItemsTask;
private readonly ThrottledTask<DownloadStatusStatistics> downloadStatisticsTask;
private IDisposable? finishedItemsSubscription; private IDisposable? finishedItemsSubscription;
private int doneItemsCount; private int doneItemsCount;
private int initialFinishedCount; private int totalEnqueuedItemCount;
private int? totalItemsToDownloadCount; private int? totalItemsToDownloadCount;
public AttachmentsPageModel() : this(State.Dummy) {} public AttachmentsPageModel() : this(State.Dummy) {}
@ -72,14 +84,17 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
FilterModel = new AttachmentFilterPanelModel(state); FilterModel = new AttachmentFilterPanelModel(state);
downloadStatisticsComputer = AsyncValueComputer<DownloadStatusStatistics>.WithResultProcessor(UpdateStatistics).WithOutdatedResults().BuildWithComputer(state.Db.GetDownloadStatusStatistics); enqueueDownloadItemsTask = new ThrottledTask<int>(OnItemsEnqueued, TaskScheduler.FromCurrentSynchronizationContext());
downloadStatisticsComputer.Recompute(); downloadStatisticsTask = new ThrottledTask<DownloadStatusStatistics>(UpdateStatistics, TaskScheduler.FromCurrentSynchronizationContext());
RecomputeDownloadStatistics();
state.Db.Statistics.PropertyChanged += OnDbStatisticsChanged; state.Db.Statistics.PropertyChanged += OnDbStatisticsChanged;
} }
public void Dispose() { public void Dispose() {
state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged; state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
enqueueDownloadItemsTask.Dispose();
downloadStatisticsTask.Dispose();
finishedItemsSubscription?.Dispose(); finishedItemsSubscription?.Dispose();
FilterModel.Dispose(); FilterModel.Dispose();
} }
@ -87,23 +102,107 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) { private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) {
if (e.PropertyName == nameof(DatabaseStatistics.TotalAttachments)) { if (e.PropertyName == nameof(DatabaseStatistics.TotalAttachments)) {
if (IsDownloading) { if (IsDownloading) {
EnqueueDownloadItems(); EnqueueDownloadItemsLater();
} }
else { else {
downloadStatisticsComputer.Recompute(); RecomputeDownloadStatistics();
} }
} }
else if (e.PropertyName == nameof(DatabaseStatistics.TotalDownloads)) { else if (e.PropertyName == nameof(DatabaseStatistics.TotalDownloads)) {
downloadStatisticsComputer.Recompute(); RecomputeDownloadStatistics();
} }
} }
private void EnqueueDownloadItems() { private async Task EnqueueDownloadItems() {
OnItemsEnqueued(await state.Db.Downloads.EnqueueDownloadItems(CreateAttachmentFilter()));
}
private void EnqueueDownloadItemsLater() {
var filter = CreateAttachmentFilter();
enqueueDownloadItemsTask.Post(cancellationToken => state.Db.Downloads.EnqueueDownloadItems(filter, cancellationToken));
}
private void OnItemsEnqueued(int itemCount) {
totalEnqueuedItemCount += itemCount;
totalItemsToDownloadCount = totalEnqueuedItemCount;
UpdateDownloadMessage();
RecomputeDownloadStatistics();
}
private AttachmentFilter CreateAttachmentFilter() {
var filter = FilterModel.CreateFilter(); var filter = FilterModel.CreateFilter();
filter.DownloadItemRule = AttachmentFilter.DownloadItemRules.OnlyNotPresent; filter.DownloadItemRule = AttachmentFilter.DownloadItemRules.OnlyNotPresent;
state.Db.EnqueueDownloadItems(filter); return filter;
}
downloadStatisticsComputer.Recompute(); public async Task OnClickToggleDownload() {
IsToggleDownloadButtonEnabled = false;
if (IsDownloading) {
await state.Downloader.Stop();
finishedItemsSubscription?.Dispose();
finishedItemsSubscription = null;
RecomputeDownloadStatistics();
await state.Db.Downloads.RemoveDownloadItems(EnqueuedItemFilter, FilterRemovalMode.RemoveMatching);
doneItemsCount = 0;
totalEnqueuedItemCount = 0;
totalItemsToDownloadCount = null;
UpdateDownloadMessage();
}
else {
var finishedItems = await state.Downloader.Start();
finishedItemsSubscription = finishedItems.Select(static _ => true)
.Buffer(TimeSpan.FromMilliseconds(100))
.Select(static items => items.Count)
.Where(static items => items > 0)
.ObserveOn(AvaloniaScheduler.Instance)
.Subscribe(OnItemsFinished);
await EnqueueDownloadItems();
}
OnPropertyChanged(nameof(ToggleDownloadButtonText));
OnPropertyChanged(nameof(IsDownloading));
IsToggleDownloadButtonEnabled = true;
}
private void OnItemsFinished(int finishedItemCount) {
doneItemsCount += finishedItemCount;
UpdateDownloadMessage();
RecomputeDownloadStatistics();
}
public async Task OnClickRetryFailedDownloads() {
IsRetryingFailedDownloads = true;
try {
var allExceptFailedFilter = new DownloadItemFilter {
IncludeStatuses = new HashSet<DownloadStatus> {
DownloadStatus.Enqueued,
DownloadStatus.Downloading,
DownloadStatus.Success
}
};
await state.Db.Downloads.RemoveDownloadItems(allExceptFailedFilter, FilterRemovalMode.KeepMatching);
if (IsDownloading) {
await EnqueueDownloadItems();
}
} catch (Exception e) {
Log.Error(e);
} finally {
IsRetryingFailedDownloads = false;
}
}
private void RecomputeDownloadStatistics() {
downloadStatisticsTask.Post(state.Db.Downloads.GetStatistics);
} }
private void UpdateStatistics(DownloadStatusStatistics statusStatistics) { private void UpdateStatistics(DownloadStatusStatistics statusStatistics) {
@ -125,9 +224,9 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
if (hadFailedDownloads != HasFailedDownloads) { if (hadFailedDownloads != HasFailedDownloads) {
OnPropertyChanged(nameof(HasFailedDownloads)); OnPropertyChanged(nameof(HasFailedDownloads));
OnPropertyChanged(nameof(IsRetryFailedOnDownloadsButtonEnabled));
} }
totalItemsToDownloadCount = statisticsEnqueued.Items + statisticsDownloaded.Items + statisticsFailed.Items - initialFinishedCount;
UpdateDownloadMessage(); UpdateDownloadMessage();
} }
@ -138,65 +237,6 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
OnPropertyChanged(nameof(DownloadProgress)); OnPropertyChanged(nameof(DownloadProgress));
} }
private void OnItemsFinished(int finishedItemCount) {
doneItemsCount += finishedItemCount;
UpdateDownloadMessage();
downloadStatisticsComputer.Recompute();
}
public async Task OnClickToggleDownload() {
IsToggleDownloadButtonEnabled = false;
if (IsDownloading) {
await state.Downloader.Stop();
finishedItemsSubscription?.Dispose();
finishedItemsSubscription = null;
downloadStatisticsComputer.Recompute();
state.Db.RemoveDownloadItems(EnqueuedItemFilter, FilterRemovalMode.RemoveMatching);
doneItemsCount = 0;
initialFinishedCount = 0;
totalItemsToDownloadCount = null;
UpdateDownloadMessage();
}
else {
var finishedItems = await state.Downloader.Start();
initialFinishedCount = statisticsDownloaded.Items + statisticsFailed.Items;
finishedItemsSubscription = finishedItems.Select(static _ => true)
.Buffer(TimeSpan.FromMilliseconds(100))
.Select(static items => items.Count)
.Where(static items => items > 0)
.ObserveOn(AvaloniaScheduler.Instance)
.Subscribe(OnItemsFinished);
EnqueueDownloadItems();
}
OnPropertyChanged(nameof(ToggleDownloadButtonText));
OnPropertyChanged(nameof(IsDownloading));
IsToggleDownloadButtonEnabled = true;
}
public void OnClickRetryFailedDownloads() {
var allExceptFailedFilter = new DownloadItemFilter {
IncludeStatuses = new HashSet<DownloadStatus> {
DownloadStatus.Enqueued,
DownloadStatus.Downloading,
DownloadStatus.Success
}
};
state.Db.RemoveDownloadItems(allExceptFailedFilter, FilterRemovalMode.KeepMatching);
if (IsDownloading) {
EnqueueDownloadItems();
}
}
public sealed class StatisticsRow { public sealed class StatisticsRow {
public string State { get; } public string State { get; }
public int Items { get; set; } public int Items { get; set; }

View File

@ -18,7 +18,7 @@ using DHT.Server;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Database.Import; using DHT.Server.Database.Import;
using DHT.Server.Database.Sqlite; using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using DHT.Utils.Models; using DHT.Utils.Models;
@ -41,7 +41,7 @@ sealed class DatabasePageModel : BaseModel {
this.Db = state.Db; this.Db = state.Db;
} }
public async void OpenDatabaseFolder() { public async Task OpenDatabaseFolder() {
string file = Db.Path; string file = Db.Path;
string? folder = Path.GetDirectoryName(file); string? folder = Path.GetDirectoryName(file);
@ -72,18 +72,11 @@ sealed class DatabasePageModel : BaseModel {
DatabaseClosed?.Invoke(this, EventArgs.Empty); DatabaseClosed?.Invoke(this, EventArgs.Empty);
} }
public async void MergeWithDatabase() { public async Task MergeWithDatabase() {
var paths = await DatabaseGui.NewOpenDatabaseFilesDialog(window, Path.GetDirectoryName(Db.Path)); var paths = await DatabaseGui.NewOpenDatabaseFilesDialog(window, Path.GetDirectoryName(Db.Path));
if (paths.Length == 0) { if (paths.Length > 0) {
return; await ProgressDialog.Show(window, "Database Merge", async (dialog, callback) => await MergeWithDatabaseFromPaths(Db, paths, dialog, callback));
} }
ProgressDialog progressDialog = new ProgressDialog();
progressDialog.DataContext = new ProgressDialogModel(async callbacks => await MergeWithDatabaseFromPaths(Db, paths, progressDialog, callbacks[0])) {
Title = "Database Merge"
};
await progressDialog.ShowProgressDialog(window);
} }
private static async Task MergeWithDatabaseFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) { private static async Task MergeWithDatabaseFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) {
@ -97,7 +90,7 @@ sealed class DatabasePageModel : BaseModel {
} }
try { try {
target.AddFrom(db); await target.AddFrom(db);
return true; return true;
} finally { } finally {
db.Dispose(); db.Dispose();
@ -140,23 +133,16 @@ sealed class DatabasePageModel : BaseModel {
} }
} }
public async void ImportLegacyArchive() { public async Task ImportLegacyArchive() {
var paths = await window.StorageProvider.OpenFiles(new FilePickerOpenOptions { var paths = await window.StorageProvider.OpenFiles(new FilePickerOpenOptions {
Title = "Open Legacy DHT Archive", Title = "Open Legacy DHT Archive",
SuggestedStartLocation = await FileDialogs.GetSuggestedStartLocation(window, Path.GetDirectoryName(Db.Path)), SuggestedStartLocation = await FileDialogs.GetSuggestedStartLocation(window, Path.GetDirectoryName(Db.Path)),
AllowMultiple = true AllowMultiple = true
}); });
if (paths.Length == 0) { if (paths.Length > 0) {
return; await ProgressDialog.Show(window, "Legacy Archive Import", async (dialog, callback) => await ImportLegacyArchiveFromPaths(Db, paths, dialog, callback));
} }
ProgressDialog progressDialog = new ProgressDialog();
progressDialog.DataContext = new ProgressDialogModel(async callbacks => await ImportLegacyArchiveFromPaths(Db, paths, progressDialog, callbacks[0])) {
Title = "Legacy Archive Import"
};
await progressDialog.ShowProgressDialog(window);
} }
private static async Task ImportLegacyArchiveFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) { private static async Task ImportLegacyArchiveFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) {
@ -208,7 +194,7 @@ sealed class DatabasePageModel : BaseModel {
private static async Task PerformImport(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback, string neutralDialogTitle, string errorDialogTitle, string itemName, Func<string, Task<bool>> performImport) { private static async Task PerformImport(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback, string neutralDialogTitle, string errorDialogTitle, string itemName, Func<string, Task<bool>> performImport) {
int total = paths.Length; int total = paths.Length;
var oldStatistics = target.SnapshotStatistics(); var oldStatistics = await target.SnapshotStatistics();
int successful = 0; int successful = 0;
int finished = 0; int finished = 0;
@ -239,7 +225,8 @@ sealed class DatabasePageModel : BaseModel {
return; return;
} }
await Dialog.ShowOk(dialog, neutralDialogTitle, GetImportDialogMessage(oldStatistics, target.SnapshotStatistics(), successful, total, itemName)); var newStatistics = await target.SnapshotStatistics();
await Dialog.ShowOk(dialog, neutralDialogTitle, GetImportDialogMessage(oldStatistics, newStatistics, successful, total, itemName));
} }
private static string GetImportDialogMessage(DatabaseStatisticsSnapshot oldStatistics, DatabaseStatisticsSnapshot newStatistics, int successfulItems, int totalItems, string itemName) { private static string GetImportDialogMessage(DatabaseStatisticsSnapshot oldStatistics, DatabaseStatisticsSnapshot newStatistics, int successfulItems, int totalItems, string itemName) {

View File

@ -44,13 +44,7 @@ namespace DHT.Desktop.Main.Pages {
return; return;
} }
ProgressDialog progressDialog = new ProgressDialog { await ProgressDialog.Show(window, "Generating Random Data", async (_, callback) => await GenerateRandomData(channels, users, messages, callback));
DataContext = new ProgressDialogModel(async callbacks => await GenerateRandomData(channels, users, messages, callbacks[0])) {
Title = "Generating Random Data"
}
};
await progressDialog.ShowProgressDialog(window);
} }
private const int BatchSize = 500; private const int BatchSize = 500;
@ -83,12 +77,9 @@ namespace DHT.Desktop.Main.Pages {
Discriminator = rand.Next(0, 9999).ToString(), Discriminator = rand.Next(0, 9999).ToString(),
}).ToArray(); }).ToArray();
state.Db.AddServer(server); await state.Db.Users.Add(users);
state.Db.AddUsers(users); await state.Db.Servers.Add([server]);
await state.Db.Channels.Add(channels);
foreach (var channel in channels) {
state.Db.AddChannel(channel);
}
var now = DateTimeOffset.Now; var now = DateTimeOffset.Now;
int batchIndex = 0; int batchIndex = 0;
@ -111,13 +102,13 @@ namespace DHT.Desktop.Main.Pages {
Timestamp = timeMillis, Timestamp = timeMillis,
EditTimestamp = editMillis, EditTimestamp = editMillis,
RepliedToId = null, RepliedToId = null,
Attachments = ImmutableArray<Attachment>.Empty, Attachments = ImmutableList<Attachment>.Empty,
Embeds = ImmutableArray<Embed>.Empty, Embeds = ImmutableList<Embed>.Empty,
Reactions = ImmutableArray<Reaction>.Empty, Reactions = ImmutableList<Reaction>.Empty,
}; };
}).ToArray(); }).ToArray();
state.Db.AddMessages(messages); await state.Db.Messages.Add(messages);
messageCount -= BatchSize; messageCount -= BatchSize;
await callback.Update("Adding messages in batches of " + BatchSize, ++batchIndex, batchCount); await callback.Update("Adding messages in batches of " + BatchSize, ++batchIndex, batchCount);

View File

@ -16,7 +16,7 @@
To start tracking messages, copy the tracking script and paste it into the console of either the Discord app, or your browser. The console is usually opened by pressing Ctrl+Shift+I. To start tracking messages, copy the tracking script and paste it into the console of either the Discord app, or your browser. The console is usually opened by pressing Ctrl+Shift+I.
</TextBlock> </TextBlock>
<StackPanel DockPanel.Dock="Left" Orientation="Horizontal" Spacing="10"> <StackPanel DockPanel.Dock="Left" Orientation="Horizontal" Spacing="10">
<Button x:Name="CopyTrackingScript" Click="CopyTrackingScriptButton_OnClick">Copy Tracking Script</Button> <Button x:Name="CopyTrackingScript" Click="CopyTrackingScriptButton_OnClick" IsEnabled="{Binding IsCopyTrackingScriptButtonEnabled}">Copy Tracking Script</Button>
</StackPanel> </StackPanel>
<TextBlock TextWrapping="Wrap" Margin="0 5 0 0"> <TextBlock TextWrapping="Wrap" Margin="0 5 0 0">
By default, the Discord app does not allow opening the console. The button below will change a hidden setting in the Discord app that controls whether the Ctrl+Shift+I shortcut is enabled. By default, the Discord app does not allow opening the console. The button below will change a hidden setting in the Discord app that controls whether the Ctrl+Shift+I shortcut is enabled.

View File

@ -1,4 +1,5 @@
using System; using System;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Web; using System.Web;
using Avalonia.Controls; using Avalonia.Controls;
@ -11,9 +12,16 @@ using static DHT.Desktop.Program;
namespace DHT.Desktop.Main.Pages; namespace DHT.Desktop.Main.Pages;
sealed class TrackingPageModel : BaseModel { sealed class TrackingPageModel : BaseModel {
private bool areDevToolsEnabled; private bool isCopyTrackingScriptButtonEnabled = true;
private bool AreDevToolsEnabled { public bool IsCopyTrackingScriptButtonEnabled {
get => isCopyTrackingScriptButtonEnabled;
set => Change(ref isCopyTrackingScriptButtonEnabled, value);
}
private bool? areDevToolsEnabled;
private bool? AreDevToolsEnabled {
get => areDevToolsEnabled; get => areDevToolsEnabled;
set { set {
Change(ref areDevToolsEnabled, value); Change(ref areDevToolsEnabled, value);
@ -21,15 +29,19 @@ sealed class TrackingPageModel : BaseModel {
} }
} }
public bool IsToggleAppDevToolsButtonEnabled { get; private set; } = true; public bool IsToggleAppDevToolsButtonEnabled { get; private set; } = false;
public string ToggleAppDevToolsButtonText { public string ToggleAppDevToolsButtonText {
get { get {
if (!AreDevToolsEnabled.HasValue) {
return "Loading...";
}
if (!IsToggleAppDevToolsButtonEnabled) { if (!IsToggleAppDevToolsButtonEnabled) {
return "Unavailable"; return "Unavailable";
} }
return AreDevToolsEnabled ? "Disable Ctrl+Shift+I" : "Enable Ctrl+Shift+I"; return AreDevToolsEnabled.Value ? "Disable Ctrl+Shift+I" : "Enable Ctrl+Shift+I";
} }
} }
@ -40,20 +52,21 @@ sealed class TrackingPageModel : BaseModel {
public TrackingPageModel(Window window) { public TrackingPageModel(Window window) {
this.window = window; this.window = window;
Task.Factory.StartNew(InitializeDevToolsToggle, CancellationToken.None, TaskCreationOptions.None, TaskScheduler.FromCurrentSynchronizationContext());
} }
public async Task Initialize() { public async Task<bool> OnClickCopyTrackingScript() {
bool? devToolsEnabled = await DiscordAppSettings.AreDevToolsEnabled(); IsCopyTrackingScriptButtonEnabled = false;
if (devToolsEnabled.HasValue) {
AreDevToolsEnabled = devToolsEnabled.Value; try {
} return await CopyTrackingScript();
else { } finally {
IsToggleAppDevToolsButtonEnabled = false; IsCopyTrackingScriptButtonEnabled = true;
OnPropertyChanged(nameof(IsToggleAppDevToolsButtonEnabled));
} }
} }
public async Task<bool> OnClickCopyTrackingScript() { private async Task<bool> CopyTrackingScript() {
string url = $"http://127.0.0.1:{ServerConfiguration.Port}/get-tracking-script?token={HttpUtility.UrlEncode(ServerConfiguration.Token)}"; string url = $"http://127.0.0.1:{ServerConfiguration.Port}/get-tracking-script?token={HttpUtility.UrlEncode(ServerConfiguration.Token)}";
string script = (await Resources.ReadTextAsync("tracker-loader.js")).Trim().Replace("{url}", url); string script = (await Resources.ReadTextAsync("tracker-loader.js")).Trim().Replace("{url}", url);
@ -72,10 +85,28 @@ sealed class TrackingPageModel : BaseModel {
} }
} }
public async void OnClickToggleAppDevTools() { private async Task InitializeDevToolsToggle() {
bool? devToolsEnabled = await DiscordAppSettings.AreDevToolsEnabled();
if (devToolsEnabled.HasValue) {
IsToggleAppDevToolsButtonEnabled = true;
AreDevToolsEnabled = devToolsEnabled.Value;
}
else {
IsToggleAppDevToolsButtonEnabled = false;
}
OnPropertyChanged(nameof(IsToggleAppDevToolsButtonEnabled));
}
public async Task OnClickToggleAppDevTools() {
const string DialogTitle = "Discord App Settings File"; const string DialogTitle = "Discord App Settings File";
bool oldState = AreDevToolsEnabled; if (!AreDevToolsEnabled.HasValue) {
return;
}
bool oldState = AreDevToolsEnabled.Value;
bool newState = !oldState; bool newState = !oldState;
switch (await DiscordAppSettings.ConfigureDevTools(newState)) { switch (await DiscordAppSettings.ConfigureDevTools(newState)) {

View File

@ -11,6 +11,7 @@ using Avalonia.Platform.Storage;
using DHT.Desktop.Common; using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.File; using DHT.Desktop.Dialogs.File;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Desktop.Main.Controls; using DHT.Desktop.Main.Controls;
using DHT.Desktop.Server; using DHT.Desktop.Server;
using DHT.Server; using DHT.Server;
@ -23,7 +24,11 @@ using static DHT.Desktop.Program;
namespace DHT.Desktop.Main.Pages; namespace DHT.Desktop.Main.Pages;
sealed class ViewerPageModel : BaseModel, IDisposable { sealed class ViewerPageModel : BaseModel, IDisposable {
public static readonly ConcurrentBag<string> TemporaryFiles = new (); public static readonly ConcurrentBag<string> TemporaryFiles = [];
private static readonly FilePickerFileType[] ViewerFileTypes = [
FileDialogs.CreateFilter("Discord History Viewer", ["html"])
];
public bool DatabaseToolFilterModeKeep { get; set; } = true; public bool DatabaseToolFilterModeKeep { get; set; } = true;
public bool DatabaseToolFilterModeRemove { get; set; } = false; public bool DatabaseToolFilterModeRemove { get; set; } = false;
@ -59,6 +64,58 @@ sealed class ViewerPageModel : BaseModel, IDisposable {
HasFilters = FilterModel.HasAnyFilters; HasFilters = FilterModel.HasAnyFilters;
} }
public async void OnClickOpenViewer() {
try {
var fullPath = await PrepareTemporaryViewerFile();
var strategy = new LiveViewerExportStrategy(ServerConfiguration.Port, ServerConfiguration.Token);
await WriteViewerFile(fullPath, strategy);
Process.Start(new ProcessStartInfo(fullPath) { UseShellExecute = true });
} catch (Exception e) {
await Dialog.ShowOk(window, "Open Viewer", "Could not save viewer: " + e.Message);
}
}
private async Task<string> PrepareTemporaryViewerFile() {
return await Task.Run(() => {
string rootPath = Path.Combine(Path.GetTempPath(), "DiscordHistoryTracker");
string filenameBase = Path.GetFileNameWithoutExtension(state.Db.Path) + "-" + DateTime.Now.ToString("yyyy-MM-dd");
string fullPath = Path.Combine(rootPath, filenameBase + ".html");
int counter = 0;
while (File.Exists(fullPath)) {
++counter;
fullPath = Path.Combine(rootPath, filenameBase + "-" + counter + ".html");
}
TemporaryFiles.Add(fullPath);
Directory.CreateDirectory(rootPath);
return fullPath;
});
}
public async void OnClickSaveViewer() {
string? path = await window.StorageProvider.SaveFile(new FilePickerSaveOptions {
Title = "Save Viewer",
FileTypeChoices = ViewerFileTypes,
SuggestedFileName = Path.GetFileNameWithoutExtension(state.Db.Path) + ".html",
SuggestedStartLocation = await FileDialogs.GetSuggestedStartLocation(window, Path.GetDirectoryName(state.Db.Path)),
});
if (path == null) {
return;
}
try {
await WriteViewerFile(path, StandaloneViewerExportStrategy.Instance);
} catch (Exception e) {
await Dialog.ShowOk(window, "Save Viewer", "Could not save viewer: " + e.Message);
}
}
private async Task WriteViewerFile(string path, IViewerExportStrategy strategy) { private async Task WriteViewerFile(string path, IViewerExportStrategy strategy) {
const string ArchiveTag = "/*[ARCHIVE]*/"; const string ArchiveTag = "/*[ARCHIVE]*/";
@ -96,54 +153,23 @@ sealed class ViewerPageModel : BaseModel, IDisposable {
File.Delete(jsonTempFile); File.Delete(jsonTempFile);
} }
public async void OnClickOpenViewer() { public async Task OnClickApplyFiltersToDatabase() {
string rootPath = Path.Combine(Path.GetTempPath(), "DiscordHistoryTracker");
string filenameBase = Path.GetFileNameWithoutExtension(state.Db.Path) + "-" + DateTime.Now.ToString("yyyy-MM-dd");
string fullPath = Path.Combine(rootPath, filenameBase + ".html");
int counter = 0;
while (File.Exists(fullPath)) {
++counter;
fullPath = Path.Combine(rootPath, filenameBase + "-" + counter + ".html");
}
TemporaryFiles.Add(fullPath);
Directory.CreateDirectory(rootPath);
await WriteViewerFile(fullPath, new LiveViewerExportStrategy(ServerConfiguration.Port, ServerConfiguration.Token));
Process.Start(new ProcessStartInfo(fullPath) { UseShellExecute = true });
}
private static readonly FilePickerFileType[] ViewerFileTypes = {
FileDialogs.CreateFilter("Discord History Viewer", new string[] { "html" }),
};
public async void OnClickSaveViewer() {
string? path = await window.StorageProvider.SaveFile(new FilePickerSaveOptions {
Title = "Save Viewer",
FileTypeChoices = ViewerFileTypes,
SuggestedFileName = Path.GetFileNameWithoutExtension(state.Db.Path) + ".html",
SuggestedStartLocation = await FileDialogs.GetSuggestedStartLocation(window, Path.GetDirectoryName(state.Db.Path)),
});
if (path != null) {
await WriteViewerFile(path, StandaloneViewerExportStrategy.Instance);
}
}
public async void OnClickApplyFiltersToDatabase() {
var filter = FilterModel.CreateFilter(); var filter = FilterModel.CreateFilter();
var messageCount = await ProgressDialog.ShowIndeterminate(window, "Apply Filters", "Counting matching messages...", _ => state.Db.Messages.Count(filter));
if (DatabaseToolFilterModeKeep) { if (DatabaseToolFilterModeKeep) {
if (DialogResult.YesNo.Yes == await Dialog.ShowYesNo(window, "Keep Matching Messages in This Database", state.Db.CountMessages(filter).Pluralize("message") + " will be kept, and the rest will be removed from this database. This action cannot be undone. Proceed?")) { if (DialogResult.YesNo.Yes == await Dialog.ShowYesNo(window, "Keep Matching Messages in This Database", messageCount.Pluralize("message") + " will be kept, and the rest will be removed from this database. This action cannot be undone. Proceed?")) {
state.Db.RemoveMessages(filter, FilterRemovalMode.KeepMatching); await ApplyFilterToDatabase(filter, FilterRemovalMode.KeepMatching);
} }
} }
else if (DatabaseToolFilterModeRemove) { else if (DatabaseToolFilterModeRemove) {
if (DialogResult.YesNo.Yes == await Dialog.ShowYesNo(window, "Remove Matching Messages in This Database", state.Db.CountMessages(filter).Pluralize("message") + " will be removed from this database. This action cannot be undone. Proceed?")) { if (DialogResult.YesNo.Yes == await Dialog.ShowYesNo(window, "Remove Matching Messages in This Database", messageCount.Pluralize("message") + " will be removed from this database. This action cannot be undone. Proceed?")) {
state.Db.RemoveMessages(filter, FilterRemovalMode.RemoveMatching); await ApplyFilterToDatabase(filter, FilterRemovalMode.RemoveMatching);
} }
} }
} }
private async Task ApplyFilterToDatabase(MessageFilter filter, FilterRemovalMode removalMode) {
await ProgressDialog.ShowIndeterminate(window, "Apply Filters", "Removing messages...", _ => state.Db.Messages.Remove(filter, removalMode));
}
} }

View File

@ -1,16 +1,12 @@
using System; using System;
using System.Threading.Tasks;
using Avalonia.Controls; using Avalonia.Controls;
using DHT.Desktop.Main.Controls; using DHT.Desktop.Main.Controls;
using DHT.Desktop.Main.Pages; using DHT.Desktop.Main.Pages;
using DHT.Server; using DHT.Server;
using DHT.Utils.Logging;
namespace DHT.Desktop.Main.Screens; namespace DHT.Desktop.Main.Screens;
sealed class MainContentScreenModel : IDisposable { sealed class MainContentScreenModel : IDisposable {
private static readonly Log Log = Log.ForType<MainContentScreenModel>();
public DatabasePage DatabasePage { get; } public DatabasePage DatabasePage { get; }
private DatabasePageModel DatabasePageModel { get; } private DatabasePageModel DatabasePageModel { get; }
@ -75,10 +71,6 @@ sealed class MainContentScreenModel : IDisposable {
StatusBarModel = new StatusBarModel(state); StatusBarModel = new StatusBarModel(state);
} }
public async Task Initialize() {
await TrackingPageModel.Initialize();
}
public void Dispose() { public void Dispose() {
AttachmentsPageModel.Dispose(); AttachmentsPageModel.Dispose();
ViewerPageModel.Dispose(); ViewerPageModel.Dispose();

View File

@ -32,7 +32,7 @@
<TextBlock Text="{Binding Version, StringFormat=Discord History Tracker v{0}}" FontSize="25" Margin="0 0 0 30" HorizontalAlignment="Center" /> <TextBlock Text="{Binding Version, StringFormat=Discord History Tracker v{0}}" FontSize="25" Margin="0 0 0 30" HorizontalAlignment="Center" />
<StackPanel Orientation="Horizontal" HorizontalAlignment="Center"> <StackPanel Orientation="Horizontal" HorizontalAlignment="Center">
<Button Command="{Binding OpenOrCreateDatabase}">Open or Create Database</Button> <Button Command="{Binding OpenOrCreateDatabase}" IsEnabled="{Binding IsOpenOrCreateDatabaseButtonEnabled}">Open or Create Database</Button>
<Button Command="{Binding ShowAboutDialog}">About</Button> <Button Command="{Binding ShowAboutDialog}">About</Button>
<Button Command="{Binding Exit}">Exit</Button> <Button Command="{Binding Exit}">Exit</Button>
</StackPanel> </StackPanel>

View File

@ -7,17 +7,23 @@ using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress; using DHT.Desktop.Dialogs.Progress;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Database.Sqlite; using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Models; using DHT.Utils.Models;
namespace DHT.Desktop.Main.Screens; namespace DHT.Desktop.Main.Screens;
sealed class WelcomeScreenModel : BaseModel, IDisposable { sealed class WelcomeScreenModel : BaseModel {
public string Version => Program.Version; public string Version => Program.Version;
public IDatabaseFile? Db { get; private set; } private bool isOpenOrCreateDatabaseButtonEnabled = true;
public bool HasDatabase => Db != null;
public bool IsOpenOrCreateDatabaseButtonEnabled {
get => isOpenOrCreateDatabaseButtonEnabled;
set => Change(ref isOpenOrCreateDatabaseButtonEnabled, value);
}
public event EventHandler<IDatabaseFile>? DatabaseSelected;
private readonly Window window; private readonly Window window;
private string? dbFilePath; private string? dbFilePath;
@ -29,23 +35,25 @@ sealed class WelcomeScreenModel : BaseModel, IDisposable {
this.window = window; this.window = window;
} }
public async void OpenOrCreateDatabase() { public async Task OpenOrCreateDatabase() {
var path = await DatabaseGui.NewOpenOrCreateDatabaseFileDialog(window, Path.GetDirectoryName(dbFilePath)); IsOpenOrCreateDatabaseButtonEnabled = false;
if (path != null) { try {
await OpenOrCreateDatabaseFromPath(path); var path = await DatabaseGui.NewOpenOrCreateDatabaseFileDialog(window, Path.GetDirectoryName(dbFilePath));
if (path != null) {
await OpenOrCreateDatabaseFromPath(path);
}
} finally {
IsOpenOrCreateDatabaseButtonEnabled = true;
} }
} }
public async Task OpenOrCreateDatabaseFromPath(string path) { public async Task OpenOrCreateDatabaseFromPath(string path) {
if (Db != null) {
Db = null;
}
dbFilePath = path; dbFilePath = path;
Db = await DatabaseGui.TryOpenOrCreateDatabaseFromPath(path, window, new SchemaUpgradeCallbacks(window));
var db = await DatabaseGui.TryOpenOrCreateDatabaseFromPath(path, window, new SchemaUpgradeCallbacks(window));
OnPropertyChanged(nameof(Db)); if (db != null) {
OnPropertyChanged(nameof(HasDatabase)); DatabaseSelected?.Invoke(this, db);
}
} }
private sealed class SchemaUpgradeCallbacks : ISchemaUpgradeCallbacks { private sealed class SchemaUpgradeCallbacks : ISchemaUpgradeCallbacks {
@ -67,12 +75,8 @@ sealed class WelcomeScreenModel : BaseModel, IDisposable {
await doUpgrade(reporter); await doUpgrade(reporter);
await Task.Delay(TimeSpan.FromMilliseconds(600)); await Task.Delay(TimeSpan.FromMilliseconds(600));
} }
await new ProgressDialog { await new ProgressDialog { DataContext = new ProgressDialogModel("Upgrading Database", StartUpgrade, progressItems: 3) }.ShowProgressDialog(window);
DataContext = new ProgressDialogModel(StartUpgrade, progressItems: 3) {
Title = "Upgrading Database"
}
}.ShowProgressDialog(window);
} }
private sealed class ProgressReporter : ISchemaUpgradeCallbacks.IProgressReporter { private sealed class ProgressReporter : ISchemaUpgradeCallbacks.IProgressReporter {
@ -109,22 +113,11 @@ sealed class WelcomeScreenModel : BaseModel, IDisposable {
} }
} }
public void CloseDatabase() { public async Task ShowAboutDialog() {
Dispose(); await new AboutWindow { DataContext = new AboutWindowModel() }.ShowDialog(window);
OnPropertyChanged(nameof(Db));
OnPropertyChanged(nameof(HasDatabase));
}
public async void ShowAboutDialog() {
await new AboutWindow { DataContext = new AboutWindowModel() }.ShowDialog(this.window);
} }
public void Exit() { public void Exit() {
window.Close(); window.Close();
} }
public void Dispose() {
Db?.Dispose();
Db = null;
}
} }

View File

@ -10,7 +10,7 @@ public readonly struct Message {
public long Timestamp { get; init; } public long Timestamp { get; init; }
public long? EditTimestamp { get; init; } public long? EditTimestamp { get; init; }
public ulong? RepliedToId { get; init; } public ulong? RepliedToId { get; init; }
public ImmutableArray<Attachment> Attachments { get; init; } public ImmutableList<Attachment> Attachments { get; init; }
public ImmutableArray<Embed> Embeds { get; init; } public ImmutableList<Embed> Embeds { get; init; }
public ImmutableArray<Reaction> Reactions { get; init; } public ImmutableList<Reaction> Reactions { get; init; }
} }

View File

@ -1,29 +1,32 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
namespace DHT.Server.Database; namespace DHT.Server.Database;
public static class DatabaseExtensions { public static class DatabaseExtensions {
public static void AddFrom(this IDatabaseFile target, IDatabaseFile source) { public static async Task AddFrom(this IDatabaseFile target, IDatabaseFile source) {
target.AddServers(source.GetAllServers()); await target.Users.Add(await source.Users.Get().ToListAsync());
target.AddChannels(source.GetAllChannels()); await target.Servers.Add(await source.Servers.Get().ToListAsync());
target.AddUsers(source.GetAllUsers().ToArray()); await target.Channels.Add(await source.Channels.Get().ToListAsync());
target.AddMessages(source.GetMessages().ToArray());
foreach (var download in source.GetDownloadsWithoutData()) { const int MessageBatchSize = 100;
target.AddDownload(download.Status == DownloadStatus.Success ? source.GetDownloadWithData(download) : download); List<Message> batchedMessages = new (MessageBatchSize);
await foreach (var message in source.Messages.Get()) {
batchedMessages.Add(message);
if (batchedMessages.Count >= MessageBatchSize) {
await target.Messages.Add(batchedMessages);
batchedMessages.Clear();
}
} }
}
await target.Messages.Add(batchedMessages);
internal static void AddServers(this IDatabaseFile target, IEnumerable<Data.Server> servers) { await foreach (var download in source.Downloads.GetWithoutData()) {
foreach (var server in servers) { await target.Downloads.AddDownload(download.Status == DownloadStatus.Success ? await source.Downloads.HydrateWithData(download) : download);
target.AddServer(server);
}
}
internal static void AddChannels(this IDatabaseFile target, IEnumerable<Channel> channels) {
foreach (var channel in channels) {
target.AddChannel(channel);
} }
} }
} }

View File

@ -1,90 +1,31 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using DHT.Server.Data; using System.Threading.Tasks;
using DHT.Server.Data.Aggregations; using DHT.Server.Database.Repositories;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
namespace DHT.Server.Database; namespace DHT.Server.Database;
[SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")] [SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
public sealed class DummyDatabaseFile : IDatabaseFile { sealed class DummyDatabaseFile : IDatabaseFile {
public static DummyDatabaseFile Instance { get; } = new(); public static DummyDatabaseFile Instance { get; } = new ();
public string Path => ""; public string Path => "";
public DatabaseStatistics Statistics { get; } = new(); public DatabaseStatistics Statistics { get; } = new ();
public IUserRepository Users { get; } = new IUserRepository.Dummy();
public IServerRepository Servers { get; } = new IServerRepository.Dummy();
public IChannelRepository Channels { get; } = new IChannelRepository.Dummy();
public IMessageRepository Messages { get; } = new IMessageRepository.Dummy();
public IDownloadRepository Downloads { get; } = new IDownloadRepository.Dummy();
private DummyDatabaseFile() {} private DummyDatabaseFile() {}
public DatabaseStatisticsSnapshot SnapshotStatistics() { public Task<DatabaseStatisticsSnapshot> SnapshotStatistics() {
return new(); return Task.FromResult(new DatabaseStatisticsSnapshot());
} }
public void AddServer(Data.Server server) {} public Task Vacuum() {
return Task.CompletedTask;
public List<Data.Server> GetAllServers() {
return new();
} }
public void AddChannel(Channel channel) {}
public List<Channel> GetAllChannels() {
return new();
}
public void AddUsers(User[] users) {}
public List<User> GetAllUsers() {
return new();
}
public void AddMessages(Message[] messages) {}
public int CountMessages(MessageFilter? filter = null) {
return 0;
}
public List<Message> GetMessages(MessageFilter? filter = null) {
return new();
}
public HashSet<ulong> GetMessageIds(MessageFilter? filter = null) {
return new();
}
public void RemoveMessages(MessageFilter filter, FilterRemovalMode mode) {}
public int CountAttachments(AttachmentFilter? filter = null) {
return new();
}
public List<Data.Download> GetDownloadsWithoutData() {
return new();
}
public Data.Download GetDownloadWithData(Data.Download download) {
return download;
}
public DownloadedAttachment? GetDownloadedAttachment(string url) {
return null;
}
public void AddDownload(Data.Download download) {}
public void EnqueueDownloadItems(AttachmentFilter? filter = null) {}
public List<DownloadItem> PullEnqueuedDownloadItems(int count) {
return new();
}
public void RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {}
public DownloadStatusStatistics GetDownloadStatusStatistics() {
return new();
}
public void Vacuum() {}
public void Dispose() {} public void Dispose() {}
} }

View File

@ -21,7 +21,7 @@ public static class ViewerJsonExport {
var includedChannelIds = new HashSet<ulong>(); var includedChannelIds = new HashSet<ulong>();
var includedServerIds = new HashSet<ulong>(); var includedServerIds = new HashSet<ulong>();
var includedMessages = db.GetMessages(filter); var includedMessages = await db.Messages.Get(filter).ToListAsync();
var includedChannels = new List<Channel>(); var includedChannels = new List<Channel>();
foreach (var message in includedMessages) { foreach (var message in includedMessages) {
@ -29,23 +29,23 @@ public static class ViewerJsonExport {
includedChannelIds.Add(message.Channel); includedChannelIds.Add(message.Channel);
} }
foreach (var channel in db.GetAllChannels()) { await foreach (var channel in db.Channels.Get()) {
if (includedChannelIds.Contains(channel.Id)) { if (includedChannelIds.Contains(channel.Id)) {
includedChannels.Add(channel); includedChannels.Add(channel);
includedServerIds.Add(channel.Server); includedServerIds.Add(channel.Server);
} }
} }
var users = GenerateUserList(db, includedUserIds, out var userindex, out var userIndices); var (users, userIndex, userIndices) = await GenerateUserList(db, includedUserIds);
var servers = GenerateServerList(db, includedServerIds, out var serverindex); var (servers, serverIndices) = await GenerateServerList(db, includedServerIds);
var channels = GenerateChannelList(includedChannels, serverindex); var channels = GenerateChannelList(includedChannels, serverIndices);
perf.Step("Collect database data"); perf.Step("Collect database data");
var value = new ViewerJson { var value = new ViewerJson {
Meta = new ViewerJson.JsonMeta { Meta = new ViewerJson.JsonMeta {
Users = users, Users = users,
Userindex = userindex, Userindex = userIndex,
Servers = servers, Servers = servers,
Channels = channels Channels = channels
}, },
@ -60,12 +60,12 @@ public static class ViewerJsonExport {
perf.End(); perf.End();
} }
private static Dictionary<Snowflake, ViewerJson.JsonUser> GenerateUserList(IDatabaseFile db, HashSet<ulong> userIds, out List<Snowflake> userindex, out Dictionary<ulong, int> userIndices) { private static async Task<(Dictionary<Snowflake, ViewerJson.JsonUser> Users, List<Snowflake> UserIndex, Dictionary<ulong, int> UserIndices)> GenerateUserList(IDatabaseFile db, HashSet<ulong> userIds) {
var users = new Dictionary<Snowflake, ViewerJson.JsonUser>(); var users = new Dictionary<Snowflake, ViewerJson.JsonUser>();
userindex = new List<Snowflake>(); var userIndex = new List<Snowflake>();
userIndices = new Dictionary<ulong, int>(); var userIndices = new Dictionary<ulong, int>();
foreach (var user in db.GetAllUsers()) { await foreach (var user in db.Users.Get()) {
var id = user.Id; var id = user.Id;
if (!userIds.Contains(id)) { if (!userIds.Contains(id)) {
continue; continue;
@ -73,7 +73,7 @@ public static class ViewerJsonExport {
var idSnowflake = new Snowflake(id); var idSnowflake = new Snowflake(id);
userIndices[id] = users.Count; userIndices[id] = users.Count;
userindex.Add(idSnowflake); userIndex.Add(idSnowflake);
users[idSnowflake] = new ViewerJson.JsonUser { users[idSnowflake] = new ViewerJson.JsonUser {
Name = user.Name, Name = user.Name,
@ -82,14 +82,14 @@ public static class ViewerJsonExport {
}; };
} }
return users; return (users, userIndex, userIndices);
} }
private static List<ViewerJson.JsonServer> GenerateServerList(IDatabaseFile db, HashSet<ulong> serverIds, out Dictionary<ulong, int> serverIndices) { private static async Task<(List<ViewerJson.JsonServer> Servers, Dictionary<ulong, int> ServerIndices)> GenerateServerList(IDatabaseFile db, HashSet<ulong> serverIds) {
var servers = new List<ViewerJson.JsonServer>(); var servers = new List<ViewerJson.JsonServer>();
serverIndices = new Dictionary<ulong, int>(); var serverIndices = new Dictionary<ulong, int>();
foreach (var server in db.GetAllServers()) { await foreach (var server in db.Servers.Get()) {
var id = server.Id; var id = server.Id;
if (!serverIds.Contains(id)) { if (!serverIds.Contains(id)) {
continue; continue;
@ -103,7 +103,7 @@ public static class ViewerJsonExport {
}); });
} }
return servers; return (servers, serverIndices);
} }
private static Dictionary<Snowflake, ViewerJson.JsonChannel> GenerateChannelList(List<Channel> includedChannels, Dictionary<ulong, int> serverIndices) { private static Dictionary<Snowflake, ViewerJson.JsonChannel> GenerateChannelList(List<Channel> includedChannels, Dictionary<ulong, int> serverIndices) {

View File

@ -1,43 +1,19 @@
using System; using System;
using System.Collections.Generic; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Database.Repositories;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
namespace DHT.Server.Database; namespace DHT.Server.Database;
public interface IDatabaseFile : IDisposable { public interface IDatabaseFile : IDisposable {
string Path { get; } string Path { get; }
DatabaseStatistics Statistics { get; } DatabaseStatistics Statistics { get; }
DatabaseStatisticsSnapshot SnapshotStatistics(); Task<DatabaseStatisticsSnapshot> SnapshotStatistics();
void AddServer(Data.Server server); IUserRepository Users { get; }
List<Data.Server> GetAllServers(); IServerRepository Servers { get; }
IChannelRepository Channels { get; }
IMessageRepository Messages { get; }
IDownloadRepository Downloads { get; }
void AddChannel(Channel channel); Task Vacuum();
List<Channel> GetAllChannels();
void AddUsers(User[] users);
List<User> GetAllUsers();
void AddMessages(Message[] messages);
int CountMessages(MessageFilter? filter = null);
List<Message> GetMessages(MessageFilter? filter = null);
HashSet<ulong> GetMessageIds(MessageFilter? filter = null);
void RemoveMessages(MessageFilter filter, FilterRemovalMode mode);
int CountAttachments(AttachmentFilter? filter = null);
void AddDownload(Data.Download download);
List<Data.Download> GetDownloadsWithoutData();
Data.Download GetDownloadWithData(Data.Download download);
DownloadedAttachment? GetDownloadedAttachment(string url);
void EnqueueDownloadItems(AttachmentFilter? filter = null);
List<DownloadItem> PullEnqueuedDownloadItems(int count);
void RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode);
DownloadStatusStatistics GetDownloadStatusStatistics();
void Vacuum();
} }

View File

@ -33,10 +33,8 @@ public static class LegacyArchiveImport {
var servers = ReadServerList(meta, fakeSnowflake); var servers = ReadServerList(meta, fakeSnowflake);
var newServersOnly = new HashSet<Data.Server>(servers); var newServersOnly = new HashSet<Data.Server>(servers);
var oldServersById = db.GetAllServers().ToDictionary(static server => server.Id, static server => server); var oldServersById = await db.Servers.Get().ToDictionaryAsync(static server => server.Id, static server => server);
var oldChannelsById = await db.Channels.Get().ToDictionaryAsync(static channel => channel.Id, static channel => channel);
var oldChannels = db.GetAllChannels();
var oldChannelsById = oldChannels.ToDictionary(static channel => channel.Id, static channel => channel);
foreach (var (channelId, serverIndex) in ReadChannelToServerIndexMapping(meta, servers)) { foreach (var (channelId, serverIndex) in ReadChannelToServerIndexMapping(meta, servers)) {
if (oldChannelsById.TryGetValue(channelId, out var oldChannel) && oldServersById.TryGetValue(oldChannel.Server, out var oldServer) && newServersOnly.Remove(servers[serverIndex])) { if (oldChannelsById.TryGetValue(channelId, out var oldChannel) && oldServersById.TryGetValue(oldChannel.Server, out var oldServer) && newServersOnly.Remove(servers[serverIndex])) {
@ -66,17 +64,17 @@ public static class LegacyArchiveImport {
perf.Step("Read channel list"); perf.Step("Read channel list");
var oldMessageIds = db.GetMessageIds(); var oldMessageIds = await db.Messages.GetIds().ToHashSetAsync();
var newMessages = channels.SelectMany(channel => ReadMessages(data, channel, users, fakeSnowflake)) var newMessages = channels.SelectMany(channel => ReadMessages(data, channel, users, fakeSnowflake))
.Where(message => !oldMessageIds.Contains(message.Id)) .Where(message => !oldMessageIds.Contains(message.Id))
.ToArray(); .ToArray();
perf.Step("Read messages"); perf.Step("Read messages");
db.AddUsers(users); await db.Users.Add(users);
db.AddServers(servers); await db.Servers.Add(servers);
db.AddChannels(channels); await db.Channels.Add(channels);
db.AddMessages(newMessages); await db.Messages.Add(newMessages);
perf.Step("Import into database"); perf.Step("Import into database");
} catch (HttpException e) { } catch (HttpException e) {
@ -179,9 +177,9 @@ public static class LegacyArchiveImport {
Timestamp = messageObj.RequireLong("t", path), Timestamp = messageObj.RequireLong("t", path),
EditTimestamp = messageObj.HasKey("te") ? messageObj.RequireLong("te", path) : null, EditTimestamp = messageObj.HasKey("te") ? messageObj.RequireLong("te", path) : null,
RepliedToId = messageObj.HasKey("r") ? messageObj.RequireSnowflake("r", path) : null, RepliedToId = messageObj.HasKey("r") ? messageObj.RequireSnowflake("r", path) : null,
Attachments = messageObj.HasKey("a") ? ReadMessageAttachments(messageObj.RequireArray("a", path), fakeSnowflake, path + ".a[]").ToImmutableArray() : ImmutableArray<Attachment>.Empty, Attachments = messageObj.HasKey("a") ? ReadMessageAttachments(messageObj.RequireArray("a", path), fakeSnowflake, path + ".a[]").ToImmutableList() : ImmutableList<Attachment>.Empty,
Embeds = messageObj.HasKey("e") ? ReadMessageEmbeds(messageObj.RequireArray("e", path), path + ".e[]").ToImmutableArray() : ImmutableArray<Embed>.Empty, Embeds = messageObj.HasKey("e") ? ReadMessageEmbeds(messageObj.RequireArray("e", path), path + ".e[]").ToImmutableList() : ImmutableList<Embed>.Empty,
Reactions = messageObj.HasKey("re") ? ReadMessageReactions(messageObj.RequireArray("re", path), path + ".re[]").ToImmutableArray() : ImmutableArray<Reaction>.Empty, Reactions = messageObj.HasKey("re") ? ReadMessageReactions(messageObj.RequireArray("re", path), path + ".re[]").ToImmutableList() : ImmutableList<Reaction>.Empty,
}; };
}).ToArray(); }).ToArray();
} }

View File

@ -0,0 +1,22 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using DHT.Server.Data;
namespace DHT.Server.Database.Repositories;
public interface IChannelRepository {
Task Add(IReadOnlyList<Channel> channels);
IAsyncEnumerable<Channel> Get();
internal sealed class Dummy : IChannelRepository {
public Task Add(IReadOnlyList<Channel> channels) {
return Task.CompletedTask;
}
public IAsyncEnumerable<Channel> Get() {
return AsyncEnumerable.Empty<Channel>();
}
}
}

View File

@ -0,0 +1,68 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
namespace DHT.Server.Database.Repositories;
public interface IDownloadRepository {
Task<long> CountAttachments(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
Task AddDownload(Data.Download download);
Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken = default);
IAsyncEnumerable<Data.Download> GetWithoutData();
Task<Data.Download> HydrateWithData(Data.Download download);
Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl);
Task<int> EnqueueDownloadItems(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken = default);
Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode);
internal sealed class Dummy : IDownloadRepository {
public Task<long> CountAttachments(AttachmentFilter? filter, CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public Task AddDownload(Data.Download download) {
return Task.CompletedTask;
}
public Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken) {
return Task.FromResult(new DownloadStatusStatistics());
}
public IAsyncEnumerable<Data.Download> GetWithoutData() {
return AsyncEnumerable.Empty<Data.Download>();
}
public Task<Data.Download> HydrateWithData(Data.Download download) {
return Task.FromResult(download);
}
public Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) {
return Task.FromResult<DownloadedAttachment?>(null);
}
public Task<int> EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
return Task.FromResult(0);
}
public IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken) {
return AsyncEnumerable.Empty<DownloadItem>();
}
public Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
return Task.CompletedTask;
}
}
}

View File

@ -0,0 +1,42 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Filters;
namespace DHT.Server.Database.Repositories;
public interface IMessageRepository {
Task Add(IReadOnlyList<Message> messages);
Task<long> Count(MessageFilter? filter = null, CancellationToken cancellationToken = default);
IAsyncEnumerable<Message> Get(MessageFilter? filter = null);
IAsyncEnumerable<ulong> GetIds(MessageFilter? filter = null);
Task Remove(MessageFilter filter, FilterRemovalMode mode);
internal sealed class Dummy : IMessageRepository {
public Task Add(IReadOnlyList<Message> messages) {
return Task.CompletedTask;
}
public Task<long> Count(MessageFilter? filter, CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public IAsyncEnumerable<Message> Get(MessageFilter? filter) {
return AsyncEnumerable.Empty<Message>();
}
public IAsyncEnumerable<ulong> GetIds(MessageFilter? filter) {
return AsyncEnumerable.Empty<ulong>();
}
public Task Remove(MessageFilter filter, FilterRemovalMode mode) {
return Task.CompletedTask;
}
}
}

View File

@ -0,0 +1,21 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace DHT.Server.Database.Repositories;
public interface IServerRepository {
Task Add(IReadOnlyList<Data.Server> servers);
IAsyncEnumerable<Data.Server> Get();
internal sealed class Dummy : IServerRepository {
public Task Add(IReadOnlyList<Data.Server> servers) {
return Task.CompletedTask;
}
public IAsyncEnumerable<Data.Server> Get() {
return AsyncEnumerable.Empty<Data.Server>();
}
}
}

View File

@ -0,0 +1,22 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using DHT.Server.Data;
namespace DHT.Server.Database.Repositories;
public interface IUserRepository {
Task Add(IReadOnlyList<User> users);
IAsyncEnumerable<User> Get();
internal sealed class Dummy : IUserRepository {
public Task Add(IReadOnlyList<User> users) {
return Task.CompletedTask;
}
public IAsyncEnumerable<User> Get() {
return AsyncEnumerable.Empty<User>();
}
}
}

View File

@ -0,0 +1,77 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteChannelRepository : IChannelRepository {
private readonly SqliteConnectionPool pool;
private readonly DatabaseStatistics statistics;
public SqliteChannelRepository(SqliteConnectionPool pool, DatabaseStatistics statistics) {
this.pool = pool;
this.statistics = statistics;
}
internal async Task Initialize() {
using var conn = pool.Take();
await UpdateChannelStatistics(conn);
}
private async Task UpdateChannelStatistics(ISqliteConnection conn) {
statistics.TotalChannels = await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM channels", static reader => reader?.GetInt64(0) ?? 0L);
}
public async Task Add(IReadOnlyList<Channel> channels) {
using var conn = pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("channels", [
("id", SqliteType.Integer),
("server", SqliteType.Integer),
("name", SqliteType.Text),
("parent_id", SqliteType.Integer),
("position", SqliteType.Integer),
("topic", SqliteType.Text),
("nsfw", SqliteType.Integer)
]);
foreach (var channel in channels) {
cmd.Set(":id", channel.Id);
cmd.Set(":server", channel.Server);
cmd.Set(":name", channel.Name);
cmd.Set(":parent_id", channel.ParentId);
cmd.Set(":position", channel.Position);
cmd.Set(":topic", channel.Topic);
cmd.Set(":nsfw", channel.Nsfw);
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
}
await UpdateChannelStatistics(conn);
}
public async IAsyncEnumerable<Channel> Get() {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
yield return new Channel {
Id = reader.GetUint64(0),
Server = reader.GetUint64(1),
Name = reader.GetString(2),
ParentId = reader.IsDBNull(3) ? null : reader.GetUint64(3),
Position = reader.IsDBNull(4) ? null : reader.GetInt32(4),
Topic = reader.IsDBNull(5) ? null : reader.GetString(5),
Nsfw = reader.IsDBNull(6) ? null : reader.GetBoolean(6),
};
}
}
}

View File

@ -0,0 +1,233 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteDownloadRepository : IDownloadRepository {
private readonly SqliteConnectionPool pool;
private readonly AsyncValueComputer<long>.Single totalDownloadsComputer;
public SqliteDownloadRepository(SqliteConnectionPool pool, AsyncValueComputer<long>.Single totalDownloadsComputer) {
this.pool = pool;
this.totalDownloadsComputer = totalDownloadsComputer;
}
public async Task<long> CountAttachments(AttachmentFilter? filter, CancellationToken cancellationToken) {
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments a" + filter.GenerateWhereClause("a"), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
public async Task AddDownload(Data.Download download) {
using (var conn = pool.Take()) {
await using var cmd = conn.Upsert("downloads", [
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("status", SqliteType.Integer),
("size", SqliteType.Integer),
("blob", SqliteType.Blob)
]);
cmd.Set(":normalized_url", download.NormalizedUrl);
cmd.Set(":download_url", download.DownloadUrl);
cmd.Set(":status", (int) download.Status);
cmd.Set(":size", download.Size);
cmd.Set(":blob", download.Data);
await cmd.ExecuteNonQueryAsync();
}
totalDownloadsComputer.Recompute();
}
public async Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken) {
static async Task LoadUndownloadedStatistics(ISqliteConnection conn, DownloadStatusStatistics result, CancellationToken cancellationToken) {
await using var cmd = conn.Command(
"""
SELECT IFNULL(COUNT(size), 0), IFNULL(SUM(size), 0)
FROM (SELECT MAX(a.size) size
FROM attachments a
WHERE a.normalized_url NOT IN (SELECT d.normalized_url FROM downloads d)
GROUP BY a.normalized_url)
""");
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
if (reader.Read()) {
result.SkippedCount = reader.GetInt32(0);
result.SkippedSize = reader.GetUint64(1);
}
}
static async Task LoadSuccessStatistics(ISqliteConnection conn, DownloadStatusStatistics result, CancellationToken cancellationToken) {
await using var cmd = conn.Command(
"""
SELECT
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN size ELSE 0 END), 0)
FROM downloads
"""
);
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
if (reader.Read()) {
result.EnqueuedCount = reader.GetInt32(0);
result.EnqueuedSize = reader.GetUint64(1);
result.SuccessfulCount = reader.GetInt32(2);
result.SuccessfulSize = reader.GetUint64(3);
result.FailedCount = reader.GetInt32(4);
result.FailedSize = reader.GetUint64(5);
}
}
var result = new DownloadStatusStatistics();
using var conn = pool.Take();
await LoadUndownloadedStatistics(conn, result, cancellationToken);
await LoadSuccessStatistics(conn, result, cancellationToken);
return result;
}
public async IAsyncEnumerable<Data.Download> GetWithoutData() {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT normalized_url, download_url, status, size FROM downloads");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
string normalizedUrl = reader.GetString(0);
string downloadUrl = reader.GetString(1);
var status = (DownloadStatus) reader.GetInt32(2);
ulong size = reader.GetUint64(3);
yield return new Data.Download(normalizedUrl, downloadUrl, status, size);
}
}
public async Task<Data.Download> HydrateWithData(Data.Download download) {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT blob FROM downloads WHERE normalized_url = :url");
cmd.AddAndSet(":url", SqliteType.Text, download.NormalizedUrl);
await using var reader = await cmd.ExecuteReaderAsync();
if (reader.Read() && !reader.IsDBNull(0)) {
return download.WithData((byte[]) reader["blob"]);
}
else {
return download;
}
}
public async Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) {
using var conn = pool.Take();
await using var cmd = conn.Command(
"""
SELECT a.type, d.blob FROM downloads d
LEFT JOIN attachments a ON d.normalized_url = a.normalized_url
WHERE d.normalized_url = :normalized_url AND d.status = :success AND d.blob IS NOT NULL
"""
);
cmd.AddAndSet(":normalized_url", SqliteType.Text, normalizedUrl);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
await using var reader = await cmd.ExecuteReaderAsync();
if (!reader.Read()) {
return null;
}
return new DownloadedAttachment {
Type = reader.IsDBNull(0) ? null : reader.GetString(0),
Data = (byte[]) reader["blob"],
};
}
public async Task<int> EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
using var conn = pool.Take();
await using var cmd = conn.Command(
$"""
INSERT INTO downloads (normalized_url, download_url, status, size)
SELECT a.normalized_url, a.download_url, :enqueued, MAX(a.size)
FROM attachments a
{filter.GenerateWhereClause("a")}
GROUP BY a.normalized_url
"""
);
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
return await cmd.ExecuteNonQueryAsync(cancellationToken);
}
public async IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, [EnumeratorCancellation] CancellationToken cancellationToken) {
var found = new List<DownloadItem>();
using var conn = pool.Take();
await using (var cmd = conn.Command("SELECT normalized_url, download_url, size FROM downloads WHERE status = :enqueued LIMIT :limit")) {
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":limit", SqliteType.Integer, Math.Max(0, count));
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
while (reader.Read()) {
found.Add(new DownloadItem {
NormalizedUrl = reader.GetString(0),
DownloadUrl = reader.GetString(1),
Size = reader.GetUint64(2),
});
}
}
if (found.Count != 0) {
await using var cmd = conn.Command("UPDATE downloads SET status = :downloading WHERE normalized_url = :normalized_url AND status = :enqueued");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.Add(":normalized_url", SqliteType.Text);
foreach (var item in found) {
cmd.Set(":normalized_url", item.NormalizedUrl);
if (await cmd.ExecuteNonQueryAsync(cancellationToken) == 1) {
yield return item;
}
}
}
}
public async Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
using (var conn = pool.Take()) {
await conn.ExecuteAsync(
$"""
-- noinspection SqlWithoutWhere
DELETE FROM downloads
{filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching)}
"""
);
}
totalDownloadsComputer.Recompute();
}
}

View File

@ -0,0 +1,306 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteMessageRepository : IMessageRepository {
private readonly SqliteConnectionPool pool;
private readonly AsyncValueComputer<long>.Single totalMessagesComputer;
private readonly AsyncValueComputer<long>.Single totalAttachmentsComputer;
public SqliteMessageRepository(SqliteConnectionPool pool, AsyncValueComputer<long>.Single totalMessagesComputer, AsyncValueComputer<long>.Single totalAttachmentsComputer) {
this.pool = pool;
this.totalMessagesComputer = totalMessagesComputer;
this.totalAttachmentsComputer = totalAttachmentsComputer;
}
public async Task Add(IReadOnlyList<Message> messages) {
if (messages.Count == 0) {
return;
}
static SqliteCommand DeleteByMessageId(ISqliteConnection conn, string tableName) {
return conn.Delete(tableName, ("message_id", SqliteType.Integer));
}
static async Task ExecuteDeleteByMessageId(SqliteCommand cmd, object id) {
cmd.Set(":message_id", id);
await cmd.ExecuteNonQueryAsync();
}
bool addedAttachments = false;
using (var conn = pool.Take()) {
await using var tx = await conn.BeginTransactionAsync();
await using var messageCmd = conn.Upsert("messages", [
("message_id", SqliteType.Integer),
("sender_id", SqliteType.Integer),
("channel_id", SqliteType.Integer),
("text", SqliteType.Text),
("timestamp", SqliteType.Integer)
]);
await using var deleteEditTimestampCmd = DeleteByMessageId(conn, "edit_timestamps");
await using var deleteRepliedToCmd = DeleteByMessageId(conn, "replied_to");
await using var deleteAttachmentsCmd = DeleteByMessageId(conn, "attachments");
await using var deleteEmbedsCmd = DeleteByMessageId(conn, "embeds");
await using var deleteReactionsCmd = DeleteByMessageId(conn, "reactions");
await using var editTimestampCmd = conn.Insert("edit_timestamps", [
("message_id", SqliteType.Integer),
("edit_timestamp", SqliteType.Integer)
]);
await using var repliedToCmd = conn.Insert("replied_to", [
("message_id", SqliteType.Integer),
("replied_to_id", SqliteType.Integer)
]);
await using var attachmentCmd = conn.Insert("attachments", [
("message_id", SqliteType.Integer),
("attachment_id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text),
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("size", SqliteType.Integer),
("width", SqliteType.Integer),
("height", SqliteType.Integer)
]);
await using var embedCmd = conn.Insert("embeds", [
("message_id", SqliteType.Integer),
("json", SqliteType.Text)
]);
await using var reactionCmd = conn.Insert("reactions", [
("message_id", SqliteType.Integer),
("emoji_id", SqliteType.Integer),
("emoji_name", SqliteType.Text),
("emoji_flags", SqliteType.Integer),
("count", SqliteType.Integer)
]);
foreach (var message in messages) {
object messageId = message.Id;
messageCmd.Set(":message_id", messageId);
messageCmd.Set(":sender_id", message.Sender);
messageCmd.Set(":channel_id", message.Channel);
messageCmd.Set(":text", message.Text);
messageCmd.Set(":timestamp", message.Timestamp);
await messageCmd.ExecuteNonQueryAsync();
await ExecuteDeleteByMessageId(deleteEditTimestampCmd, messageId);
await ExecuteDeleteByMessageId(deleteRepliedToCmd, messageId);
await ExecuteDeleteByMessageId(deleteAttachmentsCmd, messageId);
await ExecuteDeleteByMessageId(deleteEmbedsCmd, messageId);
await ExecuteDeleteByMessageId(deleteReactionsCmd, messageId);
if (message.EditTimestamp is {} timestamp) {
editTimestampCmd.Set(":message_id", messageId);
editTimestampCmd.Set(":edit_timestamp", timestamp);
await editTimestampCmd.ExecuteNonQueryAsync();
}
if (message.RepliedToId is {} repliedToId) {
repliedToCmd.Set(":message_id", messageId);
repliedToCmd.Set(":replied_to_id", repliedToId);
await repliedToCmd.ExecuteNonQueryAsync();
}
if (!message.Attachments.IsEmpty) {
addedAttachments = true;
foreach (var attachment in message.Attachments) {
attachmentCmd.Set(":message_id", messageId);
attachmentCmd.Set(":attachment_id", attachment.Id);
attachmentCmd.Set(":name", attachment.Name);
attachmentCmd.Set(":type", attachment.Type);
attachmentCmd.Set(":normalized_url", attachment.NormalizedUrl);
attachmentCmd.Set(":download_url", attachment.DownloadUrl);
attachmentCmd.Set(":size", attachment.Size);
attachmentCmd.Set(":width", attachment.Width);
attachmentCmd.Set(":height", attachment.Height);
await attachmentCmd.ExecuteNonQueryAsync();
}
}
if (!message.Embeds.IsEmpty) {
foreach (var embed in message.Embeds) {
embedCmd.Set(":message_id", messageId);
embedCmd.Set(":json", embed.Json);
await embedCmd.ExecuteNonQueryAsync();
}
}
if (!message.Reactions.IsEmpty) {
foreach (var reaction in message.Reactions) {
reactionCmd.Set(":message_id", messageId);
reactionCmd.Set(":emoji_id", reaction.EmojiId);
reactionCmd.Set(":emoji_name", reaction.EmojiName);
reactionCmd.Set(":emoji_flags", (int) reaction.EmojiFlags);
reactionCmd.Set(":count", reaction.Count);
await reactionCmd.ExecuteNonQueryAsync();
}
}
}
await tx.CommitAsync();
}
totalMessagesComputer.Recompute();
if (addedAttachments) {
totalAttachmentsComputer.Recompute();
}
}
public async Task<long> Count(MessageFilter? filter, CancellationToken cancellationToken) {
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
private sealed class MesageToManyCommand<T> : IAsyncDisposable {
private readonly SqliteCommand cmd;
private readonly Func<SqliteDataReader, T> readItem;
public MesageToManyCommand(ISqliteConnection conn, string sql, Func<SqliteDataReader, T> readItem) {
this.cmd = conn.Command(sql);
this.cmd.Add(":message_id", SqliteType.Integer);
this.readItem = readItem;
}
public async Task<ImmutableList<T>> GetItems(ulong messageId) {
cmd.Set(":message_id", messageId);
var items = ImmutableList<T>.Empty;
await using var reader = await cmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
items = items.Add(readItem(reader));
}
return items;
}
public async ValueTask DisposeAsync() {
await cmd.DisposeAsync();
}
}
public async IAsyncEnumerable<Message> Get(MessageFilter? filter) {
using var conn = pool.Take();
const string AttachmentSql =
"""
SELECT attachment_id, name, type, normalized_url, download_url, size, width, height
FROM attachments
WHERE message_id = :message_id
""";
await using var attachmentCmd = new MesageToManyCommand<Attachment>(conn, AttachmentSql, static reader => new Attachment {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
Type = reader.IsDBNull(2) ? null : reader.GetString(2),
NormalizedUrl = reader.GetString(3),
DownloadUrl = reader.GetString(4),
Size = reader.GetUint64(5),
Width = reader.IsDBNull(6) ? null : reader.GetInt32(6),
Height = reader.IsDBNull(7) ? null : reader.GetInt32(7),
});
const string EmbedSql =
"""
SELECT json
FROM embeds
WHERE message_id = :message_id
""";
await using var embedCmd = new MesageToManyCommand<Embed>(conn, EmbedSql, static reader => new Embed {
Json = reader.GetString(0)
});
const string ReactionSql =
"""
SELECT emoji_id, emoji_name, emoji_flags, count
FROM reactions
WHERE message_id = :message_id
""";
await using var reactionsCmd = new MesageToManyCommand<Reaction>(conn, ReactionSql, static reader => new Reaction {
EmojiId = reader.IsDBNull(0) ? null : reader.GetUint64(0),
EmojiName = reader.IsDBNull(1) ? null : reader.GetString(1),
EmojiFlags = (EmojiFlags) reader.GetInt16(2),
Count = reader.GetInt32(3),
});
await using var messageCmd = conn.Command(
$"""
SELECT m.message_id, m.sender_id, m.channel_id, m.text, m.timestamp, et.edit_timestamp, rt.replied_to_id
FROM messages m
LEFT JOIN edit_timestamps et ON m.message_id = et.message_id
LEFT JOIN replied_to rt ON m.message_id = rt.message_id
{filter.GenerateWhereClause("m")}
"""
);
await using var reader = await messageCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
ulong messageId = reader.GetUint64(0);
yield return new Message {
Id = messageId,
Sender = reader.GetUint64(1),
Channel = reader.GetUint64(2),
Text = reader.GetString(3),
Timestamp = reader.GetInt64(4),
EditTimestamp = reader.IsDBNull(5) ? null : reader.GetInt64(5),
RepliedToId = reader.IsDBNull(6) ? null : reader.GetUint64(6),
Attachments = await attachmentCmd.GetItems(messageId),
Embeds = await embedCmd.GetItems(messageId),
Reactions = await reactionsCmd.GetItems(messageId)
};
}
}
public async IAsyncEnumerable<ulong> GetIds(MessageFilter? filter) {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause());
await using var reader = await cmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
yield return reader.GetUint64(0);
}
}
public async Task Remove(MessageFilter filter, FilterRemovalMode mode) {
using (var conn = pool.Take()) {
await conn.ExecuteAsync(
$"""
-- noinspection SqlWithoutWhere
DELETE FROM messages
{filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching)}
"""
);
}
totalMessagesComputer.Recompute();
}
}

View File

@ -0,0 +1,65 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteServerRepository : IServerRepository {
private readonly SqliteConnectionPool pool;
private readonly DatabaseStatistics statistics;
public SqliteServerRepository(SqliteConnectionPool pool, DatabaseStatistics statistics) {
this.pool = pool;
this.statistics = statistics;
}
internal async Task Initialize() {
using var conn = pool.Take();
await UpdateServerStatistics(conn);
}
private async Task UpdateServerStatistics(ISqliteConnection conn) {
statistics.TotalServers = await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM servers", static reader => reader?.GetInt64(0) ?? 0L);
}
public async Task Add(IReadOnlyList<Data.Server> servers) {
using var conn = pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("servers", [
("id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text)
]);
foreach (var server in servers) {
cmd.Set(":id", server.Id);
cmd.Set(":name", server.Name);
cmd.Set(":type", ServerTypes.ToString(server.Type));
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
}
await UpdateServerStatistics(conn);
}
public async IAsyncEnumerable<Data.Server> Get() {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT id, name, type FROM servers");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
yield return new Data.Server {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
Type = ServerTypes.FromString(reader.GetString(2)),
};
}
}
}

View File

@ -0,0 +1,68 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteUserRepository : IUserRepository {
private readonly SqliteConnectionPool pool;
private readonly DatabaseStatistics statistics;
public SqliteUserRepository(SqliteConnectionPool pool, DatabaseStatistics statistics) {
this.pool = pool;
this.statistics = statistics;
}
internal async Task Initialize() {
using var conn = pool.Take();
await UpdateUserStatistics(conn);
}
private async Task UpdateUserStatistics(ISqliteConnection conn) {
statistics.TotalUsers = await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM users", static reader => reader?.GetInt64(0) ?? 0L);
}
public async Task Add(IReadOnlyList<User> users) {
using var conn = pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("users", [
("id", SqliteType.Integer),
("name", SqliteType.Text),
("avatar_url", SqliteType.Text),
("discriminator", SqliteType.Text)
]);
foreach (var user in users) {
cmd.Set(":id", user.Id);
cmd.Set(":name", user.Name);
cmd.Set(":avatar_url", user.AvatarUrl);
cmd.Set(":discriminator", user.Discriminator);
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
}
await UpdateUserStatistics(conn);
}
public async IAsyncEnumerable<User> Get() {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
yield return new User {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
AvatarUrl = reader.IsDBNull(2) ? null : reader.GetString(2),
Discriminator = reader.IsDBNull(3) ? null : reader.GetString(3),
};
}
}
}

View File

@ -1,4 +1,5 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Data.Common;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Database.Exceptions; using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite.Utils; using DHT.Server.Database.Sqlite.Utils;
@ -20,14 +21,14 @@ sealed class Schema {
} }
public async Task<bool> Setup(ISchemaUpgradeCallbacks callbacks) { public async Task<bool> Setup(ISchemaUpgradeCallbacks callbacks) {
conn.Execute(@"CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)"); await conn.ExecuteAsync("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)");
var dbVersionStr = conn.SelectScalar("SELECT value FROM metadata WHERE key = 'version'"); var dbVersionStr = await conn.ExecuteReaderAsync("SELECT value FROM metadata WHERE key = 'version'", static reader => reader?.GetString(0));
if (dbVersionStr == null) { if (dbVersionStr == null) {
InitializeSchemas(); await InitializeSchemas();
} }
else if (!int.TryParse(dbVersionStr.ToString(), out int dbVersion) || dbVersion < 1) { else if (!int.TryParse(dbVersionStr, out int dbVersion) || dbVersion < 1) {
throw new InvalidDatabaseVersionException(dbVersionStr.ToString() ?? "<null>"); throw new InvalidDatabaseVersionException(dbVersionStr);
} }
else if (dbVersion > Version) { else if (dbVersion > Version) {
throw new DatabaseTooNewException(dbVersion); throw new DatabaseTooNewException(dbVersion);
@ -44,123 +45,123 @@ sealed class Schema {
return true; return true;
} }
private void InitializeSchemas() { private async Task InitializeSchemas() {
conn.Execute(""" await conn.ExecuteAsync("""
CREATE TABLE users ( CREATE TABLE users (
id INTEGER PRIMARY KEY NOT NULL, id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
avatar_url TEXT, avatar_url TEXT,
discriminator TEXT discriminator TEXT
) )
"""); """);
conn.Execute(""" await conn.ExecuteAsync("""
CREATE TABLE servers ( CREATE TABLE servers (
id INTEGER PRIMARY KEY NOT NULL, id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
type TEXT NOT NULL type TEXT NOT NULL
) )
"""); """);
conn.Execute(""" await conn.ExecuteAsync("""
CREATE TABLE channels ( CREATE TABLE channels (
id INTEGER PRIMARY KEY NOT NULL, id INTEGER PRIMARY KEY NOT NULL,
server INTEGER NOT NULL, server INTEGER NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
parent_id INTEGER, parent_id INTEGER,
position INTEGER, position INTEGER,
topic TEXT, topic TEXT,
nsfw INTEGER nsfw INTEGER
) )
"""); """);
conn.Execute(""" await conn.ExecuteAsync("""
CREATE TABLE messages ( CREATE TABLE messages (
message_id INTEGER PRIMARY KEY NOT NULL, message_id INTEGER PRIMARY KEY NOT NULL,
sender_id INTEGER NOT NULL, sender_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL, channel_id INTEGER NOT NULL,
text TEXT NOT NULL, text TEXT NOT NULL,
timestamp INTEGER NOT NULL timestamp INTEGER NOT NULL
) )
"""); """);
conn.Execute(""" await conn.ExecuteAsync("""
CREATE TABLE attachments ( CREATE TABLE attachments (
message_id INTEGER NOT NULL, message_id INTEGER NOT NULL,
attachment_id INTEGER NOT NULL PRIMARY KEY NOT NULL, attachment_id INTEGER NOT NULL PRIMARY KEY NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
type TEXT, type TEXT,
normalized_url TEXT NOT NULL, normalized_url TEXT NOT NULL,
download_url TEXT, download_url TEXT,
size INTEGER NOT NULL, size INTEGER NOT NULL,
width INTEGER, width INTEGER,
height INTEGER height INTEGER
) )
"""); """);
conn.Execute(""" await conn.ExecuteAsync("""
CREATE TABLE embeds ( CREATE TABLE embeds (
message_id INTEGER NOT NULL, message_id INTEGER NOT NULL,
json TEXT NOT NULL json TEXT NOT NULL
) )
"""); """);
conn.Execute(""" await conn.ExecuteAsync("""
CREATE TABLE downloads ( CREATE TABLE downloads (
normalized_url TEXT NOT NULL PRIMARY KEY, normalized_url TEXT NOT NULL PRIMARY KEY,
download_url TEXT, download_url TEXT,
status INTEGER NOT NULL, status INTEGER NOT NULL,
size INTEGER NOT NULL, size INTEGER NOT NULL,
blob BLOB blob BLOB
) )
"""); """);
conn.Execute("""
CREATE TABLE reactions (
message_id INTEGER NOT NULL,
emoji_id INTEGER,
emoji_name TEXT,
emoji_flags INTEGER NOT NULL,
count INTEGER NOT NULL
)
""");
CreateMessageEditTimestampTable(); await conn.ExecuteAsync("""
CreateMessageRepliedToTable(); CREATE TABLE reactions (
message_id INTEGER NOT NULL,
emoji_id INTEGER,
emoji_name TEXT,
emoji_flags INTEGER NOT NULL,
count INTEGER NOT NULL
)
""");
conn.Execute("CREATE INDEX attachments_message_ix ON attachments(message_id)"); await CreateMessageEditTimestampTable();
conn.Execute("CREATE INDEX embeds_message_ix ON embeds(message_id)"); await CreateMessageRepliedToTable();
conn.Execute("CREATE INDEX reactions_message_ix ON reactions(message_id)");
conn.Execute("INSERT INTO metadata (key, value) VALUES ('version', " + Version + ")"); await conn.ExecuteAsync("CREATE INDEX attachments_message_ix ON attachments(message_id)");
await conn.ExecuteAsync("CREATE INDEX embeds_message_ix ON embeds(message_id)");
await conn.ExecuteAsync("CREATE INDEX reactions_message_ix ON reactions(message_id)");
await conn.ExecuteAsync("INSERT INTO metadata (key, value) VALUES ('version', " + Version + ")");
} }
private void CreateMessageEditTimestampTable() { private async Task CreateMessageEditTimestampTable() {
conn.Execute(""" await conn.ExecuteAsync("""
CREATE TABLE edit_timestamps ( CREATE TABLE edit_timestamps (
message_id INTEGER PRIMARY KEY NOT NULL, message_id INTEGER PRIMARY KEY NOT NULL,
edit_timestamp INTEGER NOT NULL edit_timestamp INTEGER NOT NULL
) )
"""); """);
} }
private void CreateMessageRepliedToTable() { private async Task CreateMessageRepliedToTable() {
conn.Execute(""" await conn.ExecuteAsync("""
CREATE TABLE replied_to ( CREATE TABLE replied_to (
message_id INTEGER PRIMARY KEY NOT NULL, message_id INTEGER PRIMARY KEY NOT NULL,
replied_to_id INTEGER NOT NULL replied_to_id INTEGER NOT NULL
) )
"""); """);
} }
private async Task NormalizeAttachmentUrls(ISchemaUpgradeCallbacks.IProgressReporter reporter) { private async Task NormalizeAttachmentUrls(ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing attachments...", 0, 0); await reporter.SubWork("Preparing attachments...", 0, 0);
var normalizedUrls = new Dictionary<long, string>(); var normalizedUrls = new Dictionary<long, string>();
await using (var selectCmd = conn.Command("SELECT attachment_id, url FROM attachments")) { await using (var selectCmd = conn.Command("SELECT attachment_id, url FROM attachments")) {
await using var reader = await selectCmd.ExecuteReaderAsync(); await using var reader = await selectCmd.ExecuteReaderAsync();
while (reader.Read()) { while (reader.Read()) {
var attachmentId = reader.GetInt64(0); var attachmentId = reader.GetInt64(0);
var originalUrl = reader.GetString(1); var originalUrl = reader.GetString(1);
@ -168,7 +169,7 @@ sealed class Schema {
} }
} }
await using var tx = conn.BeginTransaction(); await using var tx = await conn.BeginTransactionAsync();
int totalUrls = normalizedUrls.Count; int totalUrls = normalizedUrls.Count;
int processedUrls = -1; int processedUrls = -1;
@ -176,7 +177,7 @@ sealed class Schema {
await using (var updateCmd = conn.Command("UPDATE attachments SET download_url = url, url = :normalized_url WHERE attachment_id = :attachment_id")) { await using (var updateCmd = conn.Command("UPDATE attachments SET download_url = url, url = :normalized_url WHERE attachment_id = :attachment_id")) {
updateCmd.Add(":attachment_id", SqliteType.Integer); updateCmd.Add(":attachment_id", SqliteType.Integer);
updateCmd.Add(":normalized_url", SqliteType.Text); updateCmd.Add(":normalized_url", SqliteType.Text);
foreach (var (attachmentId, normalizedUrl) in normalizedUrls) { foreach (var (attachmentId, normalizedUrl) in normalizedUrls) {
if (++processedUrls % 1000 == 0) { if (++processedUrls % 1000 == 0) {
await reporter.SubWork("Updating URLs...", processedUrls, totalUrls); await reporter.SubWork("Updating URLs...", processedUrls, totalUrls);
@ -187,15 +188,15 @@ sealed class Schema {
updateCmd.ExecuteNonQuery(); updateCmd.ExecuteNonQuery();
} }
} }
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls); await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await tx.CommitAsync(); await tx.CommitAsync();
} }
private async Task NormalizeDownloadUrls(ISchemaUpgradeCallbacks.IProgressReporter reporter) { private async Task NormalizeDownloadUrls(ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing downloads...", 0, 0); await reporter.SubWork("Preparing downloads...", 0, 0);
var normalizedUrlsToOriginalUrls = new Dictionary<string, string>(); var normalizedUrlsToOriginalUrls = new Dictionary<string, string>();
var duplicateUrlsToDelete = new HashSet<string>(); var duplicateUrlsToDelete = new HashSet<string>();
@ -212,11 +213,11 @@ sealed class Schema {
} }
} }
conn.Execute("PRAGMA cache_size = -20000"); await conn.ExecuteAsync("PRAGMA cache_size = -20000");
SqliteTransaction tx; DbTransaction tx;
await using (tx = conn.BeginTransaction()) { await using (tx = await conn.BeginTransactionAsync()) {
await reporter.SubWork("Deleting duplicates...", 0, 0); await reporter.SubWork("Deleting duplicates...", 0, 0);
await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) { await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) {
@ -225,30 +226,30 @@ sealed class Schema {
deleteCmd.ExecuteNonQuery(); deleteCmd.ExecuteNonQuery();
} }
} }
await tx.CommitAsync(); await tx.CommitAsync();
} }
int totalUrls = normalizedUrlsToOriginalUrls.Count; int totalUrls = normalizedUrlsToOriginalUrls.Count;
int processedUrls = -1; int processedUrls = -1;
tx = conn.BeginTransaction(); tx = await conn.BeginTransactionAsync();
await using (var updateCmd = conn.Command("UPDATE downloads SET download_url = :download_url, url = :normalized_url WHERE url = :download_url")) { await using (var updateCmd = conn.Command("UPDATE downloads SET download_url = :download_url, url = :normalized_url WHERE url = :download_url")) {
updateCmd.Add(":normalized_url", SqliteType.Text); updateCmd.Add(":normalized_url", SqliteType.Text);
updateCmd.Add(":download_url", SqliteType.Text); updateCmd.Add(":download_url", SqliteType.Text);
foreach (var (normalizedUrl, downloadUrl) in normalizedUrlsToOriginalUrls) { foreach (var (normalizedUrl, downloadUrl) in normalizedUrlsToOriginalUrls) {
if (++processedUrls % 100 == 0) { if (++processedUrls % 100 == 0) {
await reporter.SubWork("Updating URLs...", processedUrls, totalUrls); await reporter.SubWork("Updating URLs...", processedUrls, totalUrls);
// Not proper way of dealing with transactions, but it avoids a long commit at the end. // Not proper way of dealing with transactions, but it avoids a long commit at the end.
// Schema upgrades are already non-atomic anyways, so this doesn't make it worse. // Schema upgrades are already non-atomic anyways, so this doesn't make it worse.
await tx.CommitAsync(); await tx.CommitAsync();
await tx.DisposeAsync(); await tx.DisposeAsync();
tx = conn.BeginTransaction(); tx = await conn.BeginTransactionAsync();
updateCmd.Transaction = tx; updateCmd.Transaction = (SqliteTransaction) tx;
} }
updateCmd.Set(":normalized_url", normalizedUrl); updateCmd.Set(":normalized_url", normalizedUrl);
@ -256,98 +257,98 @@ sealed class Schema {
updateCmd.ExecuteNonQuery(); updateCmd.ExecuteNonQuery();
} }
} }
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls); await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await tx.CommitAsync(); await tx.CommitAsync();
await tx.DisposeAsync(); await tx.DisposeAsync();
conn.Execute("PRAGMA cache_size = -2000"); await conn.ExecuteAsync("PRAGMA cache_size = -2000");
} }
private async Task UpgradeSchemas(int dbVersion, ISchemaUpgradeCallbacks.IProgressReporter reporter) { private async Task UpgradeSchemas(int dbVersion, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
var perf = Log.Start("from version " + dbVersion); var perf = Log.Start("from version " + dbVersion);
conn.Execute("UPDATE metadata SET value = " + Version + " WHERE key = 'version'"); await conn.ExecuteAsync("UPDATE metadata SET value = " + Version + " WHERE key = 'version'");
if (dbVersion <= 1) { if (dbVersion <= 1) {
await reporter.MainWork("Applying schema changes...", 0, 1); await reporter.MainWork("Applying schema changes...", 0, 1);
conn.Execute("ALTER TABLE channels ADD parent_id INTEGER"); await conn.ExecuteAsync("ALTER TABLE channels ADD parent_id INTEGER");
perf.Step("Upgrade to version 2"); perf.Step("Upgrade to version 2");
await reporter.NextVersion(); await reporter.NextVersion();
} }
if (dbVersion <= 2) { if (dbVersion <= 2) {
await reporter.MainWork("Applying schema changes...", 0, 1); await reporter.MainWork("Applying schema changes...", 0, 1);
CreateMessageEditTimestampTable();
CreateMessageRepliedToTable();
conn.Execute(""" await CreateMessageEditTimestampTable();
INSERT INTO edit_timestamps (message_id, edit_timestamp) await CreateMessageRepliedToTable();
SELECT message_id, edit_timestamp
FROM messages
WHERE edit_timestamp IS NOT NULL
""");
conn.Execute(""" await conn.ExecuteAsync("""
INSERT INTO replied_to (message_id, replied_to_id) INSERT INTO edit_timestamps (message_id, edit_timestamp)
SELECT message_id, replied_to_id SELECT message_id, edit_timestamp
FROM messages FROM messages
WHERE replied_to_id IS NOT NULL WHERE edit_timestamp IS NOT NULL
"""); """);
conn.Execute("ALTER TABLE messages DROP COLUMN replied_to_id"); await conn.ExecuteAsync("""
conn.Execute("ALTER TABLE messages DROP COLUMN edit_timestamp"); INSERT INTO replied_to (message_id, replied_to_id)
SELECT message_id, replied_to_id
FROM messages
WHERE replied_to_id IS NOT NULL
""");
await conn.ExecuteAsync("ALTER TABLE messages DROP COLUMN replied_to_id");
await conn.ExecuteAsync("ALTER TABLE messages DROP COLUMN edit_timestamp");
perf.Step("Upgrade to version 3"); perf.Step("Upgrade to version 3");
await reporter.MainWork("Vacuuming the database...", 1, 1); await reporter.MainWork("Vacuuming the database...", 1, 1);
conn.Execute("VACUUM"); await conn.ExecuteAsync("VACUUM");
perf.Step("Vacuum"); perf.Step("Vacuum");
await reporter.NextVersion(); await reporter.NextVersion();
} }
if (dbVersion <= 3) { if (dbVersion <= 3) {
conn.Execute(""" await conn.ExecuteAsync("""
CREATE TABLE downloads ( CREATE TABLE downloads (
url TEXT NOT NULL PRIMARY KEY, url TEXT NOT NULL PRIMARY KEY,
status INTEGER NOT NULL, status INTEGER NOT NULL,
size INTEGER NOT NULL, size INTEGER NOT NULL,
blob BLOB blob BLOB
) )
"""); """);
perf.Step("Upgrade to version 4"); perf.Step("Upgrade to version 4");
await reporter.NextVersion(); await reporter.NextVersion();
} }
if (dbVersion <= 4) { if (dbVersion <= 4) {
await reporter.MainWork("Applying schema changes...", 0, 1); await reporter.MainWork("Applying schema changes...", 0, 1);
conn.Execute("ALTER TABLE attachments ADD width INTEGER"); await conn.ExecuteAsync("ALTER TABLE attachments ADD width INTEGER");
conn.Execute("ALTER TABLE attachments ADD height INTEGER"); await conn.ExecuteAsync("ALTER TABLE attachments ADD height INTEGER");
perf.Step("Upgrade to version 5"); perf.Step("Upgrade to version 5");
await reporter.NextVersion(); await reporter.NextVersion();
} }
if (dbVersion <= 5) { if (dbVersion <= 5) {
await reporter.MainWork("Applying schema changes...", 0, 3); await reporter.MainWork("Applying schema changes...", 0, 3);
conn.Execute("ALTER TABLE attachments ADD download_url TEXT"); await conn.ExecuteAsync("ALTER TABLE attachments ADD download_url TEXT");
conn.Execute("ALTER TABLE downloads ADD download_url TEXT"); await conn.ExecuteAsync("ALTER TABLE downloads ADD download_url TEXT");
await reporter.MainWork("Updating attachments...", 1, 3); await reporter.MainWork("Updating attachments...", 1, 3);
await NormalizeAttachmentUrls(reporter); await NormalizeAttachmentUrls(reporter);
await reporter.MainWork("Updating downloads...", 2, 3); await reporter.MainWork("Updating downloads...", 2, 3);
await NormalizeDownloadUrls(reporter); await NormalizeDownloadUrls(reporter);
await reporter.MainWork("Applying schema changes...", 3, 3); await reporter.MainWork("Applying schema changes...", 3, 3);
conn.Execute("ALTER TABLE attachments RENAME COLUMN url TO normalized_url"); await conn.ExecuteAsync("ALTER TABLE attachments RENAME COLUMN url TO normalized_url");
conn.Execute("ALTER TABLE downloads RENAME COLUMN url TO normalized_url"); await conn.ExecuteAsync("ALTER TABLE downloads RENAME COLUMN url TO normalized_url");
perf.Step("Upgrade to version 6"); perf.Step("Upgrade to version 6");
await reporter.NextVersion(); await reporter.NextVersion();
} }

View File

@ -1,15 +1,8 @@
using System; using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Database.Repositories;
using DHT.Server.Data.Aggregations; using DHT.Server.Database.Sqlite.Repositories;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Sqlite.Utils; using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Collections;
using DHT.Utils.Logging;
using DHT.Utils.Tasks; using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite; using Microsoft.Data.Sqlite;
@ -36,7 +29,9 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
} }
if (wasOpened) { if (wasOpened) {
return new SqliteDatabaseFile(path, pool, computeTaskResultScheduler); var db = new SqliteDatabaseFile(path, pool, computeTaskResultScheduler);
await db.Initialize();
return db;
} }
else { else {
pool.Dispose(); pool.Dispose();
@ -46,15 +41,26 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
public string Path { get; } public string Path { get; }
public DatabaseStatistics Statistics { get; } public DatabaseStatistics Statistics { get; }
private readonly Log log; public IUserRepository Users => users;
public IServerRepository Servers => servers;
public IChannelRepository Channels => channels;
public IMessageRepository Messages => messages;
public IDownloadRepository Downloads => downloads;
private readonly SqliteConnectionPool pool; private readonly SqliteConnectionPool pool;
private readonly SqliteUserRepository users;
private readonly SqliteServerRepository servers;
private readonly SqliteChannelRepository channels;
private readonly SqliteMessageRepository messages;
private readonly SqliteDownloadRepository downloads;
private readonly AsyncValueComputer<long>.Single totalMessagesComputer; private readonly AsyncValueComputer<long>.Single totalMessagesComputer;
private readonly AsyncValueComputer<long>.Single totalAttachmentsComputer; private readonly AsyncValueComputer<long>.Single totalAttachmentsComputer;
private readonly AsyncValueComputer<long>.Single totalDownloadsComputer; private readonly AsyncValueComputer<long>.Single totalDownloadsComputer;
private SqliteDatabaseFile(string path, SqliteConnectionPool pool, TaskScheduler computeTaskResultScheduler) { private SqliteDatabaseFile(string path, SqliteConnectionPool pool, TaskScheduler computeTaskResultScheduler) {
this.log = Log.ForType(typeof(SqliteDatabaseFile), System.IO.Path.GetFileName(path));
this.pool = pool; this.pool = pool;
this.totalMessagesComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateMessageStatistics, computeTaskResultScheduler).WithOutdatedResults().BuildWithComputer(ComputeMessageStatistics); this.totalMessagesComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateMessageStatistics, computeTaskResultScheduler).WithOutdatedResults().BuildWithComputer(ComputeMessageStatistics);
@ -64,669 +70,65 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
this.Path = path; this.Path = path;
this.Statistics = new DatabaseStatistics(); this.Statistics = new DatabaseStatistics();
using (var conn = pool.Take()) { this.users = new SqliteUserRepository(pool, Statistics);
UpdateServerStatistics(conn); this.servers = new SqliteServerRepository(pool, Statistics);
UpdateChannelStatistics(conn); this.channels = new SqliteChannelRepository(pool, Statistics);
UpdateUserStatistics(conn); this.messages = new SqliteMessageRepository(pool, totalMessagesComputer, totalAttachmentsComputer);
} this.downloads = new SqliteDownloadRepository(pool, totalDownloadsComputer);
totalMessagesComputer.Recompute(); totalMessagesComputer.Recompute();
totalAttachmentsComputer.Recompute(); totalAttachmentsComputer.Recompute();
totalDownloadsComputer.Recompute(); totalDownloadsComputer.Recompute();
} }
private async Task Initialize() {
await users.Initialize();
await servers.Initialize();
await channels.Initialize();
}
public void Dispose() { public void Dispose() {
totalMessagesComputer.Cancel();
totalAttachmentsComputer.Cancel();
totalDownloadsComputer.Cancel();
pool.Dispose(); pool.Dispose();
} }
public DatabaseStatisticsSnapshot SnapshotStatistics() { public async Task<DatabaseStatisticsSnapshot> SnapshotStatistics() {
return new DatabaseStatisticsSnapshot { return new DatabaseStatisticsSnapshot {
TotalServers = Statistics.TotalServers, TotalServers = Statistics.TotalServers,
TotalChannels = Statistics.TotalChannels, TotalChannels = Statistics.TotalChannels,
TotalUsers = Statistics.TotalUsers, TotalUsers = Statistics.TotalUsers,
TotalMessages = ComputeMessageStatistics(), TotalMessages = await ComputeMessageStatistics(),
}; };
} }
public void AddServer(Data.Server server) { public async Task Vacuum() {
using var conn = pool.Take(); using var conn = pool.Take();
using var cmd = conn.Upsert("servers", new[] { await conn.ExecuteAsync("VACUUM");
("id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text),
});
cmd.Set(":id", server.Id);
cmd.Set(":name", server.Name);
cmd.Set(":type", ServerTypes.ToString(server.Type));
cmd.ExecuteNonQuery();
UpdateServerStatistics(conn);
} }
public List<Data.Server> GetAllServers() { private async Task<long> ComputeMessageStatistics() {
var perf = log.Start();
var list = new List<Data.Server>();
using var conn = pool.Take(); using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, name, type FROM servers"); return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages", static reader => reader?.GetInt64(0) ?? 0L);
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
list.Add(new Data.Server {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
Type = ServerTypes.FromString(reader.GetString(2)),
});
}
perf.End();
return list;
}
public void AddChannel(Channel channel) {
using var conn = pool.Take();
using var cmd = conn.Upsert("channels", new[] {
("id", SqliteType.Integer),
("server", SqliteType.Integer),
("name", SqliteType.Text),
("parent_id", SqliteType.Integer),
("position", SqliteType.Integer),
("topic", SqliteType.Text),
("nsfw", SqliteType.Integer),
});
cmd.Set(":id", channel.Id);
cmd.Set(":server", channel.Server);
cmd.Set(":name", channel.Name);
cmd.Set(":parent_id", channel.ParentId);
cmd.Set(":position", channel.Position);
cmd.Set(":topic", channel.Topic);
cmd.Set(":nsfw", channel.Nsfw);
cmd.ExecuteNonQuery();
UpdateChannelStatistics(conn);
}
public List<Channel> GetAllChannels() {
var list = new List<Channel>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
list.Add(new Channel {
Id = reader.GetUint64(0),
Server = reader.GetUint64(1),
Name = reader.GetString(2),
ParentId = reader.IsDBNull(3) ? null : reader.GetUint64(3),
Position = reader.IsDBNull(4) ? null : reader.GetInt32(4),
Topic = reader.IsDBNull(5) ? null : reader.GetString(5),
Nsfw = reader.IsDBNull(6) ? null : reader.GetBoolean(6),
});
}
return list;
}
public void AddUsers(User[] users) {
using var conn = pool.Take();
using var tx = conn.BeginTransaction();
using var cmd = conn.Upsert("users", new[] {
("id", SqliteType.Integer),
("name", SqliteType.Text),
("avatar_url", SqliteType.Text),
("discriminator", SqliteType.Text),
});
foreach (var user in users) {
cmd.Set(":id", user.Id);
cmd.Set(":name", user.Name);
cmd.Set(":avatar_url", user.AvatarUrl);
cmd.Set(":discriminator", user.Discriminator);
cmd.ExecuteNonQuery();
}
tx.Commit();
UpdateUserStatistics(conn);
}
public List<User> GetAllUsers() {
var perf = log.Start();
var list = new List<User>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
list.Add(new User {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
AvatarUrl = reader.IsDBNull(2) ? null : reader.GetString(2),
Discriminator = reader.IsDBNull(3) ? null : reader.GetString(3),
});
}
perf.End();
return list;
}
public void AddMessages(Message[] messages) {
static SqliteCommand DeleteByMessageId(ISqliteConnection conn, string tableName) {
return conn.Delete(tableName, ("message_id", SqliteType.Integer));
}
static void ExecuteDeleteByMessageId(SqliteCommand cmd, object id) {
cmd.Set(":message_id", id);
cmd.ExecuteNonQuery();
}
bool addedAttachments = false;
using (var conn = pool.Take()) {
using var tx = conn.BeginTransaction();
using var messageCmd = conn.Upsert("messages", new[] {
("message_id", SqliteType.Integer),
("sender_id", SqliteType.Integer),
("channel_id", SqliteType.Integer),
("text", SqliteType.Text),
("timestamp", SqliteType.Integer),
});
using var deleteEditTimestampCmd = DeleteByMessageId(conn, "edit_timestamps");
using var deleteRepliedToCmd = DeleteByMessageId(conn, "replied_to");
using var deleteAttachmentsCmd = DeleteByMessageId(conn, "attachments");
using var deleteEmbedsCmd = DeleteByMessageId(conn, "embeds");
using var deleteReactionsCmd = DeleteByMessageId(conn, "reactions");
using var editTimestampCmd = conn.Insert("edit_timestamps", new [] {
("message_id", SqliteType.Integer),
("edit_timestamp", SqliteType.Integer),
});
using var repliedToCmd = conn.Insert("replied_to", new [] {
("message_id", SqliteType.Integer),
("replied_to_id", SqliteType.Integer),
});
using var attachmentCmd = conn.Insert("attachments", new[] {
("message_id", SqliteType.Integer),
("attachment_id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text),
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("size", SqliteType.Integer),
("width", SqliteType.Integer),
("height", SqliteType.Integer),
});
using var embedCmd = conn.Insert("embeds", new[] {
("message_id", SqliteType.Integer),
("json", SqliteType.Text),
});
using var reactionCmd = conn.Insert("reactions", new[] {
("message_id", SqliteType.Integer),
("emoji_id", SqliteType.Integer),
("emoji_name", SqliteType.Text),
("emoji_flags", SqliteType.Integer),
("count", SqliteType.Integer),
});
foreach (var message in messages) {
object messageId = message.Id;
messageCmd.Set(":message_id", messageId);
messageCmd.Set(":sender_id", message.Sender);
messageCmd.Set(":channel_id", message.Channel);
messageCmd.Set(":text", message.Text);
messageCmd.Set(":timestamp", message.Timestamp);
messageCmd.ExecuteNonQuery();
ExecuteDeleteByMessageId(deleteEditTimestampCmd, messageId);
ExecuteDeleteByMessageId(deleteRepliedToCmd, messageId);
ExecuteDeleteByMessageId(deleteAttachmentsCmd, messageId);
ExecuteDeleteByMessageId(deleteEmbedsCmd, messageId);
ExecuteDeleteByMessageId(deleteReactionsCmd, messageId);
if (message.EditTimestamp is {} timestamp) {
editTimestampCmd.Set(":message_id", messageId);
editTimestampCmd.Set(":edit_timestamp", timestamp);
editTimestampCmd.ExecuteNonQuery();
}
if (message.RepliedToId is {} repliedToId) {
repliedToCmd.Set(":message_id", messageId);
repliedToCmd.Set(":replied_to_id", repliedToId);
repliedToCmd.ExecuteNonQuery();
}
if (!message.Attachments.IsEmpty) {
addedAttachments = true;
foreach (var attachment in message.Attachments) {
attachmentCmd.Set(":message_id", messageId);
attachmentCmd.Set(":attachment_id", attachment.Id);
attachmentCmd.Set(":name", attachment.Name);
attachmentCmd.Set(":type", attachment.Type);
attachmentCmd.Set(":normalized_url", attachment.NormalizedUrl);
attachmentCmd.Set(":download_url", attachment.DownloadUrl);
attachmentCmd.Set(":size", attachment.Size);
attachmentCmd.Set(":width", attachment.Width);
attachmentCmd.Set(":height", attachment.Height);
attachmentCmd.ExecuteNonQuery();
}
}
if (!message.Embeds.IsEmpty) {
foreach (var embed in message.Embeds) {
embedCmd.Set(":message_id", messageId);
embedCmd.Set(":json", embed.Json);
embedCmd.ExecuteNonQuery();
}
}
if (!message.Reactions.IsEmpty) {
foreach (var reaction in message.Reactions) {
reactionCmd.Set(":message_id", messageId);
reactionCmd.Set(":emoji_id", reaction.EmojiId);
reactionCmd.Set(":emoji_name", reaction.EmojiName);
reactionCmd.Set(":emoji_flags", (int) reaction.EmojiFlags);
reactionCmd.Set(":count", reaction.Count);
reactionCmd.ExecuteNonQuery();
}
}
}
tx.Commit();
}
totalMessagesComputer.Recompute();
if (addedAttachments) {
totalAttachmentsComputer.Recompute();
}
}
public int CountMessages(MessageFilter? filter = null) {
using var conn = pool.Take();
using var cmd = conn.Command("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause());
using var reader = cmd.ExecuteReader();
return reader.Read() ? reader.GetInt32(0) : 0;
}
public List<Message> GetMessages(MessageFilter? filter = null) {
var perf = log.Start();
var list = new List<Message>();
var attachments = GetAllAttachments();
var embeds = GetAllEmbeds();
var reactions = GetAllReactions();
using var conn = pool.Take();
using var cmd = conn.Command($"""
SELECT m.message_id, m.sender_id, m.channel_id, m.text, m.timestamp, et.edit_timestamp, rt.replied_to_id
FROM messages m
LEFT JOIN edit_timestamps et ON m.message_id = et.message_id
LEFT JOIN replied_to rt ON m.message_id = rt.message_id
{filter.GenerateWhereClause("m")}
""");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong id = reader.GetUint64(0);
list.Add(new Message {
Id = id,
Sender = reader.GetUint64(1),
Channel = reader.GetUint64(2),
Text = reader.GetString(3),
Timestamp = reader.GetInt64(4),
EditTimestamp = reader.IsDBNull(5) ? null : reader.GetInt64(5),
RepliedToId = reader.IsDBNull(6) ? null : reader.GetUint64(6),
Attachments = attachments.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Attachment>.Empty,
Embeds = embeds.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Embed>.Empty,
Reactions = reactions.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Reaction>.Empty,
});
}
perf.End();
return list;
}
public HashSet<ulong> GetMessageIds(MessageFilter? filter = null) {
var perf = log.Start();
var ids = new HashSet<ulong>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause());
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ids.Add(reader.GetUint64(0));
}
perf.End();
return ids;
}
public void RemoveMessages(MessageFilter filter, FilterRemovalMode mode) {
var perf = log.Start();
DeleteFromTable("messages", filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching));
totalMessagesComputer.Recompute();
perf.End();
}
public int CountAttachments(AttachmentFilter? filter = null) {
using var conn = pool.Take();
using var cmd = conn.Command("SELECT COUNT(DISTINCT normalized_url) FROM attachments a" + filter.GenerateWhereClause("a"));
using var reader = cmd.ExecuteReader();
return reader.Read() ? reader.GetInt32(0) : 0;
}
public void AddDownload(Data.Download download) {
using var conn = pool.Take();
using var cmd = conn.Upsert("downloads", new[] {
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("status", SqliteType.Integer),
("size", SqliteType.Integer),
("blob", SqliteType.Blob),
});
cmd.Set(":normalized_url", download.NormalizedUrl);
cmd.Set(":download_url", download.DownloadUrl);
cmd.Set(":status", (int) download.Status);
cmd.Set(":size", download.Size);
cmd.Set(":blob", download.Data);
cmd.ExecuteNonQuery();
totalDownloadsComputer.Recompute();
}
public List<Data.Download> GetDownloadsWithoutData() {
var list = new List<Data.Download>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT normalized_url, download_url, status, size FROM downloads");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
string normalizedUrl = reader.GetString(0);
string downloadUrl = reader.GetString(1);
var status = (DownloadStatus) reader.GetInt32(2);
ulong size = reader.GetUint64(3);
list.Add(new Data.Download(normalizedUrl, downloadUrl, status, size));
}
return list;
}
public Data.Download GetDownloadWithData(Data.Download download) {
using var conn = pool.Take();
using var cmd = conn.Command("SELECT blob FROM downloads WHERE normalized_url = :url");
cmd.AddAndSet(":url", SqliteType.Text, download.NormalizedUrl);
using var reader = cmd.ExecuteReader();
if (reader.Read() && !reader.IsDBNull(0)) {
return download.WithData((byte[]) reader["blob"]);
}
else {
return download;
}
}
public DownloadedAttachment? GetDownloadedAttachment(string normalizedUrl) {
using var conn = pool.Take();
using var cmd = conn.Command("""
SELECT a.type, d.blob FROM downloads d
LEFT JOIN attachments a ON d.normalized_url = a.normalized_url
WHERE d.normalized_url = :normalized_url AND d.status = :success AND d.blob IS NOT NULL
""");
cmd.AddAndSet(":normalized_url", SqliteType.Text, normalizedUrl);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
using var reader = cmd.ExecuteReader();
if (!reader.Read()) {
return null;
}
return new DownloadedAttachment {
Type = reader.IsDBNull(0) ? null : reader.GetString(0),
Data = (byte[]) reader["blob"],
};
}
public void EnqueueDownloadItems(AttachmentFilter? filter = null) {
using var conn = pool.Take();
using var cmd = conn.Command($"""
INSERT INTO downloads (normalized_url, download_url, status, size)
SELECT a.normalized_url, a.download_url, :enqueued, MAX(a.size)
FROM attachments a
{filter.GenerateWhereClause("a")}
GROUP BY a.normalized_url
""");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.ExecuteNonQuery();
}
public List<DownloadItem> PullEnqueuedDownloadItems(int count) {
var found = new List<DownloadItem>();
var pulled = new List<DownloadItem>();
using var conn = pool.Take();
using (var cmd = conn.Command("SELECT normalized_url, download_url, size FROM downloads WHERE status = :enqueued LIMIT :limit")) {
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":limit", SqliteType.Integer, Math.Max(0, count));
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
found.Add(new DownloadItem {
NormalizedUrl = reader.GetString(0),
DownloadUrl = reader.GetString(1),
Size = reader.GetUint64(2),
});
}
}
if (found.Count != 0) {
using var cmd = conn.Command("UPDATE downloads SET status = :downloading WHERE normalized_url = :normalized_url AND status = :enqueued");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.Add(":normalized_url", SqliteType.Text);
foreach (var item in found) {
cmd.Set(":normalized_url", item.NormalizedUrl);
if (cmd.ExecuteNonQuery() == 1) {
pulled.Add(item);
}
}
}
return pulled;
}
public void RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
DeleteFromTable("downloads", filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching));
totalDownloadsComputer.Recompute();
}
public DownloadStatusStatistics GetDownloadStatusStatistics() {
static void LoadUndownloadedStatistics(ISqliteConnection conn, DownloadStatusStatistics result) {
using var cmd = conn.Command("SELECT IFNULL(COUNT(size), 0), IFNULL(SUM(size), 0) FROM (SELECT MAX(a.size) size FROM attachments a WHERE a.normalized_url NOT IN (SELECT d.normalized_url FROM downloads d) GROUP BY a.normalized_url)");
using var reader = cmd.ExecuteReader();
if (reader.Read()) {
result.SkippedCount = reader.GetInt32(0);
result.SkippedSize = reader.GetUint64(1);
}
}
static void LoadSuccessStatistics(ISqliteConnection conn, DownloadStatusStatistics result) {
using var cmd = conn.Command("""
SELECT
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN size ELSE 0 END), 0)
FROM downloads
""");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
using var reader = cmd.ExecuteReader();
if (reader.Read()) {
result.EnqueuedCount = reader.GetInt32(0);
result.EnqueuedSize = reader.GetUint64(1);
result.SuccessfulCount = reader.GetInt32(2);
result.SuccessfulSize = reader.GetUint64(3);
result.FailedCount = reader.GetInt32(4);
result.FailedSize = reader.GetUint64(5);
}
}
var result = new DownloadStatusStatistics();
using var conn = pool.Take();
LoadUndownloadedStatistics(conn, result);
LoadSuccessStatistics(conn, result);
return result;
}
private MultiDictionary<ulong, Attachment> GetAllAttachments() {
var dict = new MultiDictionary<ulong, Attachment>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, attachment_id, name, type, normalized_url, download_url, size, width, height FROM attachments");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Attachment {
Id = reader.GetUint64(1),
Name = reader.GetString(2),
Type = reader.IsDBNull(3) ? null : reader.GetString(3),
NormalizedUrl = reader.GetString(4),
DownloadUrl = reader.GetString(5),
Size = reader.GetUint64(6),
Width = reader.IsDBNull(7) ? null : reader.GetInt32(7),
Height = reader.IsDBNull(8) ? null : reader.GetInt32(8),
});
}
return dict;
}
private MultiDictionary<ulong, Embed> GetAllEmbeds() {
var dict = new MultiDictionary<ulong, Embed>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, json FROM embeds");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Embed {
Json = reader.GetString(1),
});
}
return dict;
}
private MultiDictionary<ulong, Reaction> GetAllReactions() {
var dict = new MultiDictionary<ulong, Reaction>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, emoji_id, emoji_name, emoji_flags, count FROM reactions");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Reaction {
EmojiId = reader.IsDBNull(1) ? null : reader.GetUint64(1),
EmojiName = reader.IsDBNull(2) ? null : reader.GetString(2),
EmojiFlags = (EmojiFlags) reader.GetInt16(3),
Count = reader.GetInt32(4),
});
}
return dict;
}
private void DeleteFromTable(string table, string whereClause) {
// Rider is being stupid...
StringBuilder build = new StringBuilder()
.Append("DELETE ")
.Append("FROM ")
.Append(table)
.Append(whereClause);
using var conn = pool.Take();
using var cmd = conn.Command(build.ToString());
cmd.ExecuteNonQuery();
}
public void Vacuum() {
using var conn = pool.Take();
using var cmd = conn.Command("VACUUM");
cmd.ExecuteNonQuery();
}
private void UpdateServerStatistics(ISqliteConnection conn) {
Statistics.TotalServers = conn.SelectScalar("SELECT COUNT(*) FROM servers") as long? ?? 0;
}
private void UpdateChannelStatistics(ISqliteConnection conn) {
Statistics.TotalChannels = conn.SelectScalar("SELECT COUNT(*) FROM channels") as long? ?? 0;
}
private void UpdateUserStatistics(ISqliteConnection conn) {
Statistics.TotalUsers = conn.SelectScalar("SELECT COUNT(*) FROM users") as long? ?? 0;
}
private long ComputeMessageStatistics() {
using var conn = pool.Take();
return conn.SelectScalar("SELECT COUNT(*) FROM messages") as long? ?? 0L;
} }
private void UpdateMessageStatistics(long totalMessages) { private void UpdateMessageStatistics(long totalMessages) {
Statistics.TotalMessages = totalMessages; Statistics.TotalMessages = totalMessages;
} }
private long ComputeAttachmentStatistics() { private async Task<long> ComputeAttachmentStatistics() {
using var conn = pool.Take(); using var conn = pool.Take();
return conn.SelectScalar("SELECT COUNT(DISTINCT normalized_url) FROM attachments") as long? ?? 0L; return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments", static reader => reader?.GetInt64(0) ?? 0L);
} }
private void UpdateAttachmentStatistics(long totalAttachments) { private void UpdateAttachmentStatistics(long totalAttachments) {
Statistics.TotalAttachments = totalAttachments; Statistics.TotalAttachments = totalAttachments;
} }
private long ComputeDownloadStatistics() { private async Task<long> ComputeDownloadStatistics() {
using var conn = pool.Take(); using var conn = pool.Take();
return conn.SelectScalar("SELECT COUNT(*) FROM downloads") as long? ?? 0L; return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM downloads", static reader => reader?.GetInt64(0) ?? 0L);
} }
private void UpdateDownloadStatistics(long totalDownloads) { private void UpdateDownloadStatistics(long totalDownloads) {

View File

@ -8,9 +8,13 @@ using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite; namespace DHT.Server.Database.Sqlite;
static class SqliteFilters { static class SqliteFilters {
private static string WhereAll(bool invert) {
return invert ? "WHERE FALSE" : "";
}
public static string GenerateWhereClause(this MessageFilter? filter, string? tableAlias = null, bool invert = false) { public static string GenerateWhereClause(this MessageFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null) { if (filter == null || filter.IsEmpty) {
return ""; return WhereAll(invert);
} }
var where = new SqliteWhereGenerator(tableAlias, invert); var where = new SqliteWhereGenerator(tableAlias, invert);
@ -39,8 +43,8 @@ static class SqliteFilters {
} }
public static string GenerateWhereClause(this AttachmentFilter? filter, string? tableAlias = null, bool invert = false) { public static string GenerateWhereClause(this AttachmentFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null) { if (filter == null || filter.IsEmpty) {
return ""; return WhereAll(invert);
} }
var where = new SqliteWhereGenerator(tableAlias, invert); var where = new SqliteWhereGenerator(tableAlias, invert);
@ -60,8 +64,8 @@ static class SqliteFilters {
} }
public static string GenerateWhereClause(this DownloadItemFilter? filter, string? tableAlias = null, bool invert = false) { public static string GenerateWhereClause(this DownloadItemFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null) { if (filter == null || filter.IsEmpty) {
return ""; return WhereAll(invert);
} }
var where = new SqliteWhereGenerator(tableAlias, invert); var where = new SqliteWhereGenerator(tableAlias, invert);

View File

@ -1,7 +1,7 @@
using System; using System;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace DHT.Server.Database.Sqlite; namespace DHT.Server.Database.Sqlite.Utils;
public interface ISchemaUpgradeCallbacks { public interface ISchemaUpgradeCallbacks {
Task<bool> CanUpgrade(); Task<bool> CanUpgrade();

View File

@ -1,12 +1,16 @@
using System; using System;
using System.Data.Common;
using System.Linq; using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;
using Microsoft.Data.Sqlite; using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils; namespace DHT.Server.Database.Sqlite.Utils;
static class SqliteExtensions { static class SqliteExtensions {
public static SqliteTransaction BeginTransaction(this ISqliteConnection conn) { public static ValueTask<DbTransaction> BeginTransactionAsync(this ISqliteConnection conn) {
return conn.InnerConnection.BeginTransaction(); return conn.InnerConnection.BeginTransactionAsync();
} }
public static SqliteCommand Command(this ISqliteConnection conn, string sql) { public static SqliteCommand Command(this ISqliteConnection conn, string sql) {
@ -15,14 +19,16 @@ static class SqliteExtensions {
return cmd; return cmd;
} }
public static void Execute(this ISqliteConnection conn, string sql) { public static async Task<int> ExecuteAsync(this ISqliteConnection conn, [LanguageInjection("sql")] string sql, CancellationToken cancellationToken = default) {
using var cmd = conn.Command(sql); await using var cmd = conn.Command(sql);
cmd.ExecuteNonQuery(); return await cmd.ExecuteNonQueryAsync(cancellationToken);
} }
public static async Task<T> ExecuteReaderAsync<T>(this ISqliteConnection conn, string sql, Func<SqliteDataReader?, T> readFunction, CancellationToken cancellationToken = default) {
await using var cmd = conn.Command(sql);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
public static object? SelectScalar(this ISqliteConnection conn, string sql) { return reader.Read() ? readFunction(reader) : readFunction(null);
using var cmd = conn.Command(sql);
return cmd.ExecuteScalar();
} }
public static SqliteCommand Insert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) { public static SqliteCommand Insert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
@ -52,7 +58,7 @@ static class SqliteExtensions {
public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) { public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) {
var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name); var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name);
CreateParameters(cmd, new [] { column }); CreateParameters(cmd, new[] { column });
return cmd; return cmd;
} }

View File

@ -30,10 +30,10 @@ sealed class DownloaderTask : IAsyncDisposable {
private readonly IDatabaseFile db; private readonly IDatabaseFile db;
private readonly Subject<DownloadItem> finishedItemPublisher = new (); private readonly Subject<DownloadItem> finishedItemPublisher = new ();
private readonly Task queueWriterTask; private readonly Task queueWriterTask;
private readonly Task[] downloadTasks; private readonly Task[] downloadTasks;
public IObservable<DownloadItem> FinishedItems => finishedItemPublisher; public IObservable<DownloadItem> FinishedItems => finishedItemPublisher;
internal DownloaderTask(IDatabaseFile db) { internal DownloaderTask(IDatabaseFile db) {
@ -45,7 +45,7 @@ sealed class DownloaderTask : IAsyncDisposable {
private async Task RunQueueWriterTask() { private async Task RunQueueWriterTask() {
while (await downloadQueue.Writer.WaitToWriteAsync(cancellationToken)) { while (await downloadQueue.Writer.WaitToWriteAsync(cancellationToken)) {
var newItems = db.PullEnqueuedDownloadItems(QueueSize); var newItems = await db.Downloads.PullEnqueuedDownloadItems(QueueSize, cancellationToken).ToListAsync(cancellationToken);
if (newItems.Count == 0) { if (newItems.Count == 0) {
await Task.Delay(TimeSpan.FromMilliseconds(50), cancellationToken); await Task.Delay(TimeSpan.FromMilliseconds(50), cancellationToken);
continue; continue;
@ -70,14 +70,14 @@ sealed class DownloaderTask : IAsyncDisposable {
try { try {
var downloadedBytes = await client.GetByteArrayAsync(item.DownloadUrl, cancellationToken); var downloadedBytes = await client.GetByteArrayAsync(item.DownloadUrl, cancellationToken);
db.AddDownload(Data.Download.NewSuccess(item, downloadedBytes)); await db.Downloads.AddDownload(Data.Download.NewSuccess(item, downloadedBytes));
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
// Ignore. // Ignore.
} catch (HttpRequestException e) { } catch (HttpRequestException e) {
db.AddDownload(Data.Download.NewFailure(item, e.StatusCode, item.Size)); await db.Downloads.AddDownload(Data.Download.NewFailure(item, e.StatusCode, item.Size));
log.Error(e); log.Error(e);
} catch (Exception e) { } catch (Exception e) {
db.AddDownload(Data.Download.NewFailure(item, null, item.Size)); await db.Downloads.AddDownload(Data.Download.NewFailure(item, null, item.Size));
log.Error(e); log.Error(e);
} finally { } finally {
try { try {

View File

@ -10,15 +10,15 @@ namespace DHT.Server.Endpoints;
sealed class GetAttachmentEndpoint : BaseEndpoint { sealed class GetAttachmentEndpoint : BaseEndpoint {
public GetAttachmentEndpoint(IDatabaseFile db) : base(db) {} public GetAttachmentEndpoint(IDatabaseFile db) : base(db) {}
protected override Task<IHttpOutput> Respond(HttpContext ctx) { protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
string attachmentUrl = WebUtility.UrlDecode((string) ctx.Request.RouteValues["url"]!); string attachmentUrl = WebUtility.UrlDecode((string) ctx.Request.RouteValues["url"]!);
DownloadedAttachment? maybeDownloadedAttachment = Db.GetDownloadedAttachment(attachmentUrl); DownloadedAttachment? maybeDownloadedAttachment = await Db.Downloads.GetDownloadedAttachment(attachmentUrl);
if (maybeDownloadedAttachment is {} downloadedAttachment) { if (maybeDownloadedAttachment is {} downloadedAttachment) {
return Task.FromResult<IHttpOutput>(new HttpOutput.File(downloadedAttachment.Type, downloadedAttachment.Data)); return new HttpOutput.File(downloadedAttachment.Type, downloadedAttachment.Data);
} }
else { else {
return Task.FromResult<IHttpOutput>(new HttpOutput.Redirect(attachmentUrl, permanent: false)); return new HttpOutput.Redirect(attachmentUrl, permanent: false);
} }
} }
} }

View File

@ -16,19 +16,19 @@ sealed class TrackChannelEndpoint : BaseEndpoint {
var server = ReadServer(root.RequireObject("server"), "server"); var server = ReadServer(root.RequireObject("server"), "server");
var channel = ReadChannel(root.RequireObject("channel"), "channel", server.Id); var channel = ReadChannel(root.RequireObject("channel"), "channel", server.Id);
Db.AddServer(server); await Db.Servers.Add([server]);
Db.AddChannel(channel); await Db.Channels.Add([channel]);
return HttpOutput.None; return HttpOutput.None;
} }
private static Data.Server ReadServer(JsonElement json, string path) => new() { private static Data.Server ReadServer(JsonElement json, string path) => new () {
Id = json.RequireSnowflake("id", path), Id = json.RequireSnowflake("id", path),
Name = json.RequireString("name", path), Name = json.RequireString("name", path),
Type = ServerTypes.FromString(json.RequireString("type", path)) ?? throw new HttpException(HttpStatusCode.BadRequest, "Server type must be either 'SERVER', 'GROUP', or 'DM'.") Type = ServerTypes.FromString(json.RequireString("type", path)) ?? throw new HttpException(HttpStatusCode.BadRequest, "Server type must be either 'SERVER', 'GROUP', or 'DM'.")
}; };
private static Channel ReadChannel(JsonElement json, string path, ulong serverId) => new() { private static Channel ReadChannel(JsonElement json, string path, ulong serverId) => new () {
Id = json.RequireSnowflake("id", path), Id = json.RequireSnowflake("id", path),
Server = serverId, Server = serverId,
Name = json.RequireString("name", path), Name = json.RequireString("name", path),

View File

@ -18,7 +18,7 @@ namespace DHT.Server.Endpoints;
sealed class TrackMessagesEndpoint : BaseEndpoint { sealed class TrackMessagesEndpoint : BaseEndpoint {
private const string HasNewMessages = "1"; private const string HasNewMessages = "1";
private const string NoNewMessages = "0"; private const string NoNewMessages = "0";
public TrackMessagesEndpoint(IDatabaseFile db) : base(db) {} public TrackMessagesEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) { protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
@ -39,14 +39,14 @@ sealed class TrackMessagesEndpoint : BaseEndpoint {
} }
var addedMessageFilter = new MessageFilter { MessageIds = addedMessageIds }; var addedMessageFilter = new MessageFilter { MessageIds = addedMessageIds };
bool anyNewMessages = Db.CountMessages(addedMessageFilter) < messages.Length; bool anyNewMessages = await Db.Messages.Count(addedMessageFilter) < addedMessageIds.Count;
Db.AddMessages(messages); await Db.Messages.Add(messages);
return new HttpOutput.Text(anyNewMessages ? HasNewMessages : NoNewMessages); return new HttpOutput.Text(anyNewMessages ? HasNewMessages : NoNewMessages);
} }
private static Message ReadMessage(JsonElement json, string path) => new() { private static Message ReadMessage(JsonElement json, string path) => new () {
Id = json.RequireSnowflake("id", path), Id = json.RequireSnowflake("id", path),
Sender = json.RequireSnowflake("sender", path), Sender = json.RequireSnowflake("sender", path),
Channel = json.RequireSnowflake("channel", path), Channel = json.RequireSnowflake("channel", path),
@ -54,9 +54,9 @@ sealed class TrackMessagesEndpoint : BaseEndpoint {
Timestamp = json.RequireLong("timestamp", path), Timestamp = json.RequireLong("timestamp", path),
EditTimestamp = json.HasKey("editTimestamp") ? json.RequireLong("editTimestamp", path) : null, EditTimestamp = json.HasKey("editTimestamp") ? json.RequireLong("editTimestamp", path) : null,
RepliedToId = json.HasKey("repliedToId") ? json.RequireSnowflake("repliedToId", path) : null, RepliedToId = json.HasKey("repliedToId") ? json.RequireSnowflake("repliedToId", path) : null,
Attachments = json.HasKey("attachments") ? ReadAttachments(json.RequireArray("attachments", path + ".attachments"), path + ".attachments[]").ToImmutableArray() : ImmutableArray<Attachment>.Empty, Attachments = json.HasKey("attachments") ? ReadAttachments(json.RequireArray("attachments", path + ".attachments"), path + ".attachments[]").ToImmutableList() : ImmutableList<Attachment>.Empty,
Embeds = json.HasKey("embeds") ? ReadEmbeds(json.RequireArray("embeds", path + ".embeds"), path + ".embeds[]").ToImmutableArray() : ImmutableArray<Embed>.Empty, Embeds = json.HasKey("embeds") ? ReadEmbeds(json.RequireArray("embeds", path + ".embeds"), path + ".embeds[]").ToImmutableList() : ImmutableList<Embed>.Empty,
Reactions = json.HasKey("reactions") ? ReadReactions(json.RequireArray("reactions", path + ".reactions"), path + ".reactions[]").ToImmutableArray() : ImmutableArray<Reaction>.Empty, Reactions = json.HasKey("reactions") ? ReadReactions(json.RequireArray("reactions", path + ".reactions"), path + ".reactions[]").ToImmutableList() : ImmutableList<Reaction>.Empty,
}; };
[SuppressMessage("ReSharper", "ConvertToLambdaExpression")] [SuppressMessage("ReSharper", "ConvertToLambdaExpression")]

View File

@ -25,12 +25,12 @@ sealed class TrackUsersEndpoint : BaseEndpoint {
users[i++] = ReadUser(user, "user"); users[i++] = ReadUser(user, "user");
} }
Db.AddUsers(users); await Db.Users.Add(users);
return HttpOutput.None; return HttpOutput.None;
} }
private static User ReadUser(JsonElement json, string path) => new() { private static User ReadUser(JsonElement json, string path) => new () {
Id = json.RequireSnowflake("id", path), Id = json.RequireSnowflake("id", path),
Name = json.RequireString("name", path), Name = json.RequireString("name", path),
AvatarUrl = json.HasKey("avatar") ? json.RequireString("avatar", path) : null, AvatarUrl = json.HasKey("avatar") ? json.RequireString("avatar", path) : null,

View File

@ -12,6 +12,7 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.Data.Sqlite" Version="8.0.0" /> <PackageReference Include="Microsoft.Data.Sqlite" Version="8.0.0" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
<PackageReference Include="System.Reactive" Version="6.0.0" /> <PackageReference Include="System.Reactive" Version="6.0.0" />
</ItemGroup> </ItemGroup>

View File

@ -15,7 +15,7 @@ public sealed class AsyncValueComputer<TValue> {
private SoftHardCancellationToken? currentCancellationTokenSource; private SoftHardCancellationToken? currentCancellationTokenSource;
private bool wasHardCancelled = false; private bool wasHardCancelled = false;
private Func<TValue>? currentComputeFunction; private Func<Task<TValue>>? currentComputeFunction;
private bool hasComputeFunctionChanged = false; private bool hasComputeFunctionChanged = false;
private AsyncValueComputer(Action<TValue> resultProcessor, TaskScheduler resultTaskScheduler, bool processOutdatedResults) { private AsyncValueComputer(Action<TValue> resultProcessor, TaskScheduler resultTaskScheduler, bool processOutdatedResults) {
@ -31,7 +31,7 @@ public sealed class AsyncValueComputer<TValue> {
} }
} }
public void Compute(Func<TValue> func) { public void Compute(Func<Task<TValue>> func) {
lock (stateLock) { lock (stateLock) {
wasHardCancelled = false; wasHardCancelled = false;
@ -47,7 +47,7 @@ public sealed class AsyncValueComputer<TValue> {
} }
[SuppressMessage("ReSharper", "MethodSupportsCancellation")] [SuppressMessage("ReSharper", "MethodSupportsCancellation")]
private void EnqueueComputation(Func<TValue> func) { private void EnqueueComputation(Func<Task<TValue>> func) {
var cancellationTokenSource = new SoftHardCancellationToken(); var cancellationTokenSource = new SoftHardCancellationToken();
currentCancellationTokenSource?.RequestSoftCancellation(); currentCancellationTokenSource?.RequestSoftCancellation();
@ -84,9 +84,9 @@ public sealed class AsyncValueComputer<TValue> {
public sealed class Single { public sealed class Single {
private readonly AsyncValueComputer<TValue> baseComputer; private readonly AsyncValueComputer<TValue> baseComputer;
private readonly Func<TValue> resultComputer; private readonly Func<Task<TValue>> resultComputer;
internal Single(AsyncValueComputer<TValue> baseComputer, Func<TValue> resultComputer) { internal Single(AsyncValueComputer<TValue> baseComputer, Func<Task<TValue>> resultComputer) {
this.baseComputer = baseComputer; this.baseComputer = baseComputer;
this.resultComputer = resultComputer; this.resultComputer = resultComputer;
} }
@ -94,6 +94,10 @@ public sealed class AsyncValueComputer<TValue> {
public void Recompute() { public void Recompute() {
baseComputer.Compute(resultComputer); baseComputer.Compute(resultComputer);
} }
public void Cancel() {
baseComputer.Cancel();
}
} }
public static Builder WithResultProcessor(Action<TValue> resultProcessor, TaskScheduler? scheduler = null) { public static Builder WithResultProcessor(Action<TValue> resultProcessor, TaskScheduler? scheduler = null) {
@ -119,7 +123,7 @@ public sealed class AsyncValueComputer<TValue> {
return new AsyncValueComputer<TValue>(resultProcessor, resultTaskScheduler, processOutdatedResults); return new AsyncValueComputer<TValue>(resultProcessor, resultTaskScheduler, processOutdatedResults);
} }
public Single BuildWithComputer(Func<TValue> resultComputer) { public Single BuildWithComputer(Func<Task<TValue>> resultComputer) {
return new Single(Build(), resultComputer); return new Single(Build(), resultComputer);
} }
} }

View File

@ -0,0 +1,45 @@
using System;
using System.Threading;
using System.Threading.Tasks;
namespace DHT.Utils.Tasks;
public sealed class RestartableTask<T>(Action<T> resultProcessor, TaskScheduler resultScheduler) {
private readonly object monitor = new ();
private CancellationTokenSource? cancellationTokenSource;
public void Restart(Func<CancellationToken, Task<T>> resultComputer) {
lock (monitor) {
Cancel();
cancellationTokenSource = new CancellationTokenSource();
var taskCancellationTokenSource = cancellationTokenSource;
var taskCancellationToken = taskCancellationTokenSource.Token;
Task.Run(() => resultComputer(taskCancellationToken), taskCancellationToken)
.ContinueWith(task => resultProcessor(task.Result), taskCancellationToken, TaskContinuationOptions.OnlyOnRanToCompletion, resultScheduler)
.ContinueWith(_ => OnTaskFinished(taskCancellationTokenSource), CancellationToken.None);
}
}
public void Cancel() {
lock (monitor) {
if (cancellationTokenSource != null) {
cancellationTokenSource.Cancel();
cancellationTokenSource = null;
}
}
}
private void OnTaskFinished(CancellationTokenSource taskCancellationTokenSource) {
lock (monitor) {
taskCancellationTokenSource.Dispose();
if (cancellationTokenSource == taskCancellationTokenSource) {
cancellationTokenSource = null;
}
}
}
}

View File

@ -0,0 +1,84 @@
using System;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
namespace DHT.Utils.Tasks;
public abstract class ThrottledTaskBase<T> : IDisposable {
private readonly Channel<Func<CancellationToken, T>> taskChannel = Channel.CreateBounded<Func<CancellationToken, T>>(new BoundedChannelOptions(capacity: 1) {
SingleReader = true,
SingleWriter = false,
AllowSynchronousContinuations = false,
FullMode = BoundedChannelFullMode.DropOldest
});
private readonly CancellationTokenSource cancellationTokenSource = new ();
internal ThrottledTaskBase() {}
protected async Task ReaderTask() {
var cancellationToken = cancellationTokenSource.Token;
try {
await foreach (var item in taskChannel.Reader.ReadAllAsync(cancellationToken)) {
try {
await Run(item, cancellationToken);
} catch (OperationCanceledException) {
throw;
} catch (Exception) {
// Ignore.
}
}
} catch (OperationCanceledException) {
// Ignore.
} finally {
cancellationTokenSource.Dispose();
}
}
protected abstract Task Run(Func<CancellationToken, T> func, CancellationToken cancellationToken);
public void Post(Func<CancellationToken, T> resultComputer) {
taskChannel.Writer.TryWrite(resultComputer);
}
public void Dispose() {
taskChannel.Writer.Complete();
cancellationTokenSource.Cancel();
}
}
public sealed class ThrottledTask : ThrottledTaskBase<Task> {
private readonly Action resultProcessor;
private readonly TaskScheduler resultScheduler;
public ThrottledTask(Action resultProcessor, TaskScheduler resultScheduler) {
this.resultProcessor = resultProcessor;
this.resultScheduler = resultScheduler;
Task.Run(ReaderTask);
}
protected override async Task Run(Func<CancellationToken, Task> func, CancellationToken cancellationToken) {
await func(cancellationToken);
await Task.Factory.StartNew(resultProcessor, cancellationToken, TaskCreationOptions.None, resultScheduler);
}
}
public sealed class ThrottledTask<T> : ThrottledTaskBase<Task<T>> {
private readonly Action<T> resultProcessor;
private readonly TaskScheduler resultScheduler;
public ThrottledTask(Action<T> resultProcessor, TaskScheduler resultScheduler) {
this.resultProcessor = resultProcessor;
this.resultScheduler = resultScheduler;
Task.Run(ReaderTask);
}
protected override async Task Run(Func<CancellationToken, Task<T>> func, CancellationToken cancellationToken) {
T result = await func(cancellationToken);
await Task.Factory.StartNew(() => resultProcessor(result), cancellationToken, TaskCreationOptions.None, resultScheduler);
}
}