1
0
mirror of https://github.com/chylex/Discord-History-Tracker.git synced 2024-11-25 14:42:44 +01:00

Compare commits

..

2 Commits

53 changed files with 1753 additions and 1197 deletions

View File

@ -11,6 +11,7 @@ using DHT.Desktop.Dialogs.Message;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Database.Exceptions; using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite; using DHT.Server.Database.Sqlite;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging; using DHT.Utils.Logging;
namespace DHT.Desktop.Common; namespace DHT.Desktop.Common;
@ -20,9 +21,9 @@ static class DatabaseGui {
private const string DatabaseFileInitialName = "archive.dht"; private const string DatabaseFileInitialName = "archive.dht";
private static readonly IReadOnlyList<FilePickerFileType> DatabaseFileDialogFilter = new List<FilePickerFileType> { private static readonly IReadOnlyList<FilePickerFileType> DatabaseFileDialogFilter = [
FileDialogs.CreateFilter("Discord History Tracker Database", new [] { "dht" }) FileDialogs.CreateFilter("Discord History Tracker Database", ["dht"])
}; ];
public static async Task<string[]> NewOpenDatabaseFilesDialog(Window window, string? suggestedDirectory) { public static async Task<string[]> NewOpenDatabaseFilesDialog(Window window, string? suggestedDirectory) {
return await window.StorageProvider.OpenFiles(new FilePickerOpenOptions { return await window.StorageProvider.OpenFiles(new FilePickerOpenOptions {

View File

@ -4,5 +4,6 @@ namespace DHT.Desktop.Dialogs.Progress;
interface IProgressCallback { interface IProgressCallback {
Task Update(string message, int finishedItems, int totalItems); Task Update(string message, int finishedItems, int totalItems);
Task UpdateIndeterminate(string message);
Task Hide(); Task Hide();
} }

View File

@ -40,7 +40,7 @@
<TextBlock DockPanel.Dock="Right" Text="{Binding Items}" Classes="items" /> <TextBlock DockPanel.Dock="Right" Text="{Binding Items}" Classes="items" />
<TextBlock DockPanel.Dock="Left" Text="{Binding Message}" /> <TextBlock DockPanel.Dock="Left" Text="{Binding Message}" />
</DockPanel> </DockPanel>
<ProgressBar Value="{Binding Progress}" /> <ProgressBar IsIndeterminate="{Binding IsIndeterminate}" Value="{Binding Progress}" />
</StackPanel> </StackPanel>
</DataTemplate> </DataTemplate>
</ItemsRepeater.ItemTemplate> </ItemsRepeater.ItemTemplate>

View File

@ -7,6 +7,58 @@ namespace DHT.Desktop.Dialogs.Progress;
[SuppressMessage("ReSharper", "MemberCanBeInternal")] [SuppressMessage("ReSharper", "MemberCanBeInternal")]
public sealed partial class ProgressDialog : Window { public sealed partial class ProgressDialog : Window {
internal static async Task Show(Window owner, string title, Func<ProgressDialog, IProgressCallback, Task> action) {
var taskCompletionSource = new TaskCompletionSource();
var dialog = new ProgressDialog();
dialog.DataContext = new ProgressDialogModel(title, async callbacks => {
try {
await action(dialog, callbacks[0]);
taskCompletionSource.SetResult();
} catch (Exception e) {
taskCompletionSource.SetException(e);
}
});
await dialog.ShowProgressDialog(owner);
await taskCompletionSource.Task;
}
internal static async Task ShowIndeterminate(Window owner, string title, string message, Func<ProgressDialog, Task> action) {
var taskCompletionSource = new TaskCompletionSource();
var dialog = new ProgressDialog();
dialog.DataContext = new ProgressDialogModel(title, async callbacks => {
await callbacks[0].UpdateIndeterminate(message);
try {
await action(dialog);
taskCompletionSource.SetResult();
} catch (Exception e) {
taskCompletionSource.SetException(e);
}
});
await dialog.ShowProgressDialog(owner);
await taskCompletionSource.Task;
}
internal static async Task<T> ShowIndeterminate<T>(Window owner, string title, string message, Func<ProgressDialog, Task<T>> action) {
var taskCompletionSource = new TaskCompletionSource<T>();
var dialog = new ProgressDialog();
dialog.DataContext = new ProgressDialogModel(title, async callbacks => {
await callbacks[0].UpdateIndeterminate(message);
try {
taskCompletionSource.SetResult(await action(dialog));
} catch (Exception e) {
taskCompletionSource.SetException(e);
}
});
await dialog.ShowProgressDialog(owner);
return await taskCompletionSource.Task;
}
private bool isFinished = false; private bool isFinished = false;
private Task progressTask = Task.CompletedTask; private Task progressTask = Task.CompletedTask;

View File

@ -18,9 +18,10 @@ sealed class ProgressDialogModel : BaseModel {
[Obsolete("Designer")] [Obsolete("Designer")]
public ProgressDialogModel() {} public ProgressDialogModel() {}
public ProgressDialogModel(TaskRunner task, int progressItems = 1) { public ProgressDialogModel(string title, TaskRunner task, int progressItems = 1) {
this.Items = Enumerable.Range(0, progressItems).Select(static _ => new ProgressItem()).ToArray(); this.Title = title;
this.task = task; this.task = task;
this.Items = Enumerable.Range(0, progressItems).Select(static _ => new ProgressItem()).ToArray();
} }
internal async Task StartTask() { internal async Task StartTask() {
@ -43,6 +44,16 @@ sealed class ProgressDialogModel : BaseModel {
item.Message = message; item.Message = message;
item.Items = totalItems == 0 ? string.Empty : finishedItems.Format() + " / " + totalItems.Format(); item.Items = totalItems == 0 ? string.Empty : finishedItems.Format() + " / " + totalItems.Format();
item.Progress = totalItems == 0 ? 0 : 100 * finishedItems / totalItems; item.Progress = totalItems == 0 ? 0 : 100 * finishedItems / totalItems;
item.IsIndeterminate = false;
});
}
public async Task UpdateIndeterminate(string message) {
await Dispatcher.UIThread.InvokeAsync(() => {
item.Message = message;
item.Items = string.Empty;
item.Progress = 0;
item.IsIndeterminate = true;
}); });
} }

View File

@ -38,4 +38,11 @@ sealed class ProgressItem : BaseModel {
get => progress; get => progress;
set => Change(ref progress, value); set => Change(ref progress, value);
} }
private bool isIndeterminate;
public bool IsIndeterminate {
get => isIndeterminate;
set => Change(ref isIndeterminate, value);
}
} }

View File

@ -39,7 +39,8 @@ static class DiscordAppSettings {
public static async Task<bool?> AreDevToolsEnabled() { public static async Task<bool?> AreDevToolsEnabled() {
try { try {
return AreDevToolsEnabled(await ReadSettingsJson()); var settingsJson = await ReadSettingsJson().ConfigureAwait(false);
return AreDevToolsEnabled(settingsJson);
} catch (Exception e) { } catch (Exception e) {
Log.Error("Cannot read settings file."); Log.Error("Cannot read settings file.");
Log.Error(e); Log.Error(e);

View File

@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.Threading.Tasks;
using DHT.Desktop.Common; using DHT.Desktop.Common;
using DHT.Server; using DHT.Server;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
@ -13,17 +14,17 @@ namespace DHT.Desktop.Main.Controls;
sealed class AttachmentFilterPanelModel : BaseModel, IDisposable { sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
public sealed record Unit(string Name, uint Scale); public sealed record Unit(string Name, uint Scale);
private static readonly Unit[] AllUnits = { private static readonly Unit[] AllUnits = [
new ("B", 1), new Unit("B", 1),
new ("kB", 1024), new Unit("kB", 1024),
new ("MB", 1024 * 1024) new Unit("MB", 1024 * 1024)
}; ];
private static readonly HashSet<string> FilterProperties = new () { private static readonly HashSet<string> FilterProperties = [
nameof(LimitSize), nameof(LimitSize),
nameof(MaximumSize), nameof(MaximumSize),
nameof(MaximumSizeUnit) nameof(MaximumSizeUnit)
}; ];
public string FilterStatisticsText { get; private set; } = ""; public string FilterStatisticsText { get; private set; } = "";
@ -51,7 +52,7 @@ sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
private readonly State state; private readonly State state;
private readonly string verb; private readonly string verb;
private readonly AsyncValueComputer<long> matchingAttachmentCountComputer; private readonly RestartableTask<long> matchingAttachmentCountTask;
private long? matchingAttachmentCount; private long? matchingAttachmentCount;
private long? totalAttachmentCount; private long? totalAttachmentCount;
@ -62,7 +63,7 @@ sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
this.state = state; this.state = state;
this.verb = verb; this.verb = verb;
this.matchingAttachmentCountComputer = AsyncValueComputer<long>.WithResultProcessor(SetAttachmentCounts).Build(); this.matchingAttachmentCountTask = new RestartableTask<long>(SetAttachmentCounts, TaskScheduler.FromCurrentSynchronizationContext());
UpdateFilterStatistics(); UpdateFilterStatistics();
@ -90,14 +91,14 @@ sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
private void UpdateFilterStatistics() { private void UpdateFilterStatistics() {
var filter = CreateFilter(); var filter = CreateFilter();
if (filter.IsEmpty) { if (filter.IsEmpty) {
matchingAttachmentCountComputer.Cancel(); matchingAttachmentCountTask.Cancel();
matchingAttachmentCount = totalAttachmentCount; matchingAttachmentCount = totalAttachmentCount;
UpdateFilterStatisticsText(); UpdateFilterStatisticsText();
} }
else { else {
matchingAttachmentCount = null; matchingAttachmentCount = null;
UpdateFilterStatisticsText(); UpdateFilterStatisticsText();
matchingAttachmentCountComputer.Compute(() => state.Db.CountAttachments(filter)); matchingAttachmentCountTask.Restart(cancellationToken => state.Db.Downloads.CountAttachments(filter, cancellationToken));
} }
} }
@ -115,7 +116,7 @@ sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
} }
public AttachmentFilter CreateFilter() { public AttachmentFilter CreateFilter() {
AttachmentFilter filter = new(); AttachmentFilter filter = new ();
if (LimitSize) { if (LimitSize) {
try { try {

View File

@ -8,6 +8,7 @@ using Avalonia.Controls;
using DHT.Desktop.Common; using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.CheckBox; using DHT.Desktop.Dialogs.CheckBox;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Server; using DHT.Server;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
@ -18,7 +19,7 @@ using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Controls; namespace DHT.Desktop.Main.Controls;
sealed class MessageFilterPanelModel : BaseModel, IDisposable { sealed class MessageFilterPanelModel : BaseModel, IDisposable {
private static readonly HashSet<string> FilterProperties = new () { private static readonly HashSet<string> FilterProperties = [
nameof(FilterByDate), nameof(FilterByDate),
nameof(StartDate), nameof(StartDate),
nameof(EndDate), nameof(EndDate),
@ -26,7 +27,7 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
nameof(IncludedChannels), nameof(IncludedChannels),
nameof(FilterByUser), nameof(FilterByUser),
nameof(IncludedUsers) nameof(IncludedUsers)
}; ];
public string FilterStatisticsText { get; private set; } = ""; public string FilterStatisticsText { get; private set; } = "";
@ -62,8 +63,8 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
set => Change(ref filterByChannel, value); set => Change(ref filterByChannel, value);
} }
public HashSet<ulong> IncludedChannels { public HashSet<ulong>? IncludedChannels {
get => includedChannels ?? state.Db.GetAllChannels().Select(static channel => channel.Id).ToHashSet(); get => includedChannels;
set => Change(ref includedChannels, value); set => Change(ref includedChannels, value);
} }
@ -72,8 +73,8 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
set => Change(ref filterByUser, value); set => Change(ref filterByUser, value);
} }
public HashSet<ulong> IncludedUsers { public HashSet<ulong>? IncludedUsers {
get => includedUsers ?? state.Db.GetAllUsers().Select(static user => user.Id).ToHashSet(); get => includedUsers;
set => Change(ref includedUsers, value); set => Change(ref includedUsers, value);
} }
@ -95,7 +96,7 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
private readonly State state; private readonly State state;
private readonly string verb; private readonly string verb;
private readonly AsyncValueComputer<long> exportedMessageCountComputer; private readonly RestartableTask<long> exportedMessageCountTask;
private long? exportedMessageCount; private long? exportedMessageCount;
private long? totalMessageCount; private long? totalMessageCount;
@ -107,7 +108,7 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
this.state = state; this.state = state;
this.verb = verb; this.verb = verb;
this.exportedMessageCountComputer = AsyncValueComputer<long>.WithResultProcessor(SetExportedMessageCount).Build(); this.exportedMessageCountTask = new RestartableTask<long>(SetExportedMessageCount, TaskScheduler.FromCurrentSynchronizationContext());
UpdateFilterStatistics(); UpdateFilterStatistics();
UpdateChannelFilterLabel(); UpdateChannelFilterLabel();
@ -118,6 +119,7 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
} }
public void Dispose() { public void Dispose() {
exportedMessageCountTask.Cancel();
state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged; state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
} }
@ -148,17 +150,29 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
} }
} }
private void UpdateChannelFilterLabel() {
long total = state.Db.Statistics.TotalChannels;
long included = FilterByChannel && IncludedChannels != null ? IncludedChannels.Count : total;
ChannelFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("channel") + ".";
}
private void UpdateUserFilterLabel() {
long total = state.Db.Statistics.TotalUsers;
long included = FilterByUser && IncludedUsers != null ? IncludedUsers.Count : total;
UserFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("user") + ".";
}
private void UpdateFilterStatistics() { private void UpdateFilterStatistics() {
var filter = CreateFilter(); var filter = CreateFilter();
if (filter.IsEmpty) { if (filter.IsEmpty) {
exportedMessageCountComputer.Cancel(); exportedMessageCountTask.Cancel();
exportedMessageCount = totalMessageCount; exportedMessageCount = totalMessageCount;
UpdateFilterStatisticsText(); UpdateFilterStatisticsText();
} }
else { else {
exportedMessageCount = null; exportedMessageCount = null;
UpdateFilterStatisticsText(); UpdateFilterStatisticsText();
exportedMessageCountComputer.Compute(() => state.Db.CountMessages(filter)); exportedMessageCountTask.Restart(cancellationToken => state.Db.Messages.Count(filter, cancellationToken));
} }
} }
@ -175,12 +189,12 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
OnPropertyChanged(nameof(FilterStatisticsText)); OnPropertyChanged(nameof(FilterStatisticsText));
} }
public async void OpenChannelFilterDialog() { public async Task OpenChannelFilterDialog() {
var servers = state.Db.GetAllServers().ToDictionary(static server => server.Id); async Task<List<CheckBoxItem<ulong>>> PrepareChannelItems(ProgressDialog dialog) {
var items = new List<CheckBoxItem<ulong>>(); var items = new List<CheckBoxItem<ulong>>();
var included = IncludedChannels; var servers = await state.Db.Servers.Get().ToDictionaryAsync(static server => server.Id);
foreach (var channel in state.Db.GetAllChannels()) { await foreach (var channel in state.Db.Channels.Get()) {
var channelId = channel.Id; var channelId = channel.Id;
var channelName = channel.Name; var channelName = channel.Name;
@ -210,68 +224,63 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
items.Add(new CheckBoxItem<ulong>(channelId) { items.Add(new CheckBoxItem<ulong>(channelId) {
Title = title, Title = title,
Checked = included.Contains(channelId) Checked = IncludedChannels == null || IncludedChannels.Contains(channelId)
}); });
} }
var result = await OpenIdFilterDialog(window, "Included Channels", items); return items;
}
const string Title = "Included Channels";
List<CheckBoxItem<ulong>> items;
try {
items = await ProgressDialog.ShowIndeterminate(window, Title, "Loading channels...", PrepareChannelItems);
} catch (Exception e) {
await Dialog.ShowOk(window, Title, "Error loading channels: " + e.Message);
return;
}
var result = await OpenIdFilterDialog(Title, items);
if (result != null) { if (result != null) {
IncludedChannels = result; IncludedChannels = result;
} }
} }
public async void OpenUserFilterDialog() { public async Task OpenUserFilterDialog() {
var items = new List<CheckBoxItem<ulong>>(); async Task<List<CheckBoxItem<ulong>>> PrepareUserItems(ProgressDialog dialog) {
var included = IncludedUsers; var checkBoxItems = new List<CheckBoxItem<ulong>>();
foreach (var user in state.Db.GetAllUsers()) { await foreach (var user in state.Db.Users.Get()) {
var name = user.Name; var name = user.Name;
var discriminator = user.Discriminator; var discriminator = user.Discriminator;
items.Add(new CheckBoxItem<ulong>(user.Id) { checkBoxItems.Add(new CheckBoxItem<ulong>(user.Id) {
Title = discriminator == null ? name : name + " #" + discriminator, Title = discriminator == null ? name : name + " #" + discriminator,
Checked = included.Contains(user.Id) Checked = IncludedUsers == null || IncludedUsers.Contains(user.Id)
}); });
} }
var result = await OpenIdFilterDialog(window, "Included Users", items); return checkBoxItems;
}
const string Title = "Included Users";
List<CheckBoxItem<ulong>> items;
try {
items = await ProgressDialog.ShowIndeterminate(window, Title, "Loading users...", PrepareUserItems);
} catch (Exception e) {
await Dialog.ShowOk(window, Title, "Error loading users: " + e.Message);
return;
}
var result = await OpenIdFilterDialog(Title, items);
if (result != null) { if (result != null) {
IncludedUsers = result; IncludedUsers = result;
} }
} }
private void UpdateChannelFilterLabel() { private async Task<HashSet<ulong>?> OpenIdFilterDialog(string title, List<CheckBoxItem<ulong>> items) {
long total = state.Db.Statistics.TotalChannels;
long included = FilterByChannel ? IncludedChannels.Count : total;
ChannelFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("channel") + ".";
}
private void UpdateUserFilterLabel() {
long total = state.Db.Statistics.TotalUsers;
long included = FilterByUser ? IncludedUsers.Count : total;
UserFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("user") + ".";
}
public MessageFilter CreateFilter() {
MessageFilter filter = new();
if (FilterByDate) {
filter.StartDate = StartDate;
filter.EndDate = EndDate?.AddDays(1).AddMilliseconds(-1);
}
if (FilterByChannel) {
filter.ChannelIds = new HashSet<ulong>(IncludedChannels);
}
if (FilterByUser) {
filter.UserIds = new HashSet<ulong>(IncludedUsers);
}
return filter;
}
private static async Task<HashSet<ulong>?> OpenIdFilterDialog(Window window, string title, List<CheckBoxItem<ulong>> items) {
items.Sort(static (item1, item2) => item1.Title.CompareTo(item2.Title)); items.Sort(static (item1, item2) => item1.Title.CompareTo(item2.Title));
var model = new CheckBoxDialogModel<ulong>(items) { var model = new CheckBoxDialogModel<ulong>(items) {
@ -283,4 +292,23 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
return result == DialogResult.OkCancel.Ok ? model.SelectedItems.Select(static item => item.Item).ToHashSet() : null; return result == DialogResult.OkCancel.Ok ? model.SelectedItems.Select(static item => item.Item).ToHashSet() : null;
} }
public MessageFilter CreateFilter() {
MessageFilter filter = new ();
if (FilterByDate) {
filter.StartDate = StartDate;
filter.EndDate = EndDate?.AddDays(1).AddMilliseconds(-1);
}
if (FilterByChannel && IncludedChannels != null) {
filter.ChannelIds = new HashSet<ulong>(IncludedChannels);
}
if (FilterByUser && IncludedUsers != null) {
filter.UserIds = new HashSet<ulong>(IncludedUsers);
}
return filter;
}
} }

View File

@ -121,9 +121,10 @@ sealed class ServerConfigurationPanelModel : BaseModel, IDisposable {
ServerConfiguration.Port = port; ServerConfiguration.Port = port;
ServerConfiguration.Token = inputToken; ServerConfiguration.Token = inputToken;
await StartServer();
OnPropertyChanged(nameof(HasMadeChanges)); OnPropertyChanged(nameof(HasMadeChanges));
await StartServer();
} }
public void OnClickCancelChanges() { public void OnClickCancelChanges() {

View File

@ -1,5 +1,4 @@
using System; using System;
using System.ComponentModel;
using System.IO; using System.IO;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -8,6 +7,7 @@ using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Main.Screens; using DHT.Desktop.Main.Screens;
using DHT.Desktop.Server; using DHT.Desktop.Server;
using DHT.Server; using DHT.Server;
using DHT.Server.Database;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using DHT.Utils.Models; using DHT.Utils.Models;
@ -25,7 +25,6 @@ sealed class MainWindowModel : BaseModel, IAsyncDisposable {
private readonly WelcomeScreen welcomeScreen; private readonly WelcomeScreen welcomeScreen;
private readonly WelcomeScreenModel welcomeScreenModel; private readonly WelcomeScreenModel welcomeScreenModel;
private MainContentScreen? mainContentScreen;
private MainContentScreenModel? mainContentScreenModel; private MainContentScreenModel? mainContentScreenModel;
private readonly Window window; private readonly Window window;
@ -39,11 +38,11 @@ sealed class MainWindowModel : BaseModel, IAsyncDisposable {
this.window = window; this.window = window;
welcomeScreenModel = new WelcomeScreenModel(window); welcomeScreenModel = new WelcomeScreenModel(window);
welcomeScreenModel.DatabaseSelected += OnDatabaseSelected;
welcomeScreen = new WelcomeScreen { DataContext = welcomeScreenModel }; welcomeScreen = new WelcomeScreen { DataContext = welcomeScreenModel };
CurrentScreen = welcomeScreen; CurrentScreen = welcomeScreen;
welcomeScreenModel.PropertyChanged += WelcomeScreenModelOnPropertyChanged;
var dbFile = args.DatabaseFile; var dbFile = args.DatabaseFile;
if (!string.IsNullOrWhiteSpace(dbFile)) { if (!string.IsNullOrWhiteSpace(dbFile)) {
async void OnWindowOpened(object? o, EventArgs eventArgs) { async void OnWindowOpened(object? o, EventArgs eventArgs) {
@ -74,26 +73,12 @@ sealed class MainWindowModel : BaseModel, IAsyncDisposable {
} }
} }
private async void WelcomeScreenModelOnPropertyChanged(object? sender, PropertyChangedEventArgs e) { private async void OnDatabaseSelected(object? sender, IDatabaseFile db) {
if (e.PropertyName == nameof(welcomeScreenModel.Db)) { welcomeScreenModel.DatabaseSelected -= OnDatabaseSelected;
if (mainContentScreenModel != null) {
mainContentScreenModel.DatabaseClosed -= MainContentScreenModelOnDatabaseClosed;
mainContentScreenModel.Dispose();
}
if (state != null) { await DisposeState();
await state.DisposeAsync();
}
if (welcomeScreenModel.Db == null) { state = new State(db);
state = null;
Title = DefaultTitle;
mainContentScreenModel = null;
mainContentScreen = null;
CurrentScreen = welcomeScreen;
}
else {
state = new State(welcomeScreenModel.Db);
try { try {
await state.Server.Start(ServerConfiguration.Port, ServerConfiguration.Token); await state.Server.Start(ServerConfiguration.Port, ServerConfiguration.Token);
@ -102,33 +87,45 @@ sealed class MainWindowModel : BaseModel, IAsyncDisposable {
await Dialog.ShowOk(window, "Internal Server Error", ex.Message); await Dialog.ShowOk(window, "Internal Server Error", ex.Message);
} }
Title = Path.GetFileName(state.Db.Path) + " - " + DefaultTitle;
mainContentScreenModel = new MainContentScreenModel(window, state); mainContentScreenModel = new MainContentScreenModel(window, state);
await mainContentScreenModel.Initialize();
mainContentScreenModel.DatabaseClosed += MainContentScreenModelOnDatabaseClosed; mainContentScreenModel.DatabaseClosed += MainContentScreenModelOnDatabaseClosed;
mainContentScreen = new MainContentScreen { DataContext = mainContentScreenModel };
CurrentScreen = mainContentScreen;
}
OnPropertyChanged(nameof(CurrentScreen)); Title = Path.GetFileName(state.Db.Path) + " - " + DefaultTitle;
CurrentScreen = new MainContentScreen { DataContext = mainContentScreenModel };
OnPropertyChanged(nameof(Title)); OnPropertyChanged(nameof(Title));
OnPropertyChanged(nameof(CurrentScreen));
window.Focus(); window.Focus();
} }
private async void MainContentScreenModelOnDatabaseClosed(object? sender, EventArgs e) {
if (mainContentScreenModel != null) {
mainContentScreenModel.DatabaseClosed -= MainContentScreenModelOnDatabaseClosed;
mainContentScreenModel.Dispose();
mainContentScreenModel = null;
} }
private void MainContentScreenModelOnDatabaseClosed(object? sender, EventArgs e) { await DisposeState();
welcomeScreenModel.CloseDatabase();
Title = DefaultTitle;
CurrentScreen = welcomeScreen;
welcomeScreenModel.DatabaseSelected += OnDatabaseSelected;
OnPropertyChanged(nameof(Title));
OnPropertyChanged(nameof(CurrentScreen));
} }
public async ValueTask DisposeAsync() { private async Task DisposeState() {
mainContentScreenModel?.Dispose();
if (state != null) { if (state != null) {
await state.DisposeAsync(); await state.DisposeAsync();
state = null; state = null;
} }
}
welcomeScreenModel.Dispose(); public async ValueTask DisposeAsync() {
mainContentScreenModel?.Dispose();
await DisposeState();
} }
} }

View File

@ -1,6 +1,8 @@
using System; using System;
using System.Threading.Tasks;
using Avalonia.Controls; using Avalonia.Controls;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Desktop.Main.Controls; using DHT.Desktop.Main.Controls;
using DHT.Server; using DHT.Server;
using DHT.Utils.Models; using DHT.Utils.Models;
@ -27,8 +29,9 @@ sealed class AdvancedPageModel : BaseModel, IDisposable {
ServerConfigurationModel.Dispose(); ServerConfigurationModel.Dispose();
} }
public async void VacuumDatabase() { public async Task VacuumDatabase() {
state.Db.Vacuum(); const string Title = "Vacuum Database";
await Dialog.ShowOk(window, "Vacuum Database", "Done."); await ProgressDialog.ShowIndeterminate(window, Title, "Vacuuming database...", _ => state.Db.Vacuum());
await Dialog.ShowOk(window, Title, "Done.");
} }
} }

View File

@ -48,7 +48,7 @@
</DataGrid> </DataGrid>
</Expander> </Expander>
<StackPanel Orientation="Horizontal" Spacing="10"> <StackPanel Orientation="Horizontal" Spacing="10">
<Button Command="{Binding OnClickRetryFailedDownloads}" IsEnabled="{Binding HasFailedDownloads}">Retry Failed Downloads</Button> <Button Command="{Binding OnClickRetryFailedDownloads}" IsEnabled="{Binding IsRetryFailedOnDownloadsButtonEnabled}">Retry Failed Downloads</Button>
</StackPanel> </StackPanel>
</StackPanel> </StackPanel>
</StackPanel> </StackPanel>

View File

@ -11,12 +11,15 @@ using DHT.Server.Data;
using DHT.Server.Data.Aggregations; using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Utils.Logging;
using DHT.Utils.Models; using DHT.Utils.Models;
using DHT.Utils.Tasks; using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Pages; namespace DHT.Desktop.Main.Pages;
sealed class AttachmentsPageModel : BaseModel, IDisposable { sealed class AttachmentsPageModel : BaseModel, IDisposable {
private static readonly Log Log = Log.ForType<AttachmentsPageModel>();
private static readonly DownloadItemFilter EnqueuedItemFilter = new () { private static readonly DownloadItemFilter EnqueuedItemFilter = new () {
IncludeStatuses = new HashSet<DownloadStatus> { IncludeStatuses = new HashSet<DownloadStatus> {
DownloadStatus.Enqueued, DownloadStatus.Enqueued,
@ -24,14 +27,26 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
} }
}; };
private bool isThreadDownloadButtonEnabled = true; private bool isToggleDownloadButtonEnabled = true;
public bool IsToggleDownloadButtonEnabled {
get => isToggleDownloadButtonEnabled;
set => Change(ref isToggleDownloadButtonEnabled, value);
}
public string ToggleDownloadButtonText => IsDownloading ? "Stop Downloading" : "Start Downloading"; public string ToggleDownloadButtonText => IsDownloading ? "Stop Downloading" : "Start Downloading";
public bool IsToggleDownloadButtonEnabled { private bool isRetryingFailedDownloads = false;
get => isThreadDownloadButtonEnabled;
set => Change(ref isThreadDownloadButtonEnabled, value); public bool IsRetryingFailedDownloads {
get => isRetryingFailedDownloads;
set {
isRetryingFailedDownloads = value;
OnPropertyChanged(nameof(IsRetryFailedOnDownloadsButtonEnabled));
} }
}
public bool IsRetryFailedOnDownloadsButtonEnabled => !IsRetryingFailedDownloads && HasFailedDownloads;
public string DownloadMessage { get; set; } = ""; public string DownloadMessage { get; set; } = "";
public double DownloadProgress => totalItemsToDownloadCount is null or 0 ? 0.0 : 100.0 * doneItemsCount / totalItemsToDownloadCount.Value; public double DownloadProgress => totalItemsToDownloadCount is null or 0 ? 0.0 : 100.0 * doneItemsCount / totalItemsToDownloadCount.Value;
@ -43,22 +58,19 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
private readonly StatisticsRow statisticsFailed = new ("Failed"); private readonly StatisticsRow statisticsFailed = new ("Failed");
private readonly StatisticsRow statisticsSkipped = new ("Skipped"); private readonly StatisticsRow statisticsSkipped = new ("Skipped");
public List<StatisticsRow> StatisticsRows { public List<StatisticsRow> StatisticsRows => [
get {
return new List<StatisticsRow> {
statisticsEnqueued, statisticsEnqueued,
statisticsDownloaded, statisticsDownloaded,
statisticsFailed, statisticsFailed,
statisticsSkipped statisticsSkipped
}; ];
}
}
public bool IsDownloading => state.Downloader.IsDownloading; public bool IsDownloading => state.Downloader.IsDownloading;
public bool HasFailedDownloads => statisticsFailed.Items > 0; public bool HasFailedDownloads => statisticsFailed.Items > 0;
private readonly State state; private readonly State state;
private readonly AsyncValueComputer<DownloadStatusStatistics>.Single downloadStatisticsComputer; private readonly ThrottledTask enqueueDownloadItemsTask;
private readonly ThrottledTask<DownloadStatusStatistics> downloadStatisticsTask;
private IDisposable? finishedItemsSubscription; private IDisposable? finishedItemsSubscription;
private int doneItemsCount; private int doneItemsCount;
@ -72,14 +84,16 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
FilterModel = new AttachmentFilterPanelModel(state); FilterModel = new AttachmentFilterPanelModel(state);
downloadStatisticsComputer = AsyncValueComputer<DownloadStatusStatistics>.WithResultProcessor(UpdateStatistics).WithOutdatedResults().BuildWithComputer(state.Db.GetDownloadStatusStatistics); enqueueDownloadItemsTask = new ThrottledTask(RecomputeDownloadStatistics, TaskScheduler.FromCurrentSynchronizationContext());
downloadStatisticsComputer.Recompute(); downloadStatisticsTask = new ThrottledTask<DownloadStatusStatistics>(UpdateStatistics, TaskScheduler.FromCurrentSynchronizationContext());
RecomputeDownloadStatistics();
state.Db.Statistics.PropertyChanged += OnDbStatisticsChanged; state.Db.Statistics.PropertyChanged += OnDbStatisticsChanged;
} }
public void Dispose() { public void Dispose() {
state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged; state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
downloadStatisticsTask.Dispose();
finishedItemsSubscription?.Dispose(); finishedItemsSubscription?.Dispose();
FilterModel.Dispose(); FilterModel.Dispose();
} }
@ -87,23 +101,34 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) { private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) {
if (e.PropertyName == nameof(DatabaseStatistics.TotalAttachments)) { if (e.PropertyName == nameof(DatabaseStatistics.TotalAttachments)) {
if (IsDownloading) { if (IsDownloading) {
EnqueueDownloadItems(); EnqueueDownloadItemsLater();
} }
else { else {
downloadStatisticsComputer.Recompute(); RecomputeDownloadStatistics();
} }
} }
else if (e.PropertyName == nameof(DatabaseStatistics.TotalDownloads)) { else if (e.PropertyName == nameof(DatabaseStatistics.TotalDownloads)) {
downloadStatisticsComputer.Recompute(); RecomputeDownloadStatistics();
} }
} }
private void EnqueueDownloadItems() { private async Task EnqueueDownloadItems() {
await state.Db.Downloads.EnqueueDownloadItems(CreateAttachmentFilter());
}
private void EnqueueDownloadItemsLater() {
var filter = CreateAttachmentFilter();
enqueueDownloadItemsTask.Post(cancellationToken => state.Db.Downloads.EnqueueDownloadItems(filter, cancellationToken));
}
private AttachmentFilter CreateAttachmentFilter() {
var filter = FilterModel.CreateFilter(); var filter = FilterModel.CreateFilter();
filter.DownloadItemRule = AttachmentFilter.DownloadItemRules.OnlyNotPresent; filter.DownloadItemRule = AttachmentFilter.DownloadItemRules.OnlyNotPresent;
state.Db.EnqueueDownloadItems(filter); return filter;
}
downloadStatisticsComputer.Recompute(); private void RecomputeDownloadStatistics() {
downloadStatisticsTask.Post(state.Db.Downloads.GetStatistics);
} }
private void UpdateStatistics(DownloadStatusStatistics statusStatistics) { private void UpdateStatistics(DownloadStatusStatistics statusStatistics) {
@ -125,6 +150,7 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
if (hadFailedDownloads != HasFailedDownloads) { if (hadFailedDownloads != HasFailedDownloads) {
OnPropertyChanged(nameof(HasFailedDownloads)); OnPropertyChanged(nameof(HasFailedDownloads));
OnPropertyChanged(nameof(IsRetryFailedOnDownloadsButtonEnabled));
} }
totalItemsToDownloadCount = statisticsEnqueued.Items + statisticsDownloaded.Items + statisticsFailed.Items - initialFinishedCount; totalItemsToDownloadCount = statisticsEnqueued.Items + statisticsDownloaded.Items + statisticsFailed.Items - initialFinishedCount;
@ -138,12 +164,6 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
OnPropertyChanged(nameof(DownloadProgress)); OnPropertyChanged(nameof(DownloadProgress));
} }
private void OnItemsFinished(int finishedItemCount) {
doneItemsCount += finishedItemCount;
UpdateDownloadMessage();
downloadStatisticsComputer.Recompute();
}
public async Task OnClickToggleDownload() { public async Task OnClickToggleDownload() {
IsToggleDownloadButtonEnabled = false; IsToggleDownloadButtonEnabled = false;
@ -153,9 +173,9 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
finishedItemsSubscription?.Dispose(); finishedItemsSubscription?.Dispose();
finishedItemsSubscription = null; finishedItemsSubscription = null;
downloadStatisticsComputer.Recompute(); RecomputeDownloadStatistics();
state.Db.RemoveDownloadItems(EnqueuedItemFilter, FilterRemovalMode.RemoveMatching); await state.Db.Downloads.RemoveDownloadItems(EnqueuedItemFilter, FilterRemovalMode.RemoveMatching);
doneItemsCount = 0; doneItemsCount = 0;
initialFinishedCount = 0; initialFinishedCount = 0;
@ -173,7 +193,7 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
.ObserveOn(AvaloniaScheduler.Instance) .ObserveOn(AvaloniaScheduler.Instance)
.Subscribe(OnItemsFinished); .Subscribe(OnItemsFinished);
EnqueueDownloadItems(); await EnqueueDownloadItems();
} }
OnPropertyChanged(nameof(ToggleDownloadButtonText)); OnPropertyChanged(nameof(ToggleDownloadButtonText));
@ -181,7 +201,16 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
IsToggleDownloadButtonEnabled = true; IsToggleDownloadButtonEnabled = true;
} }
public void OnClickRetryFailedDownloads() { private void OnItemsFinished(int finishedItemCount) {
doneItemsCount += finishedItemCount;
UpdateDownloadMessage();
RecomputeDownloadStatistics();
}
public async Task OnClickRetryFailedDownloads() {
IsRetryingFailedDownloads = true;
try {
var allExceptFailedFilter = new DownloadItemFilter { var allExceptFailedFilter = new DownloadItemFilter {
IncludeStatuses = new HashSet<DownloadStatus> { IncludeStatuses = new HashSet<DownloadStatus> {
DownloadStatus.Enqueued, DownloadStatus.Enqueued,
@ -190,10 +219,15 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
} }
}; };
state.Db.RemoveDownloadItems(allExceptFailedFilter, FilterRemovalMode.KeepMatching); await state.Db.Downloads.RemoveDownloadItems(allExceptFailedFilter, FilterRemovalMode.KeepMatching);
if (IsDownloading) { if (IsDownloading) {
EnqueueDownloadItems(); await EnqueueDownloadItems();
}
} catch (Exception e) {
Log.Error(e);
} finally {
IsRetryingFailedDownloads = false;
} }
} }

View File

@ -18,7 +18,7 @@ using DHT.Server;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Database.Import; using DHT.Server.Database.Import;
using DHT.Server.Database.Sqlite; using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using DHT.Utils.Models; using DHT.Utils.Models;
@ -41,7 +41,7 @@ sealed class DatabasePageModel : BaseModel {
this.Db = state.Db; this.Db = state.Db;
} }
public async void OpenDatabaseFolder() { public async Task OpenDatabaseFolder() {
string file = Db.Path; string file = Db.Path;
string? folder = Path.GetDirectoryName(file); string? folder = Path.GetDirectoryName(file);
@ -72,18 +72,11 @@ sealed class DatabasePageModel : BaseModel {
DatabaseClosed?.Invoke(this, EventArgs.Empty); DatabaseClosed?.Invoke(this, EventArgs.Empty);
} }
public async void MergeWithDatabase() { public async Task MergeWithDatabase() {
var paths = await DatabaseGui.NewOpenDatabaseFilesDialog(window, Path.GetDirectoryName(Db.Path)); var paths = await DatabaseGui.NewOpenDatabaseFilesDialog(window, Path.GetDirectoryName(Db.Path));
if (paths.Length == 0) { if (paths.Length > 0) {
return; await ProgressDialog.Show(window, "Database Merge", async (dialog, callback) => await MergeWithDatabaseFromPaths(Db, paths, dialog, callback));
} }
ProgressDialog progressDialog = new ProgressDialog();
progressDialog.DataContext = new ProgressDialogModel(async callbacks => await MergeWithDatabaseFromPaths(Db, paths, progressDialog, callbacks[0])) {
Title = "Database Merge"
};
await progressDialog.ShowProgressDialog(window);
} }
private static async Task MergeWithDatabaseFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) { private static async Task MergeWithDatabaseFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) {
@ -97,7 +90,7 @@ sealed class DatabasePageModel : BaseModel {
} }
try { try {
target.AddFrom(db); await target.AddFrom(db);
return true; return true;
} finally { } finally {
db.Dispose(); db.Dispose();
@ -140,23 +133,16 @@ sealed class DatabasePageModel : BaseModel {
} }
} }
public async void ImportLegacyArchive() { public async Task ImportLegacyArchive() {
var paths = await window.StorageProvider.OpenFiles(new FilePickerOpenOptions { var paths = await window.StorageProvider.OpenFiles(new FilePickerOpenOptions {
Title = "Open Legacy DHT Archive", Title = "Open Legacy DHT Archive",
SuggestedStartLocation = await FileDialogs.GetSuggestedStartLocation(window, Path.GetDirectoryName(Db.Path)), SuggestedStartLocation = await FileDialogs.GetSuggestedStartLocation(window, Path.GetDirectoryName(Db.Path)),
AllowMultiple = true AllowMultiple = true
}); });
if (paths.Length == 0) { if (paths.Length > 0) {
return; await ProgressDialog.Show(window, "Legacy Archive Import", async (dialog, callback) => await ImportLegacyArchiveFromPaths(Db, paths, dialog, callback));
} }
ProgressDialog progressDialog = new ProgressDialog();
progressDialog.DataContext = new ProgressDialogModel(async callbacks => await ImportLegacyArchiveFromPaths(Db, paths, progressDialog, callbacks[0])) {
Title = "Legacy Archive Import"
};
await progressDialog.ShowProgressDialog(window);
} }
private static async Task ImportLegacyArchiveFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) { private static async Task ImportLegacyArchiveFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) {
@ -208,7 +194,7 @@ sealed class DatabasePageModel : BaseModel {
private static async Task PerformImport(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback, string neutralDialogTitle, string errorDialogTitle, string itemName, Func<string, Task<bool>> performImport) { private static async Task PerformImport(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback, string neutralDialogTitle, string errorDialogTitle, string itemName, Func<string, Task<bool>> performImport) {
int total = paths.Length; int total = paths.Length;
var oldStatistics = target.SnapshotStatistics(); var oldStatistics = await target.SnapshotStatistics();
int successful = 0; int successful = 0;
int finished = 0; int finished = 0;
@ -239,7 +225,8 @@ sealed class DatabasePageModel : BaseModel {
return; return;
} }
await Dialog.ShowOk(dialog, neutralDialogTitle, GetImportDialogMessage(oldStatistics, target.SnapshotStatistics(), successful, total, itemName)); var newStatistics = await target.SnapshotStatistics();
await Dialog.ShowOk(dialog, neutralDialogTitle, GetImportDialogMessage(oldStatistics, newStatistics, successful, total, itemName));
} }
private static string GetImportDialogMessage(DatabaseStatisticsSnapshot oldStatistics, DatabaseStatisticsSnapshot newStatistics, int successfulItems, int totalItems, string itemName) { private static string GetImportDialogMessage(DatabaseStatisticsSnapshot oldStatistics, DatabaseStatisticsSnapshot newStatistics, int successfulItems, int totalItems, string itemName) {

View File

@ -44,13 +44,7 @@ namespace DHT.Desktop.Main.Pages {
return; return;
} }
ProgressDialog progressDialog = new ProgressDialog { await ProgressDialog.Show(window, "Generating Random Data", async (_, callback) => await GenerateRandomData(channels, users, messages, callback));
DataContext = new ProgressDialogModel(async callbacks => await GenerateRandomData(channels, users, messages, callbacks[0])) {
Title = "Generating Random Data"
}
};
await progressDialog.ShowProgressDialog(window);
} }
private const int BatchSize = 500; private const int BatchSize = 500;
@ -83,12 +77,9 @@ namespace DHT.Desktop.Main.Pages {
Discriminator = rand.Next(0, 9999).ToString(), Discriminator = rand.Next(0, 9999).ToString(),
}).ToArray(); }).ToArray();
state.Db.AddServer(server); await state.Db.Users.Add(users);
state.Db.AddUsers(users); await state.Db.Servers.Add([server]);
await state.Db.Channels.Add(channels);
foreach (var channel in channels) {
state.Db.AddChannel(channel);
}
var now = DateTimeOffset.Now; var now = DateTimeOffset.Now;
int batchIndex = 0; int batchIndex = 0;
@ -111,13 +102,13 @@ namespace DHT.Desktop.Main.Pages {
Timestamp = timeMillis, Timestamp = timeMillis,
EditTimestamp = editMillis, EditTimestamp = editMillis,
RepliedToId = null, RepliedToId = null,
Attachments = ImmutableArray<Attachment>.Empty, Attachments = ImmutableList<Attachment>.Empty,
Embeds = ImmutableArray<Embed>.Empty, Embeds = ImmutableList<Embed>.Empty,
Reactions = ImmutableArray<Reaction>.Empty, Reactions = ImmutableList<Reaction>.Empty,
}; };
}).ToArray(); }).ToArray();
state.Db.AddMessages(messages); await state.Db.Messages.Add(messages);
messageCount -= BatchSize; messageCount -= BatchSize;
await callback.Update("Adding messages in batches of " + BatchSize, ++batchIndex, batchCount); await callback.Update("Adding messages in batches of " + BatchSize, ++batchIndex, batchCount);

View File

@ -16,7 +16,7 @@
To start tracking messages, copy the tracking script and paste it into the console of either the Discord app, or your browser. The console is usually opened by pressing Ctrl+Shift+I. To start tracking messages, copy the tracking script and paste it into the console of either the Discord app, or your browser. The console is usually opened by pressing Ctrl+Shift+I.
</TextBlock> </TextBlock>
<StackPanel DockPanel.Dock="Left" Orientation="Horizontal" Spacing="10"> <StackPanel DockPanel.Dock="Left" Orientation="Horizontal" Spacing="10">
<Button x:Name="CopyTrackingScript" Click="CopyTrackingScriptButton_OnClick">Copy Tracking Script</Button> <Button x:Name="CopyTrackingScript" Click="CopyTrackingScriptButton_OnClick" IsEnabled="{Binding IsCopyTrackingScriptButtonEnabled}">Copy Tracking Script</Button>
</StackPanel> </StackPanel>
<TextBlock TextWrapping="Wrap" Margin="0 5 0 0"> <TextBlock TextWrapping="Wrap" Margin="0 5 0 0">
By default, the Discord app does not allow opening the console. The button below will change a hidden setting in the Discord app that controls whether the Ctrl+Shift+I shortcut is enabled. By default, the Discord app does not allow opening the console. The button below will change a hidden setting in the Discord app that controls whether the Ctrl+Shift+I shortcut is enabled.

View File

@ -1,4 +1,5 @@
using System; using System;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Web; using System.Web;
using Avalonia.Controls; using Avalonia.Controls;
@ -11,9 +12,16 @@ using static DHT.Desktop.Program;
namespace DHT.Desktop.Main.Pages; namespace DHT.Desktop.Main.Pages;
sealed class TrackingPageModel : BaseModel { sealed class TrackingPageModel : BaseModel {
private bool areDevToolsEnabled; private bool isCopyTrackingScriptButtonEnabled = true;
private bool AreDevToolsEnabled { public bool IsCopyTrackingScriptButtonEnabled {
get => isCopyTrackingScriptButtonEnabled;
set => Change(ref isCopyTrackingScriptButtonEnabled, value);
}
private bool? areDevToolsEnabled;
private bool? AreDevToolsEnabled {
get => areDevToolsEnabled; get => areDevToolsEnabled;
set { set {
Change(ref areDevToolsEnabled, value); Change(ref areDevToolsEnabled, value);
@ -21,15 +29,19 @@ sealed class TrackingPageModel : BaseModel {
} }
} }
public bool IsToggleAppDevToolsButtonEnabled { get; private set; } = true; public bool IsToggleAppDevToolsButtonEnabled { get; private set; } = false;
public string ToggleAppDevToolsButtonText { public string ToggleAppDevToolsButtonText {
get { get {
if (!AreDevToolsEnabled.HasValue) {
return "Loading...";
}
if (!IsToggleAppDevToolsButtonEnabled) { if (!IsToggleAppDevToolsButtonEnabled) {
return "Unavailable"; return "Unavailable";
} }
return AreDevToolsEnabled ? "Disable Ctrl+Shift+I" : "Enable Ctrl+Shift+I"; return AreDevToolsEnabled.Value ? "Disable Ctrl+Shift+I" : "Enable Ctrl+Shift+I";
} }
} }
@ -40,20 +52,21 @@ sealed class TrackingPageModel : BaseModel {
public TrackingPageModel(Window window) { public TrackingPageModel(Window window) {
this.window = window; this.window = window;
}
public async Task Initialize() { Task.Factory.StartNew(InitializeDevToolsToggle, CancellationToken.None, TaskCreationOptions.None, TaskScheduler.FromCurrentSynchronizationContext());
bool? devToolsEnabled = await DiscordAppSettings.AreDevToolsEnabled();
if (devToolsEnabled.HasValue) {
AreDevToolsEnabled = devToolsEnabled.Value;
}
else {
IsToggleAppDevToolsButtonEnabled = false;
OnPropertyChanged(nameof(IsToggleAppDevToolsButtonEnabled));
}
} }
public async Task<bool> OnClickCopyTrackingScript() { public async Task<bool> OnClickCopyTrackingScript() {
IsCopyTrackingScriptButtonEnabled = false;
try {
return await CopyTrackingScript();
} finally {
IsCopyTrackingScriptButtonEnabled = true;
}
}
private async Task<bool> CopyTrackingScript() {
string url = $"http://127.0.0.1:{ServerConfiguration.Port}/get-tracking-script?token={HttpUtility.UrlEncode(ServerConfiguration.Token)}"; string url = $"http://127.0.0.1:{ServerConfiguration.Port}/get-tracking-script?token={HttpUtility.UrlEncode(ServerConfiguration.Token)}";
string script = (await Resources.ReadTextAsync("tracker-loader.js")).Trim().Replace("{url}", url); string script = (await Resources.ReadTextAsync("tracker-loader.js")).Trim().Replace("{url}", url);
@ -72,10 +85,28 @@ sealed class TrackingPageModel : BaseModel {
} }
} }
public async void OnClickToggleAppDevTools() { private async Task InitializeDevToolsToggle() {
bool? devToolsEnabled = await DiscordAppSettings.AreDevToolsEnabled();
if (devToolsEnabled.HasValue) {
IsToggleAppDevToolsButtonEnabled = true;
AreDevToolsEnabled = devToolsEnabled.Value;
}
else {
IsToggleAppDevToolsButtonEnabled = false;
}
OnPropertyChanged(nameof(IsToggleAppDevToolsButtonEnabled));
}
public async Task OnClickToggleAppDevTools() {
const string DialogTitle = "Discord App Settings File"; const string DialogTitle = "Discord App Settings File";
bool oldState = AreDevToolsEnabled; if (!AreDevToolsEnabled.HasValue) {
return;
}
bool oldState = AreDevToolsEnabled.Value;
bool newState = !oldState; bool newState = !oldState;
switch (await DiscordAppSettings.ConfigureDevTools(newState)) { switch (await DiscordAppSettings.ConfigureDevTools(newState)) {

View File

@ -11,6 +11,7 @@ using Avalonia.Platform.Storage;
using DHT.Desktop.Common; using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.File; using DHT.Desktop.Dialogs.File;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Desktop.Main.Controls; using DHT.Desktop.Main.Controls;
using DHT.Desktop.Server; using DHT.Desktop.Server;
using DHT.Server; using DHT.Server;
@ -23,7 +24,11 @@ using static DHT.Desktop.Program;
namespace DHT.Desktop.Main.Pages; namespace DHT.Desktop.Main.Pages;
sealed class ViewerPageModel : BaseModel, IDisposable { sealed class ViewerPageModel : BaseModel, IDisposable {
public static readonly ConcurrentBag<string> TemporaryFiles = new (); public static readonly ConcurrentBag<string> TemporaryFiles = [];
private static readonly FilePickerFileType[] ViewerFileTypes = [
FileDialogs.CreateFilter("Discord History Viewer", ["html"])
];
public bool DatabaseToolFilterModeKeep { get; set; } = true; public bool DatabaseToolFilterModeKeep { get; set; } = true;
public bool DatabaseToolFilterModeRemove { get; set; } = false; public bool DatabaseToolFilterModeRemove { get; set; } = false;
@ -59,6 +64,58 @@ sealed class ViewerPageModel : BaseModel, IDisposable {
HasFilters = FilterModel.HasAnyFilters; HasFilters = FilterModel.HasAnyFilters;
} }
public async void OnClickOpenViewer() {
try {
var fullPath = await PrepareTemporaryViewerFile();
var strategy = new LiveViewerExportStrategy(ServerConfiguration.Port, ServerConfiguration.Token);
await WriteViewerFile(fullPath, strategy);
Process.Start(new ProcessStartInfo(fullPath) { UseShellExecute = true });
} catch (Exception e) {
await Dialog.ShowOk(window, "Open Viewer", "Could not save viewer: " + e.Message);
}
}
private async Task<string> PrepareTemporaryViewerFile() {
return await Task.Run(() => {
string rootPath = Path.Combine(Path.GetTempPath(), "DiscordHistoryTracker");
string filenameBase = Path.GetFileNameWithoutExtension(state.Db.Path) + "-" + DateTime.Now.ToString("yyyy-MM-dd");
string fullPath = Path.Combine(rootPath, filenameBase + ".html");
int counter = 0;
while (File.Exists(fullPath)) {
++counter;
fullPath = Path.Combine(rootPath, filenameBase + "-" + counter + ".html");
}
TemporaryFiles.Add(fullPath);
Directory.CreateDirectory(rootPath);
return fullPath;
});
}
public async void OnClickSaveViewer() {
string? path = await window.StorageProvider.SaveFile(new FilePickerSaveOptions {
Title = "Save Viewer",
FileTypeChoices = ViewerFileTypes,
SuggestedFileName = Path.GetFileNameWithoutExtension(state.Db.Path) + ".html",
SuggestedStartLocation = await FileDialogs.GetSuggestedStartLocation(window, Path.GetDirectoryName(state.Db.Path)),
});
if (path == null) {
return;
}
try {
await WriteViewerFile(path, StandaloneViewerExportStrategy.Instance);
} catch (Exception e) {
await Dialog.ShowOk(window, "Save Viewer", "Could not save viewer: " + e.Message);
}
}
private async Task WriteViewerFile(string path, IViewerExportStrategy strategy) { private async Task WriteViewerFile(string path, IViewerExportStrategy strategy) {
const string ArchiveTag = "/*[ARCHIVE]*/"; const string ArchiveTag = "/*[ARCHIVE]*/";
@ -96,54 +153,23 @@ sealed class ViewerPageModel : BaseModel, IDisposable {
File.Delete(jsonTempFile); File.Delete(jsonTempFile);
} }
public async void OnClickOpenViewer() { public async Task OnClickApplyFiltersToDatabase() {
string rootPath = Path.Combine(Path.GetTempPath(), "DiscordHistoryTracker");
string filenameBase = Path.GetFileNameWithoutExtension(state.Db.Path) + "-" + DateTime.Now.ToString("yyyy-MM-dd");
string fullPath = Path.Combine(rootPath, filenameBase + ".html");
int counter = 0;
while (File.Exists(fullPath)) {
++counter;
fullPath = Path.Combine(rootPath, filenameBase + "-" + counter + ".html");
}
TemporaryFiles.Add(fullPath);
Directory.CreateDirectory(rootPath);
await WriteViewerFile(fullPath, new LiveViewerExportStrategy(ServerConfiguration.Port, ServerConfiguration.Token));
Process.Start(new ProcessStartInfo(fullPath) { UseShellExecute = true });
}
private static readonly FilePickerFileType[] ViewerFileTypes = {
FileDialogs.CreateFilter("Discord History Viewer", new string[] { "html" }),
};
public async void OnClickSaveViewer() {
string? path = await window.StorageProvider.SaveFile(new FilePickerSaveOptions {
Title = "Save Viewer",
FileTypeChoices = ViewerFileTypes,
SuggestedFileName = Path.GetFileNameWithoutExtension(state.Db.Path) + ".html",
SuggestedStartLocation = await FileDialogs.GetSuggestedStartLocation(window, Path.GetDirectoryName(state.Db.Path)),
});
if (path != null) {
await WriteViewerFile(path, StandaloneViewerExportStrategy.Instance);
}
}
public async void OnClickApplyFiltersToDatabase() {
var filter = FilterModel.CreateFilter(); var filter = FilterModel.CreateFilter();
var messageCount = await ProgressDialog.ShowIndeterminate(window, "Apply Filters", "Counting matching messages...", _ => state.Db.Messages.Count(filter));
if (DatabaseToolFilterModeKeep) { if (DatabaseToolFilterModeKeep) {
if (DialogResult.YesNo.Yes == await Dialog.ShowYesNo(window, "Keep Matching Messages in This Database", state.Db.CountMessages(filter).Pluralize("message") + " will be kept, and the rest will be removed from this database. This action cannot be undone. Proceed?")) { if (DialogResult.YesNo.Yes == await Dialog.ShowYesNo(window, "Keep Matching Messages in This Database", messageCount.Pluralize("message") + " will be kept, and the rest will be removed from this database. This action cannot be undone. Proceed?")) {
state.Db.RemoveMessages(filter, FilterRemovalMode.KeepMatching); await ApplyFilterToDatabase(filter, FilterRemovalMode.KeepMatching);
} }
} }
else if (DatabaseToolFilterModeRemove) { else if (DatabaseToolFilterModeRemove) {
if (DialogResult.YesNo.Yes == await Dialog.ShowYesNo(window, "Remove Matching Messages in This Database", state.Db.CountMessages(filter).Pluralize("message") + " will be removed from this database. This action cannot be undone. Proceed?")) { if (DialogResult.YesNo.Yes == await Dialog.ShowYesNo(window, "Remove Matching Messages in This Database", messageCount.Pluralize("message") + " will be removed from this database. This action cannot be undone. Proceed?")) {
state.Db.RemoveMessages(filter, FilterRemovalMode.RemoveMatching); await ApplyFilterToDatabase(filter, FilterRemovalMode.RemoveMatching);
} }
} }
} }
private async Task ApplyFilterToDatabase(MessageFilter filter, FilterRemovalMode removalMode) {
await ProgressDialog.ShowIndeterminate(window, "Apply Filters", "Removing messages...", _ => state.Db.Messages.Remove(filter, removalMode));
}
} }

View File

@ -1,16 +1,12 @@
using System; using System;
using System.Threading.Tasks;
using Avalonia.Controls; using Avalonia.Controls;
using DHT.Desktop.Main.Controls; using DHT.Desktop.Main.Controls;
using DHT.Desktop.Main.Pages; using DHT.Desktop.Main.Pages;
using DHT.Server; using DHT.Server;
using DHT.Utils.Logging;
namespace DHT.Desktop.Main.Screens; namespace DHT.Desktop.Main.Screens;
sealed class MainContentScreenModel : IDisposable { sealed class MainContentScreenModel : IDisposable {
private static readonly Log Log = Log.ForType<MainContentScreenModel>();
public DatabasePage DatabasePage { get; } public DatabasePage DatabasePage { get; }
private DatabasePageModel DatabasePageModel { get; } private DatabasePageModel DatabasePageModel { get; }
@ -75,10 +71,6 @@ sealed class MainContentScreenModel : IDisposable {
StatusBarModel = new StatusBarModel(state); StatusBarModel = new StatusBarModel(state);
} }
public async Task Initialize() {
await TrackingPageModel.Initialize();
}
public void Dispose() { public void Dispose() {
AttachmentsPageModel.Dispose(); AttachmentsPageModel.Dispose();
ViewerPageModel.Dispose(); ViewerPageModel.Dispose();

View File

@ -32,7 +32,7 @@
<TextBlock Text="{Binding Version, StringFormat=Discord History Tracker v{0}}" FontSize="25" Margin="0 0 0 30" HorizontalAlignment="Center" /> <TextBlock Text="{Binding Version, StringFormat=Discord History Tracker v{0}}" FontSize="25" Margin="0 0 0 30" HorizontalAlignment="Center" />
<StackPanel Orientation="Horizontal" HorizontalAlignment="Center"> <StackPanel Orientation="Horizontal" HorizontalAlignment="Center">
<Button Command="{Binding OpenOrCreateDatabase}">Open or Create Database</Button> <Button Command="{Binding OpenOrCreateDatabase}" IsEnabled="{Binding IsOpenOrCreateDatabaseButtonEnabled}">Open or Create Database</Button>
<Button Command="{Binding ShowAboutDialog}">About</Button> <Button Command="{Binding ShowAboutDialog}">About</Button>
<Button Command="{Binding Exit}">Exit</Button> <Button Command="{Binding Exit}">Exit</Button>
</StackPanel> </StackPanel>

View File

@ -7,16 +7,22 @@ using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress; using DHT.Desktop.Dialogs.Progress;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Database.Sqlite; using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Models; using DHT.Utils.Models;
namespace DHT.Desktop.Main.Screens; namespace DHT.Desktop.Main.Screens;
sealed class WelcomeScreenModel : BaseModel, IDisposable { sealed class WelcomeScreenModel : BaseModel {
public string Version => Program.Version; public string Version => Program.Version;
public IDatabaseFile? Db { get; private set; } private bool isOpenOrCreateDatabaseButtonEnabled = true;
public bool HasDatabase => Db != null;
public bool IsOpenOrCreateDatabaseButtonEnabled {
get => isOpenOrCreateDatabaseButtonEnabled;
set => Change(ref isOpenOrCreateDatabaseButtonEnabled, value);
}
public event EventHandler<IDatabaseFile>? DatabaseSelected;
private readonly Window window; private readonly Window window;
@ -29,23 +35,25 @@ sealed class WelcomeScreenModel : BaseModel, IDisposable {
this.window = window; this.window = window;
} }
public async void OpenOrCreateDatabase() { public async Task OpenOrCreateDatabase() {
IsOpenOrCreateDatabaseButtonEnabled = false;
try {
var path = await DatabaseGui.NewOpenOrCreateDatabaseFileDialog(window, Path.GetDirectoryName(dbFilePath)); var path = await DatabaseGui.NewOpenOrCreateDatabaseFileDialog(window, Path.GetDirectoryName(dbFilePath));
if (path != null) { if (path != null) {
await OpenOrCreateDatabaseFromPath(path); await OpenOrCreateDatabaseFromPath(path);
} }
} finally {
IsOpenOrCreateDatabaseButtonEnabled = true;
}
} }
public async Task OpenOrCreateDatabaseFromPath(string path) { public async Task OpenOrCreateDatabaseFromPath(string path) {
if (Db != null) {
Db = null;
}
dbFilePath = path; dbFilePath = path;
Db = await DatabaseGui.TryOpenOrCreateDatabaseFromPath(path, window, new SchemaUpgradeCallbacks(window));
OnPropertyChanged(nameof(Db)); var db = await DatabaseGui.TryOpenOrCreateDatabaseFromPath(path, window, new SchemaUpgradeCallbacks(window));
OnPropertyChanged(nameof(HasDatabase)); if (db != null) {
DatabaseSelected?.Invoke(this, db);
}
} }
private sealed class SchemaUpgradeCallbacks : ISchemaUpgradeCallbacks { private sealed class SchemaUpgradeCallbacks : ISchemaUpgradeCallbacks {
@ -68,11 +76,7 @@ sealed class WelcomeScreenModel : BaseModel, IDisposable {
await Task.Delay(TimeSpan.FromMilliseconds(600)); await Task.Delay(TimeSpan.FromMilliseconds(600));
} }
await new ProgressDialog { await new ProgressDialog { DataContext = new ProgressDialogModel("Upgrading Database", StartUpgrade, progressItems: 3) }.ShowProgressDialog(window);
DataContext = new ProgressDialogModel(StartUpgrade, progressItems: 3) {
Title = "Upgrading Database"
}
}.ShowProgressDialog(window);
} }
private sealed class ProgressReporter : ISchemaUpgradeCallbacks.IProgressReporter { private sealed class ProgressReporter : ISchemaUpgradeCallbacks.IProgressReporter {
@ -109,22 +113,11 @@ sealed class WelcomeScreenModel : BaseModel, IDisposable {
} }
} }
public void CloseDatabase() { public async Task ShowAboutDialog() {
Dispose(); await new AboutWindow { DataContext = new AboutWindowModel() }.ShowDialog(window);
OnPropertyChanged(nameof(Db));
OnPropertyChanged(nameof(HasDatabase));
}
public async void ShowAboutDialog() {
await new AboutWindow { DataContext = new AboutWindowModel() }.ShowDialog(this.window);
} }
public void Exit() { public void Exit() {
window.Close(); window.Close();
} }
public void Dispose() {
Db?.Dispose();
Db = null;
}
} }

View File

@ -2,7 +2,7 @@
<PropertyGroup> <PropertyGroup>
<TargetFramework>net8.0</TargetFramework> <TargetFramework>net8.0</TargetFramework>
<LangVersion>11</LangVersion> <LangVersion>12</LangVersion>
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
</PropertyGroup> </PropertyGroup>

View File

@ -10,7 +10,7 @@ public readonly struct Message {
public long Timestamp { get; init; } public long Timestamp { get; init; }
public long? EditTimestamp { get; init; } public long? EditTimestamp { get; init; }
public ulong? RepliedToId { get; init; } public ulong? RepliedToId { get; init; }
public ImmutableArray<Attachment> Attachments { get; init; } public ImmutableList<Attachment> Attachments { get; init; }
public ImmutableArray<Embed> Embeds { get; init; } public ImmutableList<Embed> Embeds { get; init; }
public ImmutableArray<Reaction> Reactions { get; init; } public ImmutableList<Reaction> Reactions { get; init; }
} }

View File

@ -1,29 +1,32 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
namespace DHT.Server.Database; namespace DHT.Server.Database;
public static class DatabaseExtensions { public static class DatabaseExtensions {
public static void AddFrom(this IDatabaseFile target, IDatabaseFile source) { public static async Task AddFrom(this IDatabaseFile target, IDatabaseFile source) {
target.AddServers(source.GetAllServers()); await target.Users.Add(await source.Users.Get().ToListAsync());
target.AddChannels(source.GetAllChannels()); await target.Servers.Add(await source.Servers.Get().ToListAsync());
target.AddUsers(source.GetAllUsers().ToArray()); await target.Channels.Add(await source.Channels.Get().ToListAsync());
target.AddMessages(source.GetMessages().ToArray());
foreach (var download in source.GetDownloadsWithoutData()) { const int MessageBatchSize = 100;
target.AddDownload(download.Status == DownloadStatus.Success ? source.GetDownloadWithData(download) : download); List<Message> batchedMessages = new (MessageBatchSize);
await foreach (var message in source.Messages.Get()) {
batchedMessages.Add(message);
if (batchedMessages.Count >= MessageBatchSize) {
await target.Messages.Add(batchedMessages);
batchedMessages.Clear();
} }
} }
internal static void AddServers(this IDatabaseFile target, IEnumerable<Data.Server> servers) { await target.Messages.Add(batchedMessages);
foreach (var server in servers) {
target.AddServer(server);
}
}
internal static void AddChannels(this IDatabaseFile target, IEnumerable<Channel> channels) { await foreach (var download in source.Downloads.GetWithoutData()) {
foreach (var channel in channels) { await target.Downloads.AddDownload(download.Status == DownloadStatus.Success ? await source.Downloads.HydrateWithData(download) : download);
target.AddChannel(channel);
} }
} }
} }

View File

@ -1,90 +1,31 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using DHT.Server.Data; using System.Threading.Tasks;
using DHT.Server.Data.Aggregations; using DHT.Server.Database.Repositories;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
namespace DHT.Server.Database; namespace DHT.Server.Database;
[SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")] [SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
public sealed class DummyDatabaseFile : IDatabaseFile { sealed class DummyDatabaseFile : IDatabaseFile {
public static DummyDatabaseFile Instance { get; } = new(); public static DummyDatabaseFile Instance { get; } = new ();
public string Path => ""; public string Path => "";
public DatabaseStatistics Statistics { get; } = new(); public DatabaseStatistics Statistics { get; } = new ();
public IUserRepository Users { get; } = new IUserRepository.Dummy();
public IServerRepository Servers { get; } = new IServerRepository.Dummy();
public IChannelRepository Channels { get; } = new IChannelRepository.Dummy();
public IMessageRepository Messages { get; } = new IMessageRepository.Dummy();
public IDownloadRepository Downloads { get; } = new IDownloadRepository.Dummy();
private DummyDatabaseFile() {} private DummyDatabaseFile() {}
public DatabaseStatisticsSnapshot SnapshotStatistics() { public Task<DatabaseStatisticsSnapshot> SnapshotStatistics() {
return new(); return Task.FromResult(new DatabaseStatisticsSnapshot());
} }
public void AddServer(Data.Server server) {} public Task Vacuum() {
return Task.CompletedTask;
public List<Data.Server> GetAllServers() {
return new();
} }
public void AddChannel(Channel channel) {}
public List<Channel> GetAllChannels() {
return new();
}
public void AddUsers(User[] users) {}
public List<User> GetAllUsers() {
return new();
}
public void AddMessages(Message[] messages) {}
public int CountMessages(MessageFilter? filter = null) {
return 0;
}
public List<Message> GetMessages(MessageFilter? filter = null) {
return new();
}
public HashSet<ulong> GetMessageIds(MessageFilter? filter = null) {
return new();
}
public void RemoveMessages(MessageFilter filter, FilterRemovalMode mode) {}
public int CountAttachments(AttachmentFilter? filter = null) {
return new();
}
public List<Data.Download> GetDownloadsWithoutData() {
return new();
}
public Data.Download GetDownloadWithData(Data.Download download) {
return download;
}
public DownloadedAttachment? GetDownloadedAttachment(string url) {
return null;
}
public void AddDownload(Data.Download download) {}
public void EnqueueDownloadItems(AttachmentFilter? filter = null) {}
public List<DownloadItem> PullEnqueuedDownloadItems(int count) {
return new();
}
public void RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {}
public DownloadStatusStatistics GetDownloadStatusStatistics() {
return new();
}
public void Vacuum() {}
public void Dispose() {} public void Dispose() {}
} }

View File

@ -21,7 +21,7 @@ public static class ViewerJsonExport {
var includedChannelIds = new HashSet<ulong>(); var includedChannelIds = new HashSet<ulong>();
var includedServerIds = new HashSet<ulong>(); var includedServerIds = new HashSet<ulong>();
var includedMessages = db.GetMessages(filter); var includedMessages = await db.Messages.Get(filter).ToListAsync();
var includedChannels = new List<Channel>(); var includedChannels = new List<Channel>();
foreach (var message in includedMessages) { foreach (var message in includedMessages) {
@ -29,23 +29,23 @@ public static class ViewerJsonExport {
includedChannelIds.Add(message.Channel); includedChannelIds.Add(message.Channel);
} }
foreach (var channel in db.GetAllChannels()) { await foreach (var channel in db.Channels.Get()) {
if (includedChannelIds.Contains(channel.Id)) { if (includedChannelIds.Contains(channel.Id)) {
includedChannels.Add(channel); includedChannels.Add(channel);
includedServerIds.Add(channel.Server); includedServerIds.Add(channel.Server);
} }
} }
var users = GenerateUserList(db, includedUserIds, out var userindex, out var userIndices); var (users, userIndex, userIndices) = await GenerateUserList(db, includedUserIds);
var servers = GenerateServerList(db, includedServerIds, out var serverindex); var (servers, serverIndices) = await GenerateServerList(db, includedServerIds);
var channels = GenerateChannelList(includedChannels, serverindex); var channels = GenerateChannelList(includedChannels, serverIndices);
perf.Step("Collect database data"); perf.Step("Collect database data");
var value = new ViewerJson { var value = new ViewerJson {
Meta = new ViewerJson.JsonMeta { Meta = new ViewerJson.JsonMeta {
Users = users, Users = users,
Userindex = userindex, Userindex = userIndex,
Servers = servers, Servers = servers,
Channels = channels Channels = channels
}, },
@ -60,12 +60,12 @@ public static class ViewerJsonExport {
perf.End(); perf.End();
} }
private static Dictionary<Snowflake, ViewerJson.JsonUser> GenerateUserList(IDatabaseFile db, HashSet<ulong> userIds, out List<Snowflake> userindex, out Dictionary<ulong, int> userIndices) { private static async Task<(Dictionary<Snowflake, ViewerJson.JsonUser> Users, List<Snowflake> UserIndex, Dictionary<ulong, int> UserIndices)> GenerateUserList(IDatabaseFile db, HashSet<ulong> userIds) {
var users = new Dictionary<Snowflake, ViewerJson.JsonUser>(); var users = new Dictionary<Snowflake, ViewerJson.JsonUser>();
userindex = new List<Snowflake>(); var userIndex = new List<Snowflake>();
userIndices = new Dictionary<ulong, int>(); var userIndices = new Dictionary<ulong, int>();
foreach (var user in db.GetAllUsers()) { await foreach (var user in db.Users.Get()) {
var id = user.Id; var id = user.Id;
if (!userIds.Contains(id)) { if (!userIds.Contains(id)) {
continue; continue;
@ -73,7 +73,7 @@ public static class ViewerJsonExport {
var idSnowflake = new Snowflake(id); var idSnowflake = new Snowflake(id);
userIndices[id] = users.Count; userIndices[id] = users.Count;
userindex.Add(idSnowflake); userIndex.Add(idSnowflake);
users[idSnowflake] = new ViewerJson.JsonUser { users[idSnowflake] = new ViewerJson.JsonUser {
Name = user.Name, Name = user.Name,
@ -82,14 +82,14 @@ public static class ViewerJsonExport {
}; };
} }
return users; return (users, userIndex, userIndices);
} }
private static List<ViewerJson.JsonServer> GenerateServerList(IDatabaseFile db, HashSet<ulong> serverIds, out Dictionary<ulong, int> serverIndices) { private static async Task<(List<ViewerJson.JsonServer> Servers, Dictionary<ulong, int> ServerIndices)> GenerateServerList(IDatabaseFile db, HashSet<ulong> serverIds) {
var servers = new List<ViewerJson.JsonServer>(); var servers = new List<ViewerJson.JsonServer>();
serverIndices = new Dictionary<ulong, int>(); var serverIndices = new Dictionary<ulong, int>();
foreach (var server in db.GetAllServers()) { await foreach (var server in db.Servers.Get()) {
var id = server.Id; var id = server.Id;
if (!serverIds.Contains(id)) { if (!serverIds.Contains(id)) {
continue; continue;
@ -103,7 +103,7 @@ public static class ViewerJsonExport {
}); });
} }
return servers; return (servers, serverIndices);
} }
private static Dictionary<Snowflake, ViewerJson.JsonChannel> GenerateChannelList(List<Channel> includedChannels, Dictionary<ulong, int> serverIndices) { private static Dictionary<Snowflake, ViewerJson.JsonChannel> GenerateChannelList(List<Channel> includedChannels, Dictionary<ulong, int> serverIndices) {

View File

@ -1,43 +1,19 @@
using System; using System;
using System.Collections.Generic; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Database.Repositories;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
namespace DHT.Server.Database; namespace DHT.Server.Database;
public interface IDatabaseFile : IDisposable { public interface IDatabaseFile : IDisposable {
string Path { get; } string Path { get; }
DatabaseStatistics Statistics { get; } DatabaseStatistics Statistics { get; }
DatabaseStatisticsSnapshot SnapshotStatistics(); Task<DatabaseStatisticsSnapshot> SnapshotStatistics();
void AddServer(Data.Server server); IUserRepository Users { get; }
List<Data.Server> GetAllServers(); IServerRepository Servers { get; }
IChannelRepository Channels { get; }
IMessageRepository Messages { get; }
IDownloadRepository Downloads { get; }
void AddChannel(Channel channel); Task Vacuum();
List<Channel> GetAllChannels();
void AddUsers(User[] users);
List<User> GetAllUsers();
void AddMessages(Message[] messages);
int CountMessages(MessageFilter? filter = null);
List<Message> GetMessages(MessageFilter? filter = null);
HashSet<ulong> GetMessageIds(MessageFilter? filter = null);
void RemoveMessages(MessageFilter filter, FilterRemovalMode mode);
int CountAttachments(AttachmentFilter? filter = null);
void AddDownload(Data.Download download);
List<Data.Download> GetDownloadsWithoutData();
Data.Download GetDownloadWithData(Data.Download download);
DownloadedAttachment? GetDownloadedAttachment(string url);
void EnqueueDownloadItems(AttachmentFilter? filter = null);
List<DownloadItem> PullEnqueuedDownloadItems(int count);
void RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode);
DownloadStatusStatistics GetDownloadStatusStatistics();
void Vacuum();
} }

View File

@ -33,10 +33,8 @@ public static class LegacyArchiveImport {
var servers = ReadServerList(meta, fakeSnowflake); var servers = ReadServerList(meta, fakeSnowflake);
var newServersOnly = new HashSet<Data.Server>(servers); var newServersOnly = new HashSet<Data.Server>(servers);
var oldServersById = db.GetAllServers().ToDictionary(static server => server.Id, static server => server); var oldServersById = await db.Servers.Get().ToDictionaryAsync(static server => server.Id, static server => server);
var oldChannelsById = await db.Channels.Get().ToDictionaryAsync(static channel => channel.Id, static channel => channel);
var oldChannels = db.GetAllChannels();
var oldChannelsById = oldChannels.ToDictionary(static channel => channel.Id, static channel => channel);
foreach (var (channelId, serverIndex) in ReadChannelToServerIndexMapping(meta, servers)) { foreach (var (channelId, serverIndex) in ReadChannelToServerIndexMapping(meta, servers)) {
if (oldChannelsById.TryGetValue(channelId, out var oldChannel) && oldServersById.TryGetValue(oldChannel.Server, out var oldServer) && newServersOnly.Remove(servers[serverIndex])) { if (oldChannelsById.TryGetValue(channelId, out var oldChannel) && oldServersById.TryGetValue(oldChannel.Server, out var oldServer) && newServersOnly.Remove(servers[serverIndex])) {
@ -66,17 +64,17 @@ public static class LegacyArchiveImport {
perf.Step("Read channel list"); perf.Step("Read channel list");
var oldMessageIds = db.GetMessageIds(); var oldMessageIds = await db.Messages.GetIds().ToHashSetAsync();
var newMessages = channels.SelectMany(channel => ReadMessages(data, channel, users, fakeSnowflake)) var newMessages = channels.SelectMany(channel => ReadMessages(data, channel, users, fakeSnowflake))
.Where(message => !oldMessageIds.Contains(message.Id)) .Where(message => !oldMessageIds.Contains(message.Id))
.ToArray(); .ToArray();
perf.Step("Read messages"); perf.Step("Read messages");
db.AddUsers(users); await db.Users.Add(users);
db.AddServers(servers); await db.Servers.Add(servers);
db.AddChannels(channels); await db.Channels.Add(channels);
db.AddMessages(newMessages); await db.Messages.Add(newMessages);
perf.Step("Import into database"); perf.Step("Import into database");
} catch (HttpException e) { } catch (HttpException e) {
@ -179,9 +177,9 @@ public static class LegacyArchiveImport {
Timestamp = messageObj.RequireLong("t", path), Timestamp = messageObj.RequireLong("t", path),
EditTimestamp = messageObj.HasKey("te") ? messageObj.RequireLong("te", path) : null, EditTimestamp = messageObj.HasKey("te") ? messageObj.RequireLong("te", path) : null,
RepliedToId = messageObj.HasKey("r") ? messageObj.RequireSnowflake("r", path) : null, RepliedToId = messageObj.HasKey("r") ? messageObj.RequireSnowflake("r", path) : null,
Attachments = messageObj.HasKey("a") ? ReadMessageAttachments(messageObj.RequireArray("a", path), fakeSnowflake, path + ".a[]").ToImmutableArray() : ImmutableArray<Attachment>.Empty, Attachments = messageObj.HasKey("a") ? ReadMessageAttachments(messageObj.RequireArray("a", path), fakeSnowflake, path + ".a[]").ToImmutableList() : ImmutableList<Attachment>.Empty,
Embeds = messageObj.HasKey("e") ? ReadMessageEmbeds(messageObj.RequireArray("e", path), path + ".e[]").ToImmutableArray() : ImmutableArray<Embed>.Empty, Embeds = messageObj.HasKey("e") ? ReadMessageEmbeds(messageObj.RequireArray("e", path), path + ".e[]").ToImmutableList() : ImmutableList<Embed>.Empty,
Reactions = messageObj.HasKey("re") ? ReadMessageReactions(messageObj.RequireArray("re", path), path + ".re[]").ToImmutableArray() : ImmutableArray<Reaction>.Empty, Reactions = messageObj.HasKey("re") ? ReadMessageReactions(messageObj.RequireArray("re", path), path + ".re[]").ToImmutableList() : ImmutableList<Reaction>.Empty,
}; };
}).ToArray(); }).ToArray();
} }

View File

@ -0,0 +1,22 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using DHT.Server.Data;
namespace DHT.Server.Database.Repositories;
public interface IChannelRepository {
Task Add(IReadOnlyList<Channel> channels);
IAsyncEnumerable<Channel> Get();
internal sealed class Dummy : IChannelRepository {
public Task Add(IReadOnlyList<Channel> channels) {
return Task.CompletedTask;
}
public IAsyncEnumerable<Channel> Get() {
return AsyncEnumerable.Empty<Channel>();
}
}
}

View File

@ -0,0 +1,68 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
namespace DHT.Server.Database.Repositories;
public interface IDownloadRepository {
Task<long> CountAttachments(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
Task AddDownload(Data.Download download);
Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken = default);
IAsyncEnumerable<Data.Download> GetWithoutData();
Task<Data.Download> HydrateWithData(Data.Download download);
Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl);
Task EnqueueDownloadItems(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken = default);
Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode);
internal sealed class Dummy : IDownloadRepository {
public Task<long> CountAttachments(AttachmentFilter? filter, CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public Task AddDownload(Data.Download download) {
return Task.CompletedTask;
}
public Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken) {
return Task.FromResult(new DownloadStatusStatistics());
}
public IAsyncEnumerable<Data.Download> GetWithoutData() {
return AsyncEnumerable.Empty<Data.Download>();
}
public Task<Data.Download> HydrateWithData(Data.Download download) {
return Task.FromResult(download);
}
public Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) {
return Task.FromResult<DownloadedAttachment?>(null);
}
public Task EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
return Task.CompletedTask;
}
public IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken) {
return AsyncEnumerable.Empty<DownloadItem>();
}
public Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
return Task.CompletedTask;
}
}
}

View File

@ -0,0 +1,42 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Filters;
namespace DHT.Server.Database.Repositories;
public interface IMessageRepository {
Task Add(IReadOnlyList<Message> messages);
Task<long> Count(MessageFilter? filter = null, CancellationToken cancellationToken = default);
IAsyncEnumerable<Message> Get(MessageFilter? filter = null);
IAsyncEnumerable<ulong> GetIds(MessageFilter? filter = null);
Task Remove(MessageFilter filter, FilterRemovalMode mode);
internal sealed class Dummy : IMessageRepository {
public Task Add(IReadOnlyList<Message> messages) {
return Task.CompletedTask;
}
public Task<long> Count(MessageFilter? filter, CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public IAsyncEnumerable<Message> Get(MessageFilter? filter) {
return AsyncEnumerable.Empty<Message>();
}
public IAsyncEnumerable<ulong> GetIds(MessageFilter? filter) {
return AsyncEnumerable.Empty<ulong>();
}
public Task Remove(MessageFilter filter, FilterRemovalMode mode) {
return Task.CompletedTask;
}
}
}

View File

@ -0,0 +1,21 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace DHT.Server.Database.Repositories;
public interface IServerRepository {
Task Add(IReadOnlyList<Data.Server> servers);
IAsyncEnumerable<Data.Server> Get();
internal sealed class Dummy : IServerRepository {
public Task Add(IReadOnlyList<Data.Server> servers) {
return Task.CompletedTask;
}
public IAsyncEnumerable<Data.Server> Get() {
return AsyncEnumerable.Empty<Data.Server>();
}
}
}

View File

@ -0,0 +1,22 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using DHT.Server.Data;
namespace DHT.Server.Database.Repositories;
public interface IUserRepository {
Task Add(IReadOnlyList<User> users);
IAsyncEnumerable<User> Get();
internal sealed class Dummy : IUserRepository {
public Task Add(IReadOnlyList<User> users) {
return Task.CompletedTask;
}
public IAsyncEnumerable<User> Get() {
return AsyncEnumerable.Empty<User>();
}
}
}

View File

@ -0,0 +1,77 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteChannelRepository : IChannelRepository {
private readonly SqliteConnectionPool pool;
private readonly DatabaseStatistics statistics;
public SqliteChannelRepository(SqliteConnectionPool pool, DatabaseStatistics statistics) {
this.pool = pool;
this.statistics = statistics;
}
internal async Task Initialize() {
using var conn = pool.Take();
await UpdateChannelStatistics(conn);
}
private async Task UpdateChannelStatistics(ISqliteConnection conn) {
statistics.TotalChannels = await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM channels", static reader => reader?.GetInt64(0) ?? 0L);
}
public async Task Add(IReadOnlyList<Channel> channels) {
using var conn = pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("channels", [
("id", SqliteType.Integer),
("server", SqliteType.Integer),
("name", SqliteType.Text),
("parent_id", SqliteType.Integer),
("position", SqliteType.Integer),
("topic", SqliteType.Text),
("nsfw", SqliteType.Integer)
]);
foreach (var channel in channels) {
cmd.Set(":id", channel.Id);
cmd.Set(":server", channel.Server);
cmd.Set(":name", channel.Name);
cmd.Set(":parent_id", channel.ParentId);
cmd.Set(":position", channel.Position);
cmd.Set(":topic", channel.Topic);
cmd.Set(":nsfw", channel.Nsfw);
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
}
await UpdateChannelStatistics(conn);
}
public async IAsyncEnumerable<Channel> Get() {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
yield return new Channel {
Id = reader.GetUint64(0),
Server = reader.GetUint64(1),
Name = reader.GetString(2),
ParentId = reader.IsDBNull(3) ? null : reader.GetUint64(3),
Position = reader.IsDBNull(4) ? null : reader.GetInt32(4),
Topic = reader.IsDBNull(5) ? null : reader.GetString(5),
Nsfw = reader.IsDBNull(6) ? null : reader.GetBoolean(6),
};
}
}
}

View File

@ -0,0 +1,233 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteDownloadRepository : IDownloadRepository {
private readonly SqliteConnectionPool pool;
private readonly AsyncValueComputer<long>.Single totalDownloadsComputer;
public SqliteDownloadRepository(SqliteConnectionPool pool, AsyncValueComputer<long>.Single totalDownloadsComputer) {
this.pool = pool;
this.totalDownloadsComputer = totalDownloadsComputer;
}
public async Task<long> CountAttachments(AttachmentFilter? filter, CancellationToken cancellationToken) {
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments a" + filter.GenerateWhereClause("a"), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
public async Task AddDownload(Data.Download download) {
using (var conn = pool.Take()) {
await using var cmd = conn.Upsert("downloads", [
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("status", SqliteType.Integer),
("size", SqliteType.Integer),
("blob", SqliteType.Blob)
]);
cmd.Set(":normalized_url", download.NormalizedUrl);
cmd.Set(":download_url", download.DownloadUrl);
cmd.Set(":status", (int) download.Status);
cmd.Set(":size", download.Size);
cmd.Set(":blob", download.Data);
await cmd.ExecuteNonQueryAsync();
}
totalDownloadsComputer.Recompute();
}
public async Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken) {
static async Task LoadUndownloadedStatistics(ISqliteConnection conn, DownloadStatusStatistics result, CancellationToken cancellationToken) {
await using var cmd = conn.Command(
"""
SELECT IFNULL(COUNT(size), 0), IFNULL(SUM(size), 0)
FROM (SELECT MAX(a.size) size
FROM attachments a
WHERE a.normalized_url NOT IN (SELECT d.normalized_url FROM downloads d)
GROUP BY a.normalized_url)
""");
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
if (reader.Read()) {
result.SkippedCount = reader.GetInt32(0);
result.SkippedSize = reader.GetUint64(1);
}
}
static async Task LoadSuccessStatistics(ISqliteConnection conn, DownloadStatusStatistics result, CancellationToken cancellationToken) {
await using var cmd = conn.Command(
"""
SELECT
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN size ELSE 0 END), 0)
FROM downloads
"""
);
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
if (reader.Read()) {
result.EnqueuedCount = reader.GetInt32(0);
result.EnqueuedSize = reader.GetUint64(1);
result.SuccessfulCount = reader.GetInt32(2);
result.SuccessfulSize = reader.GetUint64(3);
result.FailedCount = reader.GetInt32(4);
result.FailedSize = reader.GetUint64(5);
}
}
var result = new DownloadStatusStatistics();
using var conn = pool.Take();
await LoadUndownloadedStatistics(conn, result, cancellationToken);
await LoadSuccessStatistics(conn, result, cancellationToken);
return result;
}
public async IAsyncEnumerable<Data.Download> GetWithoutData() {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT normalized_url, download_url, status, size FROM downloads");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
string normalizedUrl = reader.GetString(0);
string downloadUrl = reader.GetString(1);
var status = (DownloadStatus) reader.GetInt32(2);
ulong size = reader.GetUint64(3);
yield return new Data.Download(normalizedUrl, downloadUrl, status, size);
}
}
public async Task<Data.Download> HydrateWithData(Data.Download download) {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT blob FROM downloads WHERE normalized_url = :url");
cmd.AddAndSet(":url", SqliteType.Text, download.NormalizedUrl);
await using var reader = await cmd.ExecuteReaderAsync();
if (reader.Read() && !reader.IsDBNull(0)) {
return download.WithData((byte[]) reader["blob"]);
}
else {
return download;
}
}
public async Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) {
using var conn = pool.Take();
await using var cmd = conn.Command(
"""
SELECT a.type, d.blob FROM downloads d
LEFT JOIN attachments a ON d.normalized_url = a.normalized_url
WHERE d.normalized_url = :normalized_url AND d.status = :success AND d.blob IS NOT NULL
"""
);
cmd.AddAndSet(":normalized_url", SqliteType.Text, normalizedUrl);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
await using var reader = await cmd.ExecuteReaderAsync();
if (!reader.Read()) {
return null;
}
return new DownloadedAttachment {
Type = reader.IsDBNull(0) ? null : reader.GetString(0),
Data = (byte[]) reader["blob"],
};
}
public async Task EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
using var conn = pool.Take();
await using var cmd = conn.Command(
$"""
INSERT INTO downloads (normalized_url, download_url, status, size)
SELECT a.normalized_url, a.download_url, :enqueued, MAX(a.size)
FROM attachments a
{filter.GenerateWhereClause("a")}
GROUP BY a.normalized_url
"""
);
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
await cmd.ExecuteNonQueryAsync(cancellationToken);
}
public async IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, [EnumeratorCancellation] CancellationToken cancellationToken) {
var found = new List<DownloadItem>();
using var conn = pool.Take();
await using (var cmd = conn.Command("SELECT normalized_url, download_url, size FROM downloads WHERE status = :enqueued LIMIT :limit")) {
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":limit", SqliteType.Integer, Math.Max(0, count));
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
while (reader.Read()) {
found.Add(new DownloadItem {
NormalizedUrl = reader.GetString(0),
DownloadUrl = reader.GetString(1),
Size = reader.GetUint64(2),
});
}
}
if (found.Count != 0) {
await using var cmd = conn.Command("UPDATE downloads SET status = :downloading WHERE normalized_url = :normalized_url AND status = :enqueued");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.Add(":normalized_url", SqliteType.Text);
foreach (var item in found) {
cmd.Set(":normalized_url", item.NormalizedUrl);
if (await cmd.ExecuteNonQueryAsync(cancellationToken) == 1) {
yield return item;
}
}
}
}
public async Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
using (var conn = pool.Take()) {
await conn.ExecuteAsync(
$"""
-- noinspection SqlWithoutWhere
DELETE FROM downloads
{filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching)}
"""
);
}
totalDownloadsComputer.Recompute();
}
}

View File

@ -0,0 +1,306 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteMessageRepository : IMessageRepository {
private readonly SqliteConnectionPool pool;
private readonly AsyncValueComputer<long>.Single totalMessagesComputer;
private readonly AsyncValueComputer<long>.Single totalAttachmentsComputer;
public SqliteMessageRepository(SqliteConnectionPool pool, AsyncValueComputer<long>.Single totalMessagesComputer, AsyncValueComputer<long>.Single totalAttachmentsComputer) {
this.pool = pool;
this.totalMessagesComputer = totalMessagesComputer;
this.totalAttachmentsComputer = totalAttachmentsComputer;
}
public async Task Add(IReadOnlyList<Message> messages) {
if (messages.Count == 0) {
return;
}
static SqliteCommand DeleteByMessageId(ISqliteConnection conn, string tableName) {
return conn.Delete(tableName, ("message_id", SqliteType.Integer));
}
static async Task ExecuteDeleteByMessageId(SqliteCommand cmd, object id) {
cmd.Set(":message_id", id);
await cmd.ExecuteNonQueryAsync();
}
bool addedAttachments = false;
using (var conn = pool.Take()) {
await using var tx = await conn.BeginTransactionAsync();
await using var messageCmd = conn.Upsert("messages", [
("message_id", SqliteType.Integer),
("sender_id", SqliteType.Integer),
("channel_id", SqliteType.Integer),
("text", SqliteType.Text),
("timestamp", SqliteType.Integer)
]);
await using var deleteEditTimestampCmd = DeleteByMessageId(conn, "edit_timestamps");
await using var deleteRepliedToCmd = DeleteByMessageId(conn, "replied_to");
await using var deleteAttachmentsCmd = DeleteByMessageId(conn, "attachments");
await using var deleteEmbedsCmd = DeleteByMessageId(conn, "embeds");
await using var deleteReactionsCmd = DeleteByMessageId(conn, "reactions");
await using var editTimestampCmd = conn.Insert("edit_timestamps", [
("message_id", SqliteType.Integer),
("edit_timestamp", SqliteType.Integer)
]);
await using var repliedToCmd = conn.Insert("replied_to", [
("message_id", SqliteType.Integer),
("replied_to_id", SqliteType.Integer)
]);
await using var attachmentCmd = conn.Insert("attachments", [
("message_id", SqliteType.Integer),
("attachment_id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text),
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("size", SqliteType.Integer),
("width", SqliteType.Integer),
("height", SqliteType.Integer)
]);
await using var embedCmd = conn.Insert("embeds", [
("message_id", SqliteType.Integer),
("json", SqliteType.Text)
]);
await using var reactionCmd = conn.Insert("reactions", [
("message_id", SqliteType.Integer),
("emoji_id", SqliteType.Integer),
("emoji_name", SqliteType.Text),
("emoji_flags", SqliteType.Integer),
("count", SqliteType.Integer)
]);
foreach (var message in messages) {
object messageId = message.Id;
messageCmd.Set(":message_id", messageId);
messageCmd.Set(":sender_id", message.Sender);
messageCmd.Set(":channel_id", message.Channel);
messageCmd.Set(":text", message.Text);
messageCmd.Set(":timestamp", message.Timestamp);
await messageCmd.ExecuteNonQueryAsync();
await ExecuteDeleteByMessageId(deleteEditTimestampCmd, messageId);
await ExecuteDeleteByMessageId(deleteRepliedToCmd, messageId);
await ExecuteDeleteByMessageId(deleteAttachmentsCmd, messageId);
await ExecuteDeleteByMessageId(deleteEmbedsCmd, messageId);
await ExecuteDeleteByMessageId(deleteReactionsCmd, messageId);
if (message.EditTimestamp is {} timestamp) {
editTimestampCmd.Set(":message_id", messageId);
editTimestampCmd.Set(":edit_timestamp", timestamp);
await editTimestampCmd.ExecuteNonQueryAsync();
}
if (message.RepliedToId is {} repliedToId) {
repliedToCmd.Set(":message_id", messageId);
repliedToCmd.Set(":replied_to_id", repliedToId);
await repliedToCmd.ExecuteNonQueryAsync();
}
if (!message.Attachments.IsEmpty) {
addedAttachments = true;
foreach (var attachment in message.Attachments) {
attachmentCmd.Set(":message_id", messageId);
attachmentCmd.Set(":attachment_id", attachment.Id);
attachmentCmd.Set(":name", attachment.Name);
attachmentCmd.Set(":type", attachment.Type);
attachmentCmd.Set(":normalized_url", attachment.NormalizedUrl);
attachmentCmd.Set(":download_url", attachment.DownloadUrl);
attachmentCmd.Set(":size", attachment.Size);
attachmentCmd.Set(":width", attachment.Width);
attachmentCmd.Set(":height", attachment.Height);
await attachmentCmd.ExecuteNonQueryAsync();
}
}
if (!message.Embeds.IsEmpty) {
foreach (var embed in message.Embeds) {
embedCmd.Set(":message_id", messageId);
embedCmd.Set(":json", embed.Json);
await embedCmd.ExecuteNonQueryAsync();
}
}
if (!message.Reactions.IsEmpty) {
foreach (var reaction in message.Reactions) {
reactionCmd.Set(":message_id", messageId);
reactionCmd.Set(":emoji_id", reaction.EmojiId);
reactionCmd.Set(":emoji_name", reaction.EmojiName);
reactionCmd.Set(":emoji_flags", (int) reaction.EmojiFlags);
reactionCmd.Set(":count", reaction.Count);
await reactionCmd.ExecuteNonQueryAsync();
}
}
}
await tx.CommitAsync();
}
totalMessagesComputer.Recompute();
if (addedAttachments) {
totalAttachmentsComputer.Recompute();
}
}
public async Task<long> Count(MessageFilter? filter, CancellationToken cancellationToken) {
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
private sealed class MesageToManyCommand<T> : IAsyncDisposable {
private readonly SqliteCommand cmd;
private readonly Func<SqliteDataReader, T> readItem;
public MesageToManyCommand(ISqliteConnection conn, string sql, Func<SqliteDataReader, T> readItem) {
this.cmd = conn.Command(sql);
this.cmd.Add(":message_id", SqliteType.Integer);
this.readItem = readItem;
}
public async Task<ImmutableList<T>> GetItems(ulong messageId) {
cmd.Set(":message_id", messageId);
var items = ImmutableList<T>.Empty;
await using var reader = await cmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
items = items.Add(readItem(reader));
}
return items;
}
public async ValueTask DisposeAsync() {
await cmd.DisposeAsync();
}
}
public async IAsyncEnumerable<Message> Get(MessageFilter? filter) {
using var conn = pool.Take();
const string AttachmentSql =
"""
SELECT attachment_id, name, type, normalized_url, download_url, size, width, height
FROM attachments
WHERE message_id = :message_id
""";
await using var attachmentCmd = new MesageToManyCommand<Attachment>(conn, AttachmentSql, static reader => new Attachment {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
Type = reader.IsDBNull(2) ? null : reader.GetString(2),
NormalizedUrl = reader.GetString(3),
DownloadUrl = reader.GetString(4),
Size = reader.GetUint64(5),
Width = reader.IsDBNull(6) ? null : reader.GetInt32(6),
Height = reader.IsDBNull(7) ? null : reader.GetInt32(7),
});
const string EmbedSql =
"""
SELECT json
FROM embeds
WHERE message_id = :message_id
""";
await using var embedCmd = new MesageToManyCommand<Embed>(conn, EmbedSql, static reader => new Embed {
Json = reader.GetString(0)
});
const string ReactionSql =
"""
SELECT emoji_id, emoji_name, emoji_flags, count
FROM reactions
WHERE message_id = :message_id
""";
await using var reactionsCmd = new MesageToManyCommand<Reaction>(conn, ReactionSql, static reader => new Reaction {
EmojiId = reader.IsDBNull(0) ? null : reader.GetUint64(0),
EmojiName = reader.IsDBNull(1) ? null : reader.GetString(1),
EmojiFlags = (EmojiFlags) reader.GetInt16(2),
Count = reader.GetInt32(3),
});
await using var messageCmd = conn.Command(
$"""
SELECT m.message_id, m.sender_id, m.channel_id, m.text, m.timestamp, et.edit_timestamp, rt.replied_to_id
FROM messages m
LEFT JOIN edit_timestamps et ON m.message_id = et.message_id
LEFT JOIN replied_to rt ON m.message_id = rt.message_id
{filter.GenerateWhereClause("m")}
"""
);
await using var reader = await messageCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
ulong messageId = reader.GetUint64(0);
yield return new Message {
Id = messageId,
Sender = reader.GetUint64(1),
Channel = reader.GetUint64(2),
Text = reader.GetString(3),
Timestamp = reader.GetInt64(4),
EditTimestamp = reader.IsDBNull(5) ? null : reader.GetInt64(5),
RepliedToId = reader.IsDBNull(6) ? null : reader.GetUint64(6),
Attachments = await attachmentCmd.GetItems(messageId),
Embeds = await embedCmd.GetItems(messageId),
Reactions = await reactionsCmd.GetItems(messageId)
};
}
}
public async IAsyncEnumerable<ulong> GetIds(MessageFilter? filter) {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause());
await using var reader = await cmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
yield return reader.GetUint64(0);
}
}
public async Task Remove(MessageFilter filter, FilterRemovalMode mode) {
using (var conn = pool.Take()) {
await conn.ExecuteAsync(
$"""
-- noinspection SqlWithoutWhere
DELETE FROM messages
{filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching)}
"""
);
}
totalMessagesComputer.Recompute();
}
}

View File

@ -0,0 +1,65 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteServerRepository : IServerRepository {
private readonly SqliteConnectionPool pool;
private readonly DatabaseStatistics statistics;
public SqliteServerRepository(SqliteConnectionPool pool, DatabaseStatistics statistics) {
this.pool = pool;
this.statistics = statistics;
}
internal async Task Initialize() {
using var conn = pool.Take();
await UpdateServerStatistics(conn);
}
private async Task UpdateServerStatistics(ISqliteConnection conn) {
statistics.TotalServers = await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM servers", static reader => reader?.GetInt64(0) ?? 0L);
}
public async Task Add(IReadOnlyList<Data.Server> servers) {
using var conn = pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("servers", [
("id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text)
]);
foreach (var server in servers) {
cmd.Set(":id", server.Id);
cmd.Set(":name", server.Name);
cmd.Set(":type", ServerTypes.ToString(server.Type));
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
}
await UpdateServerStatistics(conn);
}
public async IAsyncEnumerable<Data.Server> Get() {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT id, name, type FROM servers");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
yield return new Data.Server {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
Type = ServerTypes.FromString(reader.GetString(2)),
};
}
}
}

View File

@ -0,0 +1,68 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteUserRepository : IUserRepository {
private readonly SqliteConnectionPool pool;
private readonly DatabaseStatistics statistics;
public SqliteUserRepository(SqliteConnectionPool pool, DatabaseStatistics statistics) {
this.pool = pool;
this.statistics = statistics;
}
internal async Task Initialize() {
using var conn = pool.Take();
await UpdateUserStatistics(conn);
}
private async Task UpdateUserStatistics(ISqliteConnection conn) {
statistics.TotalUsers = await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM users", static reader => reader?.GetInt64(0) ?? 0L);
}
public async Task Add(IReadOnlyList<User> users) {
using var conn = pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("users", [
("id", SqliteType.Integer),
("name", SqliteType.Text),
("avatar_url", SqliteType.Text),
("discriminator", SqliteType.Text)
]);
foreach (var user in users) {
cmd.Set(":id", user.Id);
cmd.Set(":name", user.Name);
cmd.Set(":avatar_url", user.AvatarUrl);
cmd.Set(":discriminator", user.Discriminator);
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
}
await UpdateUserStatistics(conn);
}
public async IAsyncEnumerable<User> Get() {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
yield return new User {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
AvatarUrl = reader.IsDBNull(2) ? null : reader.GetString(2),
Discriminator = reader.IsDBNull(3) ? null : reader.GetString(3),
};
}
}
}

View File

@ -1,4 +1,5 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Data.Common;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Database.Exceptions; using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite.Utils; using DHT.Server.Database.Sqlite.Utils;
@ -168,7 +169,7 @@ sealed class Schema {
} }
} }
await using var tx = conn.BeginTransaction(); await using var tx = await conn.BeginTransactionAsync();
int totalUrls = normalizedUrls.Count; int totalUrls = normalizedUrls.Count;
int processedUrls = -1; int processedUrls = -1;
@ -214,9 +215,9 @@ sealed class Schema {
conn.Execute("PRAGMA cache_size = -20000"); conn.Execute("PRAGMA cache_size = -20000");
SqliteTransaction tx; DbTransaction tx;
await using (tx = conn.BeginTransaction()) { await using (tx = await conn.BeginTransactionAsync()) {
await reporter.SubWork("Deleting duplicates...", 0, 0); await reporter.SubWork("Deleting duplicates...", 0, 0);
await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) { await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) {
@ -232,7 +233,7 @@ sealed class Schema {
int totalUrls = normalizedUrlsToOriginalUrls.Count; int totalUrls = normalizedUrlsToOriginalUrls.Count;
int processedUrls = -1; int processedUrls = -1;
tx = conn.BeginTransaction(); tx = await conn.BeginTransactionAsync();
await using (var updateCmd = conn.Command("UPDATE downloads SET download_url = :download_url, url = :normalized_url WHERE url = :download_url")) { await using (var updateCmd = conn.Command("UPDATE downloads SET download_url = :download_url, url = :normalized_url WHERE url = :download_url")) {
updateCmd.Add(":normalized_url", SqliteType.Text); updateCmd.Add(":normalized_url", SqliteType.Text);
@ -247,8 +248,8 @@ sealed class Schema {
await tx.CommitAsync(); await tx.CommitAsync();
await tx.DisposeAsync(); await tx.DisposeAsync();
tx = conn.BeginTransaction(); tx = await conn.BeginTransactionAsync();
updateCmd.Transaction = tx; updateCmd.Transaction = (SqliteTransaction) tx;
} }
updateCmd.Set(":normalized_url", normalizedUrl); updateCmd.Set(":normalized_url", normalizedUrl);

View File

@ -1,15 +1,8 @@
using System; using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Database.Repositories;
using DHT.Server.Data.Aggregations; using DHT.Server.Database.Sqlite.Repositories;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Sqlite.Utils; using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Collections;
using DHT.Utils.Logging;
using DHT.Utils.Tasks; using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite; using Microsoft.Data.Sqlite;
@ -36,7 +29,9 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
} }
if (wasOpened) { if (wasOpened) {
return new SqliteDatabaseFile(path, pool, computeTaskResultScheduler); var db = new SqliteDatabaseFile(path, pool, computeTaskResultScheduler);
await db.Initialize();
return db;
} }
else { else {
pool.Dispose(); pool.Dispose();
@ -47,14 +42,25 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
public string Path { get; } public string Path { get; }
public DatabaseStatistics Statistics { get; } public DatabaseStatistics Statistics { get; }
private readonly Log log; public IUserRepository Users => users;
public IServerRepository Servers => servers;
public IChannelRepository Channels => channels;
public IMessageRepository Messages => messages;
public IDownloadRepository Downloads => downloads;
private readonly SqliteConnectionPool pool; private readonly SqliteConnectionPool pool;
private readonly SqliteUserRepository users;
private readonly SqliteServerRepository servers;
private readonly SqliteChannelRepository channels;
private readonly SqliteMessageRepository messages;
private readonly SqliteDownloadRepository downloads;
private readonly AsyncValueComputer<long>.Single totalMessagesComputer; private readonly AsyncValueComputer<long>.Single totalMessagesComputer;
private readonly AsyncValueComputer<long>.Single totalAttachmentsComputer; private readonly AsyncValueComputer<long>.Single totalAttachmentsComputer;
private readonly AsyncValueComputer<long>.Single totalDownloadsComputer; private readonly AsyncValueComputer<long>.Single totalDownloadsComputer;
private SqliteDatabaseFile(string path, SqliteConnectionPool pool, TaskScheduler computeTaskResultScheduler) { private SqliteDatabaseFile(string path, SqliteConnectionPool pool, TaskScheduler computeTaskResultScheduler) {
this.log = Log.ForType(typeof(SqliteDatabaseFile), System.IO.Path.GetFileName(path));
this.pool = pool; this.pool = pool;
this.totalMessagesComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateMessageStatistics, computeTaskResultScheduler).WithOutdatedResults().BuildWithComputer(ComputeMessageStatistics); this.totalMessagesComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateMessageStatistics, computeTaskResultScheduler).WithOutdatedResults().BuildWithComputer(ComputeMessageStatistics);
@ -64,669 +70,65 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
this.Path = path; this.Path = path;
this.Statistics = new DatabaseStatistics(); this.Statistics = new DatabaseStatistics();
using (var conn = pool.Take()) { this.users = new SqliteUserRepository(pool, Statistics);
UpdateServerStatistics(conn); this.servers = new SqliteServerRepository(pool, Statistics);
UpdateChannelStatistics(conn); this.channels = new SqliteChannelRepository(pool, Statistics);
UpdateUserStatistics(conn); this.messages = new SqliteMessageRepository(pool, totalMessagesComputer, totalAttachmentsComputer);
} this.downloads = new SqliteDownloadRepository(pool, totalDownloadsComputer);
totalMessagesComputer.Recompute(); totalMessagesComputer.Recompute();
totalAttachmentsComputer.Recompute(); totalAttachmentsComputer.Recompute();
totalDownloadsComputer.Recompute(); totalDownloadsComputer.Recompute();
} }
private async Task Initialize() {
await users.Initialize();
await servers.Initialize();
await channels.Initialize();
}
public void Dispose() { public void Dispose() {
totalMessagesComputer.Cancel();
totalAttachmentsComputer.Cancel();
totalDownloadsComputer.Cancel();
pool.Dispose(); pool.Dispose();
} }
public DatabaseStatisticsSnapshot SnapshotStatistics() { public async Task<DatabaseStatisticsSnapshot> SnapshotStatistics() {
return new DatabaseStatisticsSnapshot { return new DatabaseStatisticsSnapshot {
TotalServers = Statistics.TotalServers, TotalServers = Statistics.TotalServers,
TotalChannels = Statistics.TotalChannels, TotalChannels = Statistics.TotalChannels,
TotalUsers = Statistics.TotalUsers, TotalUsers = Statistics.TotalUsers,
TotalMessages = ComputeMessageStatistics(), TotalMessages = await ComputeMessageStatistics(),
}; };
} }
public void AddServer(Data.Server server) { public async Task Vacuum() {
using var conn = pool.Take(); using var conn = pool.Take();
using var cmd = conn.Upsert("servers", new[] { await conn.ExecuteAsync("VACUUM");
("id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text),
});
cmd.Set(":id", server.Id);
cmd.Set(":name", server.Name);
cmd.Set(":type", ServerTypes.ToString(server.Type));
cmd.ExecuteNonQuery();
UpdateServerStatistics(conn);
} }
public List<Data.Server> GetAllServers() { private async Task<long> ComputeMessageStatistics() {
var perf = log.Start();
var list = new List<Data.Server>();
using var conn = pool.Take(); using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, name, type FROM servers"); return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages", static reader => reader?.GetInt64(0) ?? 0L);
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
list.Add(new Data.Server {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
Type = ServerTypes.FromString(reader.GetString(2)),
});
}
perf.End();
return list;
}
public void AddChannel(Channel channel) {
using var conn = pool.Take();
using var cmd = conn.Upsert("channels", new[] {
("id", SqliteType.Integer),
("server", SqliteType.Integer),
("name", SqliteType.Text),
("parent_id", SqliteType.Integer),
("position", SqliteType.Integer),
("topic", SqliteType.Text),
("nsfw", SqliteType.Integer),
});
cmd.Set(":id", channel.Id);
cmd.Set(":server", channel.Server);
cmd.Set(":name", channel.Name);
cmd.Set(":parent_id", channel.ParentId);
cmd.Set(":position", channel.Position);
cmd.Set(":topic", channel.Topic);
cmd.Set(":nsfw", channel.Nsfw);
cmd.ExecuteNonQuery();
UpdateChannelStatistics(conn);
}
public List<Channel> GetAllChannels() {
var list = new List<Channel>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
list.Add(new Channel {
Id = reader.GetUint64(0),
Server = reader.GetUint64(1),
Name = reader.GetString(2),
ParentId = reader.IsDBNull(3) ? null : reader.GetUint64(3),
Position = reader.IsDBNull(4) ? null : reader.GetInt32(4),
Topic = reader.IsDBNull(5) ? null : reader.GetString(5),
Nsfw = reader.IsDBNull(6) ? null : reader.GetBoolean(6),
});
}
return list;
}
public void AddUsers(User[] users) {
using var conn = pool.Take();
using var tx = conn.BeginTransaction();
using var cmd = conn.Upsert("users", new[] {
("id", SqliteType.Integer),
("name", SqliteType.Text),
("avatar_url", SqliteType.Text),
("discriminator", SqliteType.Text),
});
foreach (var user in users) {
cmd.Set(":id", user.Id);
cmd.Set(":name", user.Name);
cmd.Set(":avatar_url", user.AvatarUrl);
cmd.Set(":discriminator", user.Discriminator);
cmd.ExecuteNonQuery();
}
tx.Commit();
UpdateUserStatistics(conn);
}
public List<User> GetAllUsers() {
var perf = log.Start();
var list = new List<User>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
list.Add(new User {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
AvatarUrl = reader.IsDBNull(2) ? null : reader.GetString(2),
Discriminator = reader.IsDBNull(3) ? null : reader.GetString(3),
});
}
perf.End();
return list;
}
public void AddMessages(Message[] messages) {
static SqliteCommand DeleteByMessageId(ISqliteConnection conn, string tableName) {
return conn.Delete(tableName, ("message_id", SqliteType.Integer));
}
static void ExecuteDeleteByMessageId(SqliteCommand cmd, object id) {
cmd.Set(":message_id", id);
cmd.ExecuteNonQuery();
}
bool addedAttachments = false;
using (var conn = pool.Take()) {
using var tx = conn.BeginTransaction();
using var messageCmd = conn.Upsert("messages", new[] {
("message_id", SqliteType.Integer),
("sender_id", SqliteType.Integer),
("channel_id", SqliteType.Integer),
("text", SqliteType.Text),
("timestamp", SqliteType.Integer),
});
using var deleteEditTimestampCmd = DeleteByMessageId(conn, "edit_timestamps");
using var deleteRepliedToCmd = DeleteByMessageId(conn, "replied_to");
using var deleteAttachmentsCmd = DeleteByMessageId(conn, "attachments");
using var deleteEmbedsCmd = DeleteByMessageId(conn, "embeds");
using var deleteReactionsCmd = DeleteByMessageId(conn, "reactions");
using var editTimestampCmd = conn.Insert("edit_timestamps", new [] {
("message_id", SqliteType.Integer),
("edit_timestamp", SqliteType.Integer),
});
using var repliedToCmd = conn.Insert("replied_to", new [] {
("message_id", SqliteType.Integer),
("replied_to_id", SqliteType.Integer),
});
using var attachmentCmd = conn.Insert("attachments", new[] {
("message_id", SqliteType.Integer),
("attachment_id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text),
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("size", SqliteType.Integer),
("width", SqliteType.Integer),
("height", SqliteType.Integer),
});
using var embedCmd = conn.Insert("embeds", new[] {
("message_id", SqliteType.Integer),
("json", SqliteType.Text),
});
using var reactionCmd = conn.Insert("reactions", new[] {
("message_id", SqliteType.Integer),
("emoji_id", SqliteType.Integer),
("emoji_name", SqliteType.Text),
("emoji_flags", SqliteType.Integer),
("count", SqliteType.Integer),
});
foreach (var message in messages) {
object messageId = message.Id;
messageCmd.Set(":message_id", messageId);
messageCmd.Set(":sender_id", message.Sender);
messageCmd.Set(":channel_id", message.Channel);
messageCmd.Set(":text", message.Text);
messageCmd.Set(":timestamp", message.Timestamp);
messageCmd.ExecuteNonQuery();
ExecuteDeleteByMessageId(deleteEditTimestampCmd, messageId);
ExecuteDeleteByMessageId(deleteRepliedToCmd, messageId);
ExecuteDeleteByMessageId(deleteAttachmentsCmd, messageId);
ExecuteDeleteByMessageId(deleteEmbedsCmd, messageId);
ExecuteDeleteByMessageId(deleteReactionsCmd, messageId);
if (message.EditTimestamp is {} timestamp) {
editTimestampCmd.Set(":message_id", messageId);
editTimestampCmd.Set(":edit_timestamp", timestamp);
editTimestampCmd.ExecuteNonQuery();
}
if (message.RepliedToId is {} repliedToId) {
repliedToCmd.Set(":message_id", messageId);
repliedToCmd.Set(":replied_to_id", repliedToId);
repliedToCmd.ExecuteNonQuery();
}
if (!message.Attachments.IsEmpty) {
addedAttachments = true;
foreach (var attachment in message.Attachments) {
attachmentCmd.Set(":message_id", messageId);
attachmentCmd.Set(":attachment_id", attachment.Id);
attachmentCmd.Set(":name", attachment.Name);
attachmentCmd.Set(":type", attachment.Type);
attachmentCmd.Set(":normalized_url", attachment.NormalizedUrl);
attachmentCmd.Set(":download_url", attachment.DownloadUrl);
attachmentCmd.Set(":size", attachment.Size);
attachmentCmd.Set(":width", attachment.Width);
attachmentCmd.Set(":height", attachment.Height);
attachmentCmd.ExecuteNonQuery();
}
}
if (!message.Embeds.IsEmpty) {
foreach (var embed in message.Embeds) {
embedCmd.Set(":message_id", messageId);
embedCmd.Set(":json", embed.Json);
embedCmd.ExecuteNonQuery();
}
}
if (!message.Reactions.IsEmpty) {
foreach (var reaction in message.Reactions) {
reactionCmd.Set(":message_id", messageId);
reactionCmd.Set(":emoji_id", reaction.EmojiId);
reactionCmd.Set(":emoji_name", reaction.EmojiName);
reactionCmd.Set(":emoji_flags", (int) reaction.EmojiFlags);
reactionCmd.Set(":count", reaction.Count);
reactionCmd.ExecuteNonQuery();
}
}
}
tx.Commit();
}
totalMessagesComputer.Recompute();
if (addedAttachments) {
totalAttachmentsComputer.Recompute();
}
}
public int CountMessages(MessageFilter? filter = null) {
using var conn = pool.Take();
using var cmd = conn.Command("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause());
using var reader = cmd.ExecuteReader();
return reader.Read() ? reader.GetInt32(0) : 0;
}
public List<Message> GetMessages(MessageFilter? filter = null) {
var perf = log.Start();
var list = new List<Message>();
var attachments = GetAllAttachments();
var embeds = GetAllEmbeds();
var reactions = GetAllReactions();
using var conn = pool.Take();
using var cmd = conn.Command($"""
SELECT m.message_id, m.sender_id, m.channel_id, m.text, m.timestamp, et.edit_timestamp, rt.replied_to_id
FROM messages m
LEFT JOIN edit_timestamps et ON m.message_id = et.message_id
LEFT JOIN replied_to rt ON m.message_id = rt.message_id
{filter.GenerateWhereClause("m")}
""");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong id = reader.GetUint64(0);
list.Add(new Message {
Id = id,
Sender = reader.GetUint64(1),
Channel = reader.GetUint64(2),
Text = reader.GetString(3),
Timestamp = reader.GetInt64(4),
EditTimestamp = reader.IsDBNull(5) ? null : reader.GetInt64(5),
RepliedToId = reader.IsDBNull(6) ? null : reader.GetUint64(6),
Attachments = attachments.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Attachment>.Empty,
Embeds = embeds.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Embed>.Empty,
Reactions = reactions.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Reaction>.Empty,
});
}
perf.End();
return list;
}
public HashSet<ulong> GetMessageIds(MessageFilter? filter = null) {
var perf = log.Start();
var ids = new HashSet<ulong>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause());
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ids.Add(reader.GetUint64(0));
}
perf.End();
return ids;
}
public void RemoveMessages(MessageFilter filter, FilterRemovalMode mode) {
var perf = log.Start();
DeleteFromTable("messages", filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching));
totalMessagesComputer.Recompute();
perf.End();
}
public int CountAttachments(AttachmentFilter? filter = null) {
using var conn = pool.Take();
using var cmd = conn.Command("SELECT COUNT(DISTINCT normalized_url) FROM attachments a" + filter.GenerateWhereClause("a"));
using var reader = cmd.ExecuteReader();
return reader.Read() ? reader.GetInt32(0) : 0;
}
public void AddDownload(Data.Download download) {
using var conn = pool.Take();
using var cmd = conn.Upsert("downloads", new[] {
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("status", SqliteType.Integer),
("size", SqliteType.Integer),
("blob", SqliteType.Blob),
});
cmd.Set(":normalized_url", download.NormalizedUrl);
cmd.Set(":download_url", download.DownloadUrl);
cmd.Set(":status", (int) download.Status);
cmd.Set(":size", download.Size);
cmd.Set(":blob", download.Data);
cmd.ExecuteNonQuery();
totalDownloadsComputer.Recompute();
}
public List<Data.Download> GetDownloadsWithoutData() {
var list = new List<Data.Download>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT normalized_url, download_url, status, size FROM downloads");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
string normalizedUrl = reader.GetString(0);
string downloadUrl = reader.GetString(1);
var status = (DownloadStatus) reader.GetInt32(2);
ulong size = reader.GetUint64(3);
list.Add(new Data.Download(normalizedUrl, downloadUrl, status, size));
}
return list;
}
public Data.Download GetDownloadWithData(Data.Download download) {
using var conn = pool.Take();
using var cmd = conn.Command("SELECT blob FROM downloads WHERE normalized_url = :url");
cmd.AddAndSet(":url", SqliteType.Text, download.NormalizedUrl);
using var reader = cmd.ExecuteReader();
if (reader.Read() && !reader.IsDBNull(0)) {
return download.WithData((byte[]) reader["blob"]);
}
else {
return download;
}
}
public DownloadedAttachment? GetDownloadedAttachment(string normalizedUrl) {
using var conn = pool.Take();
using var cmd = conn.Command("""
SELECT a.type, d.blob FROM downloads d
LEFT JOIN attachments a ON d.normalized_url = a.normalized_url
WHERE d.normalized_url = :normalized_url AND d.status = :success AND d.blob IS NOT NULL
""");
cmd.AddAndSet(":normalized_url", SqliteType.Text, normalizedUrl);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
using var reader = cmd.ExecuteReader();
if (!reader.Read()) {
return null;
}
return new DownloadedAttachment {
Type = reader.IsDBNull(0) ? null : reader.GetString(0),
Data = (byte[]) reader["blob"],
};
}
public void EnqueueDownloadItems(AttachmentFilter? filter = null) {
using var conn = pool.Take();
using var cmd = conn.Command($"""
INSERT INTO downloads (normalized_url, download_url, status, size)
SELECT a.normalized_url, a.download_url, :enqueued, MAX(a.size)
FROM attachments a
{filter.GenerateWhereClause("a")}
GROUP BY a.normalized_url
""");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.ExecuteNonQuery();
}
public List<DownloadItem> PullEnqueuedDownloadItems(int count) {
var found = new List<DownloadItem>();
var pulled = new List<DownloadItem>();
using var conn = pool.Take();
using (var cmd = conn.Command("SELECT normalized_url, download_url, size FROM downloads WHERE status = :enqueued LIMIT :limit")) {
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":limit", SqliteType.Integer, Math.Max(0, count));
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
found.Add(new DownloadItem {
NormalizedUrl = reader.GetString(0),
DownloadUrl = reader.GetString(1),
Size = reader.GetUint64(2),
});
}
}
if (found.Count != 0) {
using var cmd = conn.Command("UPDATE downloads SET status = :downloading WHERE normalized_url = :normalized_url AND status = :enqueued");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.Add(":normalized_url", SqliteType.Text);
foreach (var item in found) {
cmd.Set(":normalized_url", item.NormalizedUrl);
if (cmd.ExecuteNonQuery() == 1) {
pulled.Add(item);
}
}
}
return pulled;
}
public void RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
DeleteFromTable("downloads", filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching));
totalDownloadsComputer.Recompute();
}
public DownloadStatusStatistics GetDownloadStatusStatistics() {
static void LoadUndownloadedStatistics(ISqliteConnection conn, DownloadStatusStatistics result) {
using var cmd = conn.Command("SELECT IFNULL(COUNT(size), 0), IFNULL(SUM(size), 0) FROM (SELECT MAX(a.size) size FROM attachments a WHERE a.normalized_url NOT IN (SELECT d.normalized_url FROM downloads d) GROUP BY a.normalized_url)");
using var reader = cmd.ExecuteReader();
if (reader.Read()) {
result.SkippedCount = reader.GetInt32(0);
result.SkippedSize = reader.GetUint64(1);
}
}
static void LoadSuccessStatistics(ISqliteConnection conn, DownloadStatusStatistics result) {
using var cmd = conn.Command("""
SELECT
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN size ELSE 0 END), 0)
FROM downloads
""");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
using var reader = cmd.ExecuteReader();
if (reader.Read()) {
result.EnqueuedCount = reader.GetInt32(0);
result.EnqueuedSize = reader.GetUint64(1);
result.SuccessfulCount = reader.GetInt32(2);
result.SuccessfulSize = reader.GetUint64(3);
result.FailedCount = reader.GetInt32(4);
result.FailedSize = reader.GetUint64(5);
}
}
var result = new DownloadStatusStatistics();
using var conn = pool.Take();
LoadUndownloadedStatistics(conn, result);
LoadSuccessStatistics(conn, result);
return result;
}
private MultiDictionary<ulong, Attachment> GetAllAttachments() {
var dict = new MultiDictionary<ulong, Attachment>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, attachment_id, name, type, normalized_url, download_url, size, width, height FROM attachments");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Attachment {
Id = reader.GetUint64(1),
Name = reader.GetString(2),
Type = reader.IsDBNull(3) ? null : reader.GetString(3),
NormalizedUrl = reader.GetString(4),
DownloadUrl = reader.GetString(5),
Size = reader.GetUint64(6),
Width = reader.IsDBNull(7) ? null : reader.GetInt32(7),
Height = reader.IsDBNull(8) ? null : reader.GetInt32(8),
});
}
return dict;
}
private MultiDictionary<ulong, Embed> GetAllEmbeds() {
var dict = new MultiDictionary<ulong, Embed>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, json FROM embeds");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Embed {
Json = reader.GetString(1),
});
}
return dict;
}
private MultiDictionary<ulong, Reaction> GetAllReactions() {
var dict = new MultiDictionary<ulong, Reaction>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, emoji_id, emoji_name, emoji_flags, count FROM reactions");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Reaction {
EmojiId = reader.IsDBNull(1) ? null : reader.GetUint64(1),
EmojiName = reader.IsDBNull(2) ? null : reader.GetString(2),
EmojiFlags = (EmojiFlags) reader.GetInt16(3),
Count = reader.GetInt32(4),
});
}
return dict;
}
private void DeleteFromTable(string table, string whereClause) {
// Rider is being stupid...
StringBuilder build = new StringBuilder()
.Append("DELETE ")
.Append("FROM ")
.Append(table)
.Append(whereClause);
using var conn = pool.Take();
using var cmd = conn.Command(build.ToString());
cmd.ExecuteNonQuery();
}
public void Vacuum() {
using var conn = pool.Take();
using var cmd = conn.Command("VACUUM");
cmd.ExecuteNonQuery();
}
private void UpdateServerStatistics(ISqliteConnection conn) {
Statistics.TotalServers = conn.SelectScalar("SELECT COUNT(*) FROM servers") as long? ?? 0;
}
private void UpdateChannelStatistics(ISqliteConnection conn) {
Statistics.TotalChannels = conn.SelectScalar("SELECT COUNT(*) FROM channels") as long? ?? 0;
}
private void UpdateUserStatistics(ISqliteConnection conn) {
Statistics.TotalUsers = conn.SelectScalar("SELECT COUNT(*) FROM users") as long? ?? 0;
}
private long ComputeMessageStatistics() {
using var conn = pool.Take();
return conn.SelectScalar("SELECT COUNT(*) FROM messages") as long? ?? 0L;
} }
private void UpdateMessageStatistics(long totalMessages) { private void UpdateMessageStatistics(long totalMessages) {
Statistics.TotalMessages = totalMessages; Statistics.TotalMessages = totalMessages;
} }
private long ComputeAttachmentStatistics() { private async Task<long> ComputeAttachmentStatistics() {
using var conn = pool.Take(); using var conn = pool.Take();
return conn.SelectScalar("SELECT COUNT(DISTINCT normalized_url) FROM attachments") as long? ?? 0L; return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments", static reader => reader?.GetInt64(0) ?? 0L);
} }
private void UpdateAttachmentStatistics(long totalAttachments) { private void UpdateAttachmentStatistics(long totalAttachments) {
Statistics.TotalAttachments = totalAttachments; Statistics.TotalAttachments = totalAttachments;
} }
private long ComputeDownloadStatistics() { private async Task<long> ComputeDownloadStatistics() {
using var conn = pool.Take(); using var conn = pool.Take();
return conn.SelectScalar("SELECT COUNT(*) FROM downloads") as long? ?? 0L; return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM downloads", static reader => reader?.GetInt64(0) ?? 0L);
} }
private void UpdateDownloadStatistics(long totalDownloads) { private void UpdateDownloadStatistics(long totalDownloads) {

View File

@ -8,9 +8,13 @@ using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite; namespace DHT.Server.Database.Sqlite;
static class SqliteFilters { static class SqliteFilters {
private static string WhereAll(bool invert) {
return invert ? "WHERE FALSE" : "";
}
public static string GenerateWhereClause(this MessageFilter? filter, string? tableAlias = null, bool invert = false) { public static string GenerateWhereClause(this MessageFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null) { if (filter == null || filter.IsEmpty) {
return ""; return WhereAll(invert);
} }
var where = new SqliteWhereGenerator(tableAlias, invert); var where = new SqliteWhereGenerator(tableAlias, invert);
@ -39,8 +43,8 @@ static class SqliteFilters {
} }
public static string GenerateWhereClause(this AttachmentFilter? filter, string? tableAlias = null, bool invert = false) { public static string GenerateWhereClause(this AttachmentFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null) { if (filter == null || filter.IsEmpty) {
return ""; return WhereAll(invert);
} }
var where = new SqliteWhereGenerator(tableAlias, invert); var where = new SqliteWhereGenerator(tableAlias, invert);
@ -60,8 +64,8 @@ static class SqliteFilters {
} }
public static string GenerateWhereClause(this DownloadItemFilter? filter, string? tableAlias = null, bool invert = false) { public static string GenerateWhereClause(this DownloadItemFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null) { if (filter == null || filter.IsEmpty) {
return ""; return WhereAll(invert);
} }
var where = new SqliteWhereGenerator(tableAlias, invert); var where = new SqliteWhereGenerator(tableAlias, invert);

View File

@ -1,7 +1,7 @@
using System; using System;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace DHT.Server.Database.Sqlite; namespace DHT.Server.Database.Sqlite.Utils;
public interface ISchemaUpgradeCallbacks { public interface ISchemaUpgradeCallbacks {
Task<bool> CanUpgrade(); Task<bool> CanUpgrade();

View File

@ -1,12 +1,16 @@
using System; using System;
using System.Data.Common;
using System.Linq; using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;
using Microsoft.Data.Sqlite; using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils; namespace DHT.Server.Database.Sqlite.Utils;
static class SqliteExtensions { static class SqliteExtensions {
public static SqliteTransaction BeginTransaction(this ISqliteConnection conn) { public static ValueTask<DbTransaction> BeginTransactionAsync(this ISqliteConnection conn) {
return conn.InnerConnection.BeginTransaction(); return conn.InnerConnection.BeginTransactionAsync();
} }
public static SqliteCommand Command(this ISqliteConnection conn, string sql) { public static SqliteCommand Command(this ISqliteConnection conn, string sql) {
@ -20,6 +24,18 @@ static class SqliteExtensions {
cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery();
} }
public static async Task<int> ExecuteAsync(this ISqliteConnection conn, [LanguageInjection("sql")] string sql, CancellationToken cancellationToken = default) {
await using var cmd = conn.Command(sql);
return await cmd.ExecuteNonQueryAsync(cancellationToken);
}
public static async Task<T> ExecuteReaderAsync<T>(this ISqliteConnection conn, string sql, Func<SqliteDataReader?, T> readFunction, CancellationToken cancellationToken = default) {
await using var cmd = conn.Command(sql);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
return reader.Read() ? readFunction(reader) : readFunction(null);
}
public static object? SelectScalar(this ISqliteConnection conn, string sql) { public static object? SelectScalar(this ISqliteConnection conn, string sql) {
using var cmd = conn.Command(sql); using var cmd = conn.Command(sql);
return cmd.ExecuteScalar(); return cmd.ExecuteScalar();
@ -52,7 +68,7 @@ static class SqliteExtensions {
public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) { public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) {
var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name); var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name);
CreateParameters(cmd, new [] { column }); CreateParameters(cmd, new[] { column });
return cmd; return cmd;
} }

View File

@ -45,7 +45,7 @@ sealed class DownloaderTask : IAsyncDisposable {
private async Task RunQueueWriterTask() { private async Task RunQueueWriterTask() {
while (await downloadQueue.Writer.WaitToWriteAsync(cancellationToken)) { while (await downloadQueue.Writer.WaitToWriteAsync(cancellationToken)) {
var newItems = db.PullEnqueuedDownloadItems(QueueSize); var newItems = await db.Downloads.PullEnqueuedDownloadItems(QueueSize, cancellationToken).ToListAsync(cancellationToken);
if (newItems.Count == 0) { if (newItems.Count == 0) {
await Task.Delay(TimeSpan.FromMilliseconds(50), cancellationToken); await Task.Delay(TimeSpan.FromMilliseconds(50), cancellationToken);
continue; continue;
@ -70,14 +70,14 @@ sealed class DownloaderTask : IAsyncDisposable {
try { try {
var downloadedBytes = await client.GetByteArrayAsync(item.DownloadUrl, cancellationToken); var downloadedBytes = await client.GetByteArrayAsync(item.DownloadUrl, cancellationToken);
db.AddDownload(Data.Download.NewSuccess(item, downloadedBytes)); await db.Downloads.AddDownload(Data.Download.NewSuccess(item, downloadedBytes));
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
// Ignore. // Ignore.
} catch (HttpRequestException e) { } catch (HttpRequestException e) {
db.AddDownload(Data.Download.NewFailure(item, e.StatusCode, item.Size)); await db.Downloads.AddDownload(Data.Download.NewFailure(item, e.StatusCode, item.Size));
log.Error(e); log.Error(e);
} catch (Exception e) { } catch (Exception e) {
db.AddDownload(Data.Download.NewFailure(item, null, item.Size)); await db.Downloads.AddDownload(Data.Download.NewFailure(item, null, item.Size));
log.Error(e); log.Error(e);
} finally { } finally {
try { try {

View File

@ -10,15 +10,15 @@ namespace DHT.Server.Endpoints;
sealed class GetAttachmentEndpoint : BaseEndpoint { sealed class GetAttachmentEndpoint : BaseEndpoint {
public GetAttachmentEndpoint(IDatabaseFile db) : base(db) {} public GetAttachmentEndpoint(IDatabaseFile db) : base(db) {}
protected override Task<IHttpOutput> Respond(HttpContext ctx) { protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
string attachmentUrl = WebUtility.UrlDecode((string) ctx.Request.RouteValues["url"]!); string attachmentUrl = WebUtility.UrlDecode((string) ctx.Request.RouteValues["url"]!);
DownloadedAttachment? maybeDownloadedAttachment = Db.GetDownloadedAttachment(attachmentUrl); DownloadedAttachment? maybeDownloadedAttachment = await Db.Downloads.GetDownloadedAttachment(attachmentUrl);
if (maybeDownloadedAttachment is {} downloadedAttachment) { if (maybeDownloadedAttachment is {} downloadedAttachment) {
return Task.FromResult<IHttpOutput>(new HttpOutput.File(downloadedAttachment.Type, downloadedAttachment.Data)); return new HttpOutput.File(downloadedAttachment.Type, downloadedAttachment.Data);
} }
else { else {
return Task.FromResult<IHttpOutput>(new HttpOutput.Redirect(attachmentUrl, permanent: false)); return new HttpOutput.Redirect(attachmentUrl, permanent: false);
} }
} }
} }

View File

@ -16,19 +16,19 @@ sealed class TrackChannelEndpoint : BaseEndpoint {
var server = ReadServer(root.RequireObject("server"), "server"); var server = ReadServer(root.RequireObject("server"), "server");
var channel = ReadChannel(root.RequireObject("channel"), "channel", server.Id); var channel = ReadChannel(root.RequireObject("channel"), "channel", server.Id);
Db.AddServer(server); await Db.Servers.Add([server]);
Db.AddChannel(channel); await Db.Channels.Add([channel]);
return HttpOutput.None; return HttpOutput.None;
} }
private static Data.Server ReadServer(JsonElement json, string path) => new() { private static Data.Server ReadServer(JsonElement json, string path) => new () {
Id = json.RequireSnowflake("id", path), Id = json.RequireSnowflake("id", path),
Name = json.RequireString("name", path), Name = json.RequireString("name", path),
Type = ServerTypes.FromString(json.RequireString("type", path)) ?? throw new HttpException(HttpStatusCode.BadRequest, "Server type must be either 'SERVER', 'GROUP', or 'DM'.") Type = ServerTypes.FromString(json.RequireString("type", path)) ?? throw new HttpException(HttpStatusCode.BadRequest, "Server type must be either 'SERVER', 'GROUP', or 'DM'.")
}; };
private static Channel ReadChannel(JsonElement json, string path, ulong serverId) => new() { private static Channel ReadChannel(JsonElement json, string path, ulong serverId) => new () {
Id = json.RequireSnowflake("id", path), Id = json.RequireSnowflake("id", path),
Server = serverId, Server = serverId,
Name = json.RequireString("name", path), Name = json.RequireString("name", path),

View File

@ -39,14 +39,14 @@ sealed class TrackMessagesEndpoint : BaseEndpoint {
} }
var addedMessageFilter = new MessageFilter { MessageIds = addedMessageIds }; var addedMessageFilter = new MessageFilter { MessageIds = addedMessageIds };
bool anyNewMessages = Db.CountMessages(addedMessageFilter) < messages.Length; bool anyNewMessages = await Db.Messages.Count(addedMessageFilter) < addedMessageIds.Count;
Db.AddMessages(messages); await Db.Messages.Add(messages);
return new HttpOutput.Text(anyNewMessages ? HasNewMessages : NoNewMessages); return new HttpOutput.Text(anyNewMessages ? HasNewMessages : NoNewMessages);
} }
private static Message ReadMessage(JsonElement json, string path) => new() { private static Message ReadMessage(JsonElement json, string path) => new () {
Id = json.RequireSnowflake("id", path), Id = json.RequireSnowflake("id", path),
Sender = json.RequireSnowflake("sender", path), Sender = json.RequireSnowflake("sender", path),
Channel = json.RequireSnowflake("channel", path), Channel = json.RequireSnowflake("channel", path),
@ -54,9 +54,9 @@ sealed class TrackMessagesEndpoint : BaseEndpoint {
Timestamp = json.RequireLong("timestamp", path), Timestamp = json.RequireLong("timestamp", path),
EditTimestamp = json.HasKey("editTimestamp") ? json.RequireLong("editTimestamp", path) : null, EditTimestamp = json.HasKey("editTimestamp") ? json.RequireLong("editTimestamp", path) : null,
RepliedToId = json.HasKey("repliedToId") ? json.RequireSnowflake("repliedToId", path) : null, RepliedToId = json.HasKey("repliedToId") ? json.RequireSnowflake("repliedToId", path) : null,
Attachments = json.HasKey("attachments") ? ReadAttachments(json.RequireArray("attachments", path + ".attachments"), path + ".attachments[]").ToImmutableArray() : ImmutableArray<Attachment>.Empty, Attachments = json.HasKey("attachments") ? ReadAttachments(json.RequireArray("attachments", path + ".attachments"), path + ".attachments[]").ToImmutableList() : ImmutableList<Attachment>.Empty,
Embeds = json.HasKey("embeds") ? ReadEmbeds(json.RequireArray("embeds", path + ".embeds"), path + ".embeds[]").ToImmutableArray() : ImmutableArray<Embed>.Empty, Embeds = json.HasKey("embeds") ? ReadEmbeds(json.RequireArray("embeds", path + ".embeds"), path + ".embeds[]").ToImmutableList() : ImmutableList<Embed>.Empty,
Reactions = json.HasKey("reactions") ? ReadReactions(json.RequireArray("reactions", path + ".reactions"), path + ".reactions[]").ToImmutableArray() : ImmutableArray<Reaction>.Empty, Reactions = json.HasKey("reactions") ? ReadReactions(json.RequireArray("reactions", path + ".reactions"), path + ".reactions[]").ToImmutableList() : ImmutableList<Reaction>.Empty,
}; };
[SuppressMessage("ReSharper", "ConvertToLambdaExpression")] [SuppressMessage("ReSharper", "ConvertToLambdaExpression")]

View File

@ -25,12 +25,12 @@ sealed class TrackUsersEndpoint : BaseEndpoint {
users[i++] = ReadUser(user, "user"); users[i++] = ReadUser(user, "user");
} }
Db.AddUsers(users); await Db.Users.Add(users);
return HttpOutput.None; return HttpOutput.None;
} }
private static User ReadUser(JsonElement json, string path) => new() { private static User ReadUser(JsonElement json, string path) => new () {
Id = json.RequireSnowflake("id", path), Id = json.RequireSnowflake("id", path),
Name = json.RequireString("name", path), Name = json.RequireString("name", path),
AvatarUrl = json.HasKey("avatar") ? json.RequireString("avatar", path) : null, AvatarUrl = json.HasKey("avatar") ? json.RequireString("avatar", path) : null,

View File

@ -12,6 +12,7 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.Data.Sqlite" Version="8.0.0" /> <PackageReference Include="Microsoft.Data.Sqlite" Version="8.0.0" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
<PackageReference Include="System.Reactive" Version="6.0.0" /> <PackageReference Include="System.Reactive" Version="6.0.0" />
</ItemGroup> </ItemGroup>

View File

@ -15,7 +15,7 @@ public sealed class AsyncValueComputer<TValue> {
private SoftHardCancellationToken? currentCancellationTokenSource; private SoftHardCancellationToken? currentCancellationTokenSource;
private bool wasHardCancelled = false; private bool wasHardCancelled = false;
private Func<TValue>? currentComputeFunction; private Func<Task<TValue>>? currentComputeFunction;
private bool hasComputeFunctionChanged = false; private bool hasComputeFunctionChanged = false;
private AsyncValueComputer(Action<TValue> resultProcessor, TaskScheduler resultTaskScheduler, bool processOutdatedResults) { private AsyncValueComputer(Action<TValue> resultProcessor, TaskScheduler resultTaskScheduler, bool processOutdatedResults) {
@ -31,7 +31,7 @@ public sealed class AsyncValueComputer<TValue> {
} }
} }
public void Compute(Func<TValue> func) { public void Compute(Func<Task<TValue>> func) {
lock (stateLock) { lock (stateLock) {
wasHardCancelled = false; wasHardCancelled = false;
@ -47,7 +47,7 @@ public sealed class AsyncValueComputer<TValue> {
} }
[SuppressMessage("ReSharper", "MethodSupportsCancellation")] [SuppressMessage("ReSharper", "MethodSupportsCancellation")]
private void EnqueueComputation(Func<TValue> func) { private void EnqueueComputation(Func<Task<TValue>> func) {
var cancellationTokenSource = new SoftHardCancellationToken(); var cancellationTokenSource = new SoftHardCancellationToken();
currentCancellationTokenSource?.RequestSoftCancellation(); currentCancellationTokenSource?.RequestSoftCancellation();
@ -84,9 +84,9 @@ public sealed class AsyncValueComputer<TValue> {
public sealed class Single { public sealed class Single {
private readonly AsyncValueComputer<TValue> baseComputer; private readonly AsyncValueComputer<TValue> baseComputer;
private readonly Func<TValue> resultComputer; private readonly Func<Task<TValue>> resultComputer;
internal Single(AsyncValueComputer<TValue> baseComputer, Func<TValue> resultComputer) { internal Single(AsyncValueComputer<TValue> baseComputer, Func<Task<TValue>> resultComputer) {
this.baseComputer = baseComputer; this.baseComputer = baseComputer;
this.resultComputer = resultComputer; this.resultComputer = resultComputer;
} }
@ -94,6 +94,10 @@ public sealed class AsyncValueComputer<TValue> {
public void Recompute() { public void Recompute() {
baseComputer.Compute(resultComputer); baseComputer.Compute(resultComputer);
} }
public void Cancel() {
baseComputer.Cancel();
}
} }
public static Builder WithResultProcessor(Action<TValue> resultProcessor, TaskScheduler? scheduler = null) { public static Builder WithResultProcessor(Action<TValue> resultProcessor, TaskScheduler? scheduler = null) {
@ -119,7 +123,7 @@ public sealed class AsyncValueComputer<TValue> {
return new AsyncValueComputer<TValue>(resultProcessor, resultTaskScheduler, processOutdatedResults); return new AsyncValueComputer<TValue>(resultProcessor, resultTaskScheduler, processOutdatedResults);
} }
public Single BuildWithComputer(Func<TValue> resultComputer) { public Single BuildWithComputer(Func<Task<TValue>> resultComputer) {
return new Single(Build(), resultComputer); return new Single(Build(), resultComputer);
} }
} }

View File

@ -0,0 +1,45 @@
using System;
using System.Threading;
using System.Threading.Tasks;
namespace DHT.Utils.Tasks;
public sealed class RestartableTask<T>(Action<T> resultProcessor, TaskScheduler resultScheduler) {
private readonly object monitor = new ();
private CancellationTokenSource? cancellationTokenSource;
public void Restart(Func<CancellationToken, Task<T>> resultComputer) {
lock (monitor) {
Cancel();
cancellationTokenSource = new CancellationTokenSource();
var taskCancellationTokenSource = cancellationTokenSource;
var taskCancellationToken = taskCancellationTokenSource.Token;
Task.Run(() => resultComputer(taskCancellationToken), taskCancellationToken)
.ContinueWith(task => resultProcessor(task.Result), taskCancellationToken, TaskContinuationOptions.OnlyOnRanToCompletion, resultScheduler)
.ContinueWith(_ => OnTaskFinished(taskCancellationTokenSource), CancellationToken.None);
}
}
public void Cancel() {
lock (monitor) {
if (cancellationTokenSource != null) {
cancellationTokenSource.Cancel();
cancellationTokenSource = null;
}
}
}
private void OnTaskFinished(CancellationTokenSource taskCancellationTokenSource) {
lock (monitor) {
taskCancellationTokenSource.Dispose();
if (cancellationTokenSource == taskCancellationTokenSource) {
cancellationTokenSource = null;
}
}
}
}

View File

@ -0,0 +1,84 @@
using System;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
namespace DHT.Utils.Tasks;
public abstract class ThrottledTaskBase<T> : IDisposable {
private readonly Channel<Func<CancellationToken, T>> taskChannel = Channel.CreateBounded<Func<CancellationToken, T>>(new BoundedChannelOptions(capacity: 1) {
SingleReader = true,
SingleWriter = false,
AllowSynchronousContinuations = false,
FullMode = BoundedChannelFullMode.DropOldest
});
private readonly CancellationTokenSource cancellationTokenSource = new ();
internal ThrottledTaskBase() {}
protected async Task ReaderTask() {
var cancellationToken = cancellationTokenSource.Token;
try {
await foreach (var item in taskChannel.Reader.ReadAllAsync(cancellationToken)) {
try {
await Run(item, cancellationToken);
} catch (OperationCanceledException) {
throw;
} catch (Exception) {
// Ignore.
}
}
} catch (OperationCanceledException) {
// Ignore.
} finally {
cancellationTokenSource.Dispose();
}
}
protected abstract Task Run(Func<CancellationToken, T> func, CancellationToken cancellationToken);
public void Post(Func<CancellationToken, T> resultComputer) {
taskChannel.Writer.TryWrite(resultComputer);
}
public void Dispose() {
taskChannel.Writer.Complete();
cancellationTokenSource.Cancel();
}
}
public sealed class ThrottledTask : ThrottledTaskBase<Task> {
private readonly Action resultProcessor;
private readonly TaskScheduler resultScheduler;
public ThrottledTask(Action resultProcessor, TaskScheduler resultScheduler) {
this.resultProcessor = resultProcessor;
this.resultScheduler = resultScheduler;
Task.Run(ReaderTask);
}
protected override async Task Run(Func<CancellationToken, Task> func, CancellationToken cancellationToken) {
await func(cancellationToken);
await Task.Factory.StartNew(resultProcessor, cancellationToken, TaskCreationOptions.None, resultScheduler);
}
}
public sealed class ThrottledTask<T> : ThrottledTaskBase<Task<T>> {
private readonly Action<T> resultProcessor;
private readonly TaskScheduler resultScheduler;
public ThrottledTask(Action<T> resultProcessor, TaskScheduler resultScheduler) {
this.resultProcessor = resultProcessor;
this.resultScheduler = resultScheduler;
Task.Run(ReaderTask);
}
protected override async Task Run(Func<CancellationToken, Task<T>> func, CancellationToken cancellationToken) {
T result = await func(cancellationToken);
await Task.Factory.StartNew(() => resultProcessor(result), cancellationToken, TaskCreationOptions.None, resultScheduler);
}
}