1
0
mirror of https://github.com/chylex/Discord-History-Tracker.git synced 2024-11-25 14:42:44 +01:00

Compare commits

...

5 Commits

Author SHA1 Message Date
ef3e34066a
Release v40.0 2023-12-31 20:18:24 +01:00
37374eeb18
Migrate ConfigureAwait to Task.Run 2023-12-31 20:18:08 +01:00
23ddb45a0d
Make opening/saving viewer asynchronous 2023-12-31 20:18:08 +01:00
9904a711f7
Make database connection pool asynchronous 2023-12-31 19:47:28 +01:00
d5720c8758
Code cleanup 2023-12-31 19:44:44 +01:00
34 changed files with 173 additions and 192 deletions

View File

@ -1,4 +1,5 @@
using System; using System;
using System.Collections.Generic;
using DHT.Utils.Logging; using DHT.Utils.Logging;
namespace DHT.Desktop; namespace DHT.Desktop;
@ -8,15 +9,15 @@ sealed class Arguments {
private const int FirstArgument = 1; private const int FirstArgument = 1;
public static Arguments Empty => new(Array.Empty<string>()); public static Arguments Empty => new (Array.Empty<string>());
public bool Console { get; } public bool Console { get; }
public string? DatabaseFile { get; } public string? DatabaseFile { get; }
public ushort? ServerPort { get; } public ushort? ServerPort { get; }
public string? ServerToken { get; } public string? ServerToken { get; }
public Arguments(string[] args) { public Arguments(IReadOnlyList<string> args) {
for (int i = FirstArgument; i < args.Length; i++) { for (int i = FirstArgument; i < args.Count; i++) {
string key = args[i]; string key = args[i];
switch (key) { switch (key) {
@ -35,7 +36,7 @@ sealed class Arguments {
value = key; value = key;
key = "-db"; key = "-db";
} }
else if (i >= args.Length - 1) { else if (i >= args.Count - 1) {
Log.Warn("Missing value for command line argument: " + key); Log.Warn("Missing value for command line argument: " + key);
continue; continue;
} }

View File

@ -19,13 +19,13 @@ sealed class BytesValueConverter : IValueConverter {
} }
} }
private static readonly Unit[] Units = { private static readonly Unit[] Units = [
new ("B", decimalPlaces: 0), new Unit("B", decimalPlaces: 0),
new ("kB", decimalPlaces: 0), new Unit("kB", decimalPlaces: 0),
new ("MB", decimalPlaces: 1), new Unit("MB", decimalPlaces: 1),
new ("GB", decimalPlaces: 1), new Unit("GB", decimalPlaces: 1),
new ("TB", decimalPlaces: 1) new Unit("TB", decimalPlaces: 1)
}; ];
private const int Scale = 1000; private const int Scale = 1000;

View File

@ -59,7 +59,7 @@ class CheckBoxDialogModel : ObservableObject {
} }
sealed class CheckBoxDialogModel<T> : CheckBoxDialogModel { sealed class CheckBoxDialogModel<T> : CheckBoxDialogModel {
public new IReadOnlyList<CheckBoxItem<T>> Items { get; } private new IReadOnlyList<CheckBoxItem<T>> Items { get; }
public IEnumerable<CheckBoxItem<T>> SelectedItems => Items.Where(static item => item.IsChecked); public IEnumerable<CheckBoxItem<T>> SelectedItems => Items.Where(static item => item.IsChecked);

View File

@ -36,7 +36,7 @@ class TextBoxDialogModel : ObservableObject {
} }
sealed class TextBoxDialogModel<T> : TextBoxDialogModel { sealed class TextBoxDialogModel<T> : TextBoxDialogModel {
public new IReadOnlyList<TextBoxItem<T>> Items { get; } private new IReadOnlyList<TextBoxItem<T>> Items { get; }
public IEnumerable<TextBoxItem<T>> ValidItems => Items.Where(static item => item.IsValid); public IEnumerable<TextBoxItem<T>> ValidItems => Items.Where(static item => item.IsValid);

View File

@ -39,7 +39,7 @@ static class DiscordAppSettings {
public static async Task<bool?> AreDevToolsEnabled() { public static async Task<bool?> AreDevToolsEnabled() {
try { try {
var settingsJson = await ReadSettingsJson().ConfigureAwait(false); var settingsJson = await ReadSettingsJson();
return AreDevToolsEnabled(settingsJson); return AreDevToolsEnabled(settingsJson);
} catch (Exception e) { } catch (Exception e) {
Log.Error("Cannot read settings file."); Log.Error("Cannot read settings file.");

View File

@ -5,4 +5,4 @@ namespace DHT.Desktop.Discord;
[JsonSourceGenerationOptions(GenerationMode = JsonSourceGenerationMode.Default, WriteIndented = true)] [JsonSourceGenerationOptions(GenerationMode = JsonSourceGenerationMode.Default, WriteIndented = true)]
[JsonSerializable(typeof(JsonObject))] [JsonSerializable(typeof(JsonObject))]
sealed partial class DiscordAppSettingsJsonContext : JsonSerializerContext {} sealed partial class DiscordAppSettingsJsonContext : JsonSerializerContext;

View File

@ -92,7 +92,7 @@ sealed class DatabasePageModel {
await target.AddFrom(db); await target.AddFrom(db);
return true; return true;
} finally { } finally {
db.Dispose(); await db.DisposeAsync();
} }
}); });
} }

View File

@ -128,7 +128,7 @@ namespace DHT.Desktop.Main.Pages {
return options[(int) Math.Floor(options.Length * rand.NextDouble() * rand.NextDouble())]; return options[(int) Math.Floor(options.Length * rand.NextDouble() * rand.NextDouble())];
} }
private static readonly string[] RandomWords = { private static readonly string[] RandomWords = [
"apple", "apricot", "artichoke", "arugula", "asparagus", "avocado", "apple", "apricot", "artichoke", "arugula", "asparagus", "avocado",
"banana", "bean", "beechnut", "beet", "blackberry", "blackcurrant", "blueberry", "boysenberry", "bramble", "broccoli", "banana", "bean", "beechnut", "beet", "blackberry", "blackcurrant", "blueberry", "boysenberry", "bramble", "broccoli",
"cabbage", "cacao", "cantaloupe", "caper", "carambola", "carrot", "cauliflower", "celery", "chard", "cherry", "chokeberry", "citron", "clementine", "coconut", "corn", "crabapple", "cranberry", "cucumber", "currant", "cabbage", "cacao", "cantaloupe", "caper", "carambola", "carrot", "cauliflower", "celery", "chard", "cherry", "chokeberry", "citron", "clementine", "coconut", "corn", "crabapple", "cranberry", "cucumber", "currant",
@ -151,8 +151,8 @@ namespace DHT.Desktop.Main.Pages {
"vanilla", "vanilla",
"watercress", "watermelon", "watercress", "watermelon",
"yam", "yam",
"zucchini", "zucchini"
}; ];
private static string RandomText(Random rand, int maxWords) { private static string RandomText(Random rand, int maxWords) {
int wordCount = 1 + (int) Math.Floor(maxWords * Math.Pow(rand.NextDouble(), 3)); int wordCount = 1 + (int) Math.Floor(maxWords * Math.Pow(rand.NextDouble(), 3));

View File

@ -78,7 +78,7 @@ sealed partial class TrackingPageModel : ObservableObject {
} }
private async Task InitializeDevToolsToggle() { private async Task InitializeDevToolsToggle() {
bool? devToolsEnabled = await DiscordAppSettings.AreDevToolsEnabled(); bool? devToolsEnabled = await Task.Run(DiscordAppSettings.AreDevToolsEnabled);
if (devToolsEnabled.HasValue) { if (devToolsEnabled.HasValue) {
AreDevToolsEnabled = devToolsEnabled.Value; AreDevToolsEnabled = devToolsEnabled.Value;

View File

@ -65,10 +65,13 @@ sealed partial class ViewerPageModel : ObservableObject, IDisposable {
var fullPath = await PrepareTemporaryViewerFile(); var fullPath = await PrepareTemporaryViewerFile();
var strategy = new LiveViewerExportStrategy(ServerConfiguration.Port, ServerConfiguration.Token); var strategy = new LiveViewerExportStrategy(ServerConfiguration.Port, ServerConfiguration.Token);
await WriteViewerFile(fullPath, strategy); await ProgressDialog.ShowIndeterminate(window, "Open Viewer", "Creating viewer...", _ => Task.Run(() => WriteViewerFile(fullPath, strategy)));
Process.Start(new ProcessStartInfo(fullPath) { UseShellExecute = true });
Process.Start(new ProcessStartInfo(fullPath) {
UseShellExecute = true
});
} catch (Exception e) { } catch (Exception e) {
await Dialog.ShowOk(window, "Open Viewer", "Could not save viewer: " + e.Message); await Dialog.ShowOk(window, "Open Viewer", "Could not create or save viewer: " + e.Message);
} }
} }
@ -106,9 +109,9 @@ sealed partial class ViewerPageModel : ObservableObject, IDisposable {
} }
try { try {
await WriteViewerFile(path, StandaloneViewerExportStrategy.Instance); await ProgressDialog.ShowIndeterminate(window, "Save Viewer", "Creating viewer...", _ => Task.Run(() => WriteViewerFile(path, StandaloneViewerExportStrategy.Instance)));
} catch (Exception e) { } catch (Exception e) {
await Dialog.ShowOk(window, "Save Viewer", "Could not save viewer: " + e.Message); await Dialog.ShowOk(window, "Save Viewer", "Could not create or save viewer: " + e.Message);
} }
} }

View File

@ -23,5 +23,7 @@ sealed class DummyDatabaseFile : IDatabaseFile {
return Task.CompletedTask; return Task.CompletedTask;
} }
public void Dispose() {} public ValueTask DisposeAsync() {
return ValueTask.CompletedTask;
}
} }

View File

@ -3,9 +3,9 @@ using System.Text.Json.Serialization;
namespace DHT.Server.Database.Export; namespace DHT.Server.Database.Export;
[JsonSourceGenerationOptions( [JsonSourceGenerationOptions(
Converters = new [] { typeof(SnowflakeJsonSerializer) }, Converters = [typeof(SnowflakeJsonSerializer)],
PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase,
GenerationMode = JsonSourceGenerationMode.Default GenerationMode = JsonSourceGenerationMode.Default
)] )]
[JsonSerializable(typeof(ViewerJson))] [JsonSerializable(typeof(ViewerJson))]
sealed partial class ViewerJsonContext : JsonSerializerContext {} sealed partial class ViewerJsonContext : JsonSerializerContext;

View File

@ -4,7 +4,7 @@ using DHT.Server.Database.Repositories;
namespace DHT.Server.Database; namespace DHT.Server.Database;
public interface IDatabaseFile : IDisposable { public interface IDatabaseFile : IAsyncDisposable {
string Path { get; } string Path { get; }
IUserRepository Users { get; } IUserRepository Users { get; }

View File

@ -4,4 +4,4 @@ namespace DHT.Server.Database.Import;
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower, GenerationMode = JsonSourceGenerationMode.Default)] [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower, GenerationMode = JsonSourceGenerationMode.Default)]
[JsonSerializable(typeof(DiscordEmbedLegacyJson))] [JsonSerializable(typeof(DiscordEmbedLegacyJson))]
sealed partial class DiscordEmbedLegacyJsonContext : JsonSerializerContext {} sealed partial class DiscordEmbedLegacyJsonContext : JsonSerializerContext;

View File

@ -22,7 +22,7 @@ sealed class SqliteAttachmentRepository : BaseSqliteRepository, IAttachmentRepos
} }
public async Task<long> Count(AttachmentFilter? filter, CancellationToken cancellationToken) { public async Task<long> Count(AttachmentFilter? filter, CancellationToken cancellationToken) {
using var conn = pool.Take(); await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments a" + filter.GenerateWhereClause("a"), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken); return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments a" + filter.GenerateWhereClause("a"), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
} }
} }

View File

@ -16,7 +16,7 @@ sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository
} }
public async Task Add(IReadOnlyList<Channel> channels) { public async Task Add(IReadOnlyList<Channel> channels) {
using var conn = pool.Take(); await using var conn = await pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) { await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("channels", [ await using var cmd = conn.Upsert("channels", [
@ -47,12 +47,12 @@ sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository
} }
public override async Task<long> Count(CancellationToken cancellationToken) { public override async Task<long> Count(CancellationToken cancellationToken) {
using var conn = pool.Take(); await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM channels", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken); return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM channels", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
} }
public async IAsyncEnumerable<Channel> Get() { public async IAsyncEnumerable<Channel> Get() {
using var conn = pool.Take(); await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels"); await using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels");
await using var reader = await cmd.ExecuteReaderAsync(); await using var reader = await cmd.ExecuteReaderAsync();

View File

@ -21,7 +21,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
} }
public async Task AddDownload(Data.Download download) { public async Task AddDownload(Data.Download download) {
using (var conn = pool.Take()) { await using (var conn = await pool.Take()) {
await using var cmd = conn.Upsert("downloads", [ await using var cmd = conn.Upsert("downloads", [
("normalized_url", SqliteType.Text), ("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text), ("download_url", SqliteType.Text),
@ -42,7 +42,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
} }
public override async Task<long> Count(CancellationToken cancellationToken) { public override async Task<long> Count(CancellationToken cancellationToken) {
using var conn = pool.Take(); await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM downloads", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken); return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM downloads", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
} }
@ -97,14 +97,14 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
var result = new DownloadStatusStatistics(); var result = new DownloadStatusStatistics();
using var conn = pool.Take(); await using var conn = await pool.Take();
await LoadUndownloadedStatistics(conn, result, cancellationToken); await LoadUndownloadedStatistics(conn, result, cancellationToken);
await LoadSuccessStatistics(conn, result, cancellationToken); await LoadSuccessStatistics(conn, result, cancellationToken);
return result; return result;
} }
public async IAsyncEnumerable<Data.Download> GetWithoutData() { public async IAsyncEnumerable<Data.Download> GetWithoutData() {
using var conn = pool.Take(); await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT normalized_url, download_url, status, size FROM downloads"); await using var cmd = conn.Command("SELECT normalized_url, download_url, status, size FROM downloads");
await using var reader = await cmd.ExecuteReaderAsync(); await using var reader = await cmd.ExecuteReaderAsync();
@ -120,7 +120,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
} }
public async Task<Data.Download> HydrateWithData(Data.Download download) { public async Task<Data.Download> HydrateWithData(Data.Download download) {
using var conn = pool.Take(); await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT blob FROM downloads WHERE normalized_url = :url"); await using var cmd = conn.Command("SELECT blob FROM downloads WHERE normalized_url = :url");
cmd.AddAndSet(":url", SqliteType.Text, download.NormalizedUrl); cmd.AddAndSet(":url", SqliteType.Text, download.NormalizedUrl);
@ -136,7 +136,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
} }
public async Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) { public async Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) {
using var conn = pool.Take(); await using var conn = await pool.Take();
await using var cmd = conn.Command( await using var cmd = conn.Command(
""" """
@ -162,7 +162,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
} }
public async Task<int> EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) { public async Task<int> EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
using var conn = pool.Take(); await using var conn = await pool.Take();
await using var cmd = conn.Command( await using var cmd = conn.Command(
$""" $"""
@ -181,7 +181,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
public async IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, [EnumeratorCancellation] CancellationToken cancellationToken) { public async IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, [EnumeratorCancellation] CancellationToken cancellationToken) {
var found = new List<DownloadItem>(); var found = new List<DownloadItem>();
using var conn = pool.Take(); await using var conn = await pool.Take();
await using (var cmd = conn.Command("SELECT normalized_url, download_url, size FROM downloads WHERE status = :enqueued LIMIT :limit")) { await using (var cmd = conn.Command("SELECT normalized_url, download_url, size FROM downloads WHERE status = :enqueued LIMIT :limit")) {
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued); cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
@ -215,7 +215,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
} }
public async Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) { public async Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
using (var conn = pool.Take()) { await using (var conn = await pool.Take()) {
await conn.ExecuteAsync( await conn.ExecuteAsync(
$""" $"""
-- noinspection SqlWithoutWhere -- noinspection SqlWithoutWhere

View File

@ -36,7 +36,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
bool addedAttachments = false; bool addedAttachments = false;
using (var conn = pool.Take()) { await using (var conn = await pool.Take()) {
await using var tx = await conn.BeginTransactionAsync(); await using var tx = await conn.BeginTransactionAsync();
await using var messageCmd = conn.Upsert("messages", [ await using var messageCmd = conn.Upsert("messages", [
@ -170,7 +170,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
} }
public async Task<long> Count(MessageFilter? filter, CancellationToken cancellationToken) { public async Task<long> Count(MessageFilter? filter, CancellationToken cancellationToken) {
using var conn = pool.Take(); await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken); return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
} }
@ -205,7 +205,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
} }
public async IAsyncEnumerable<Message> Get(MessageFilter? filter) { public async IAsyncEnumerable<Message> Get(MessageFilter? filter) {
using var conn = pool.Take(); await using var conn = await pool.Take();
const string AttachmentSql = const string AttachmentSql =
""" """
@ -281,7 +281,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
} }
public async IAsyncEnumerable<ulong> GetIds(MessageFilter? filter) { public async IAsyncEnumerable<ulong> GetIds(MessageFilter? filter) {
using var conn = pool.Take(); await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause()); await using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause());
await using var reader = await cmd.ExecuteReaderAsync(); await using var reader = await cmd.ExecuteReaderAsync();
@ -292,7 +292,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
} }
public async Task Remove(MessageFilter filter, FilterRemovalMode mode) { public async Task Remove(MessageFilter filter, FilterRemovalMode mode) {
using (var conn = pool.Take()) { await using (var conn = await pool.Take()) {
await conn.ExecuteAsync( await conn.ExecuteAsync(
$""" $"""
-- noinspection SqlWithoutWhere -- noinspection SqlWithoutWhere

View File

@ -16,7 +16,7 @@ sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository {
} }
public async Task Add(IReadOnlyList<Data.Server> servers) { public async Task Add(IReadOnlyList<Data.Server> servers) {
using var conn = pool.Take(); await using var conn = await pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) { await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("servers", [ await using var cmd = conn.Upsert("servers", [
@ -39,12 +39,12 @@ sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository {
} }
public override async Task<long> Count(CancellationToken cancellationToken) { public override async Task<long> Count(CancellationToken cancellationToken) {
using var conn = pool.Take(); await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM servers", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken); return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM servers", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
} }
public async IAsyncEnumerable<Data.Server> Get() { public async IAsyncEnumerable<Data.Server> Get() {
using var conn = pool.Take(); await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT id, name, type FROM servers"); await using var cmd = conn.Command("SELECT id, name, type FROM servers");
await using var reader = await cmd.ExecuteReaderAsync(); await using var reader = await cmd.ExecuteReaderAsync();

View File

@ -16,7 +16,7 @@ sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository {
} }
public async Task Add(IReadOnlyList<User> users) { public async Task Add(IReadOnlyList<User> users) {
using (var conn = pool.Take()) { await using (var conn = await pool.Take()) {
await using var tx = await conn.BeginTransactionAsync(); await using var tx = await conn.BeginTransactionAsync();
await using var cmd = conn.Upsert("users", [ await using var cmd = conn.Upsert("users", [
@ -41,12 +41,12 @@ sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository {
} }
public override async Task<long> Count(CancellationToken cancellationToken) { public override async Task<long> Count(CancellationToken cancellationToken) {
using var conn = pool.Take(); await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM users", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken); return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM users", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
} }
public async IAsyncEnumerable<User> Get() { public async IAsyncEnumerable<User> Get() {
using var conn = pool.Take(); await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users"); await using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users");
await using var reader = await cmd.ExecuteReaderAsync(); await using var reader = await cmd.ExecuteReaderAsync();

View File

@ -16,14 +16,14 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
Mode = SqliteOpenMode.ReadWriteCreate, Mode = SqliteOpenMode.ReadWriteCreate,
}; };
var pool = new SqliteConnectionPool(connectionString, DefaultPoolSize); var pool = await SqliteConnectionPool.Create(connectionString, DefaultPoolSize);
bool wasOpened; bool wasOpened;
try { try {
using var conn = pool.Take(); await using var conn = await pool.Take();
wasOpened = await new Schema(conn).Setup(schemaUpgradeCallbacks); wasOpened = await new Schema(conn).Setup(schemaUpgradeCallbacks);
} catch (Exception) { } catch (Exception) {
pool.Dispose(); await pool.DisposeAsync();
throw; throw;
} }
@ -31,7 +31,7 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
return new SqliteDatabaseFile(path, pool); return new SqliteDatabaseFile(path, pool);
} }
else { else {
pool.Dispose(); await pool.DisposeAsync();
return null; return null;
} }
} }
@ -65,18 +65,18 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
downloads = new SqliteDownloadRepository(pool); downloads = new SqliteDownloadRepository(pool);
} }
public void Dispose() { public async ValueTask DisposeAsync() {
users.Dispose(); users.Dispose();
servers.Dispose(); servers.Dispose();
channels.Dispose(); channels.Dispose();
messages.Dispose(); messages.Dispose();
attachments.Dispose(); attachments.Dispose();
downloads.Dispose(); downloads.Dispose();
pool.Dispose(); await pool.DisposeAsync();
} }
public async Task Vacuum() { public async Task Vacuum() {
using var conn = pool.Take(); await using var conn = await pool.Take();
await conn.ExecuteAsync("VACUUM"); await conn.ExecuteAsync("VACUUM");
} }
} }

View File

@ -3,6 +3,6 @@ using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils; namespace DHT.Server.Database.Sqlite.Utils;
interface ISqliteConnection : IDisposable { interface ISqliteConnection : IAsyncDisposable {
SqliteConnection InnerConnection { get; } SqliteConnection InnerConnection { get; }
} }

View File

@ -1,100 +1,77 @@
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading; using System.Threading;
using DHT.Utils.Logging; using System.Threading.Tasks;
using DHT.Utils.Collections;
using Microsoft.Data.Sqlite; using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils; namespace DHT.Server.Database.Sqlite.Utils;
sealed class SqliteConnectionPool : IDisposable { sealed class SqliteConnectionPool : IAsyncDisposable {
public static async Task<SqliteConnectionPool> Create(SqliteConnectionStringBuilder connectionStringBuilder, int poolSize) {
var pool = new SqliteConnectionPool(poolSize);
await pool.InitializePooledConnections(connectionStringBuilder);
return pool;
}
private static string GetConnectionString(SqliteConnectionStringBuilder connectionStringBuilder) { private static string GetConnectionString(SqliteConnectionStringBuilder connectionStringBuilder) {
connectionStringBuilder.Pooling = false; connectionStringBuilder.Pooling = false;
return connectionStringBuilder.ToString(); return connectionStringBuilder.ToString();
} }
private readonly object monitor = new (); private readonly int poolSize;
private readonly Random rand = new (); private readonly List<PooledConnection> all;
private volatile bool isDisposed; private readonly ConcurrentPool<PooledConnection> free;
private readonly BlockingCollection<PooledConnection> free = new (new ConcurrentStack<PooledConnection>()); private readonly CancellationTokenSource disposalTokenSource = new ();
private readonly List<PooledConnection> used; private readonly CancellationToken disposalToken;
public SqliteConnectionPool(SqliteConnectionStringBuilder connectionStringBuilder, int poolSize) { private SqliteConnectionPool(int poolSize) {
this.poolSize = poolSize;
this.all = new List<PooledConnection>(poolSize);
this.free = new ConcurrentPool<PooledConnection>(poolSize);
this.disposalToken = disposalTokenSource.Token;
}
private async Task InitializePooledConnections(SqliteConnectionStringBuilder connectionStringBuilder) {
var connectionString = GetConnectionString(connectionStringBuilder); var connectionString = GetConnectionString(connectionStringBuilder);
for (int i = 0; i < poolSize; i++) { for (int i = 0; i < poolSize; i++) {
var conn = new SqliteConnection(connectionString); var conn = new SqliteConnection(connectionString);
conn.Open(); conn.Open();
var pooledConn = new PooledConnection(this, conn); var pooledConnection = new PooledConnection(this, conn);
using (var cmd = pooledConn.Command("PRAGMA journal_mode=WAL")) { await using (var cmd = pooledConnection.Command("PRAGMA journal_mode=WAL")) {
cmd.ExecuteNonQuery(); await cmd.ExecuteNonQueryAsync(disposalToken);
} }
free.Add(pooledConn); all.Add(pooledConnection);
} await free.Push(pooledConnection, disposalToken);
used = new List<PooledConnection>(poolSize);
}
private void ThrowIfDisposed() {
ObjectDisposedException.ThrowIf(isDisposed, nameof(SqliteConnectionPool));
}
public ISqliteConnection Take() {
while (true) {
ThrowIfDisposed();
lock (monitor) {
if (free.TryTake(out var conn)) {
used.Add(conn);
return conn;
}
else {
Log.ForType<SqliteConnectionPool>().Warn("Thread " + Environment.CurrentManagedThreadId + " is starving for connections.");
}
}
Thread.Sleep(TimeSpan.FromMilliseconds(rand.Next(100, 200)));
} }
} }
private void Return(PooledConnection conn) { public async Task<ISqliteConnection> Take() {
ThrowIfDisposed(); return await free.Pop(disposalToken);
lock (monitor) {
if (used.Remove(conn)) {
free.Add(conn);
}
}
} }
public void Dispose() { private async Task Return(PooledConnection conn) {
if (isDisposed) { await free.Push(conn, disposalToken);
}
public async ValueTask DisposeAsync() {
if (disposalToken.IsCancellationRequested) {
return; return;
} }
isDisposed = true; await disposalTokenSource.CancelAsync();
lock (monitor) { foreach (var conn in all) {
while (free.TryTake(out var conn)) { await conn.InnerConnection.CloseAsync();
Close(conn.InnerConnection); await conn.InnerConnection.DisposeAsync();
}
foreach (var conn in used) {
Close(conn.InnerConnection);
}
free.Dispose();
used.Clear();
} }
}
private static void Close(SqliteConnection conn) { disposalTokenSource.Dispose();
conn.Close();
conn.Dispose();
} }
private sealed class PooledConnection : ISqliteConnection { private sealed class PooledConnection : ISqliteConnection {
@ -107,8 +84,8 @@ sealed class SqliteConnectionPool : IDisposable {
this.InnerConnection = conn; this.InnerConnection = conn;
} }
void IDisposable.Dispose() { public async ValueTask DisposeAsync() {
pool.Return(this); await pool.Return(this);
} }
} }
} }

View File

@ -58,7 +58,7 @@ static class SqliteExtensions {
public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) { public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) {
var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name); var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name);
CreateParameters(cmd, new[] { column }); CreateParameters(cmd, [column]);
return cmd; return cmd;
} }

View File

@ -5,7 +5,7 @@ namespace DHT.Server.Database.Sqlite.Utils;
sealed class SqliteWhereGenerator { sealed class SqliteWhereGenerator {
private readonly string? tableAlias; private readonly string? tableAlias;
private readonly bool invert; private readonly bool invert;
private readonly List<string> conditions = new (); private readonly List<string> conditions = [];
public SqliteWhereGenerator(string? tableAlias, bool invert) { public SqliteWhereGenerator(string? tableAlias, bool invert) {
this.tableAlias = tableAlias; this.tableAlias = tableAlias;

View File

@ -4,7 +4,7 @@ using System.Collections.Frozen;
namespace DHT.Server.Download; namespace DHT.Server.Download;
static class DiscordCdn { static class DiscordCdn {
private static FrozenSet<string> CdnHosts { get; } = new [] { private static FrozenSet<string> CdnHosts { get; } = new[] {
"cdn.discordapp.com", "cdn.discordapp.com",
"cdn.discord.com", "cdn.discord.com",
}.ToFrozenSet(); }.ToFrozenSet();

View File

@ -11,12 +11,12 @@ using Microsoft.Extensions.Hosting;
namespace DHT.Server.Service; namespace DHT.Server.Service;
sealed class Startup { sealed class Startup {
private static readonly string[] AllowedOrigins = { private static readonly string[] AllowedOrigins = [
"https://discord.com", "https://discord.com",
"https://ptb.discord.com", "https://ptb.discord.com",
"https://canary.discord.com", "https://canary.discord.com",
"https://discordapp.com", "https://discordapp.com"
}; ];
public void ConfigureServices(IServiceCollection services) { public void ConfigureServices(IServiceCollection services) {
services.Configure<JsonOptions>(static options => { services.Configure<JsonOptions>(static options => {

View File

@ -22,6 +22,6 @@ public sealed class State : IAsyncDisposable {
public async ValueTask DisposeAsync() { public async ValueTask DisposeAsync() {
await Downloader.Stop(); await Downloader.Stop();
await Server.Stop(); await Server.Stop();
Db.Dispose(); await Db.DisposeAsync();
} }
} }

View File

@ -0,0 +1,45 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace DHT.Utils.Collections;
public sealed class ConcurrentPool<T> {
private readonly SemaphoreSlim mutexSemaphore;
private readonly SemaphoreSlim availableItemSemaphore;
private readonly Stack<T> items;
public ConcurrentPool(int size) {
mutexSemaphore = new SemaphoreSlim(1);
availableItemSemaphore = new SemaphoreSlim(0, size);
items = new Stack<T>();
}
public async Task Push(T item, CancellationToken cancellationToken) {
await PushItem(item, cancellationToken);
availableItemSemaphore.Release();
}
public async Task<T> Pop(CancellationToken cancellationToken) {
await availableItemSemaphore.WaitAsync(cancellationToken);
return await PopItem(cancellationToken);
}
private async Task PushItem(T item, CancellationToken cancellationToken) {
await mutexSemaphore.WaitAsync(cancellationToken);
try {
items.Push(item);
} finally {
mutexSemaphore.Release();
}
}
private async Task<T> PopItem(CancellationToken cancellationToken) {
await mutexSemaphore.WaitAsync(cancellationToken);
try {
return items.Pop();
} finally {
mutexSemaphore.Release();
}
}
}

View File

@ -1,19 +0,0 @@
using System.Collections.Generic;
namespace DHT.Utils.Collections;
public sealed class MultiDictionary<TKey, TValue> where TKey : notnull {
private readonly Dictionary<TKey, List<TValue>> dict = new();
public void Add(TKey key, TValue value) {
if (!dict.TryGetValue(key, out var list)) {
dict[key] = list = new List<TValue>();
}
list.Add(value);
}
public List<TValue>? GetListOrNull(TKey key) {
return dict.TryGetValue(key, out var list) ? list : null;
}
}

View File

@ -13,42 +13,20 @@ public static class HttpOutput {
} }
} }
public sealed class Text : IHttpOutput { public sealed class Text(string text) : IHttpOutput {
private readonly string text;
public Text(string text) {
this.text = text;
}
public Task WriteTo(HttpResponse response) { public Task WriteTo(HttpResponse response) {
return response.WriteAsync(text, Encoding.UTF8); return response.WriteAsync(text, Encoding.UTF8);
} }
} }
public sealed class File : IHttpOutput { public sealed class File(string? contentType, byte[] bytes) : IHttpOutput {
private readonly string? contentType;
private readonly byte[] bytes;
public File(string? contentType, byte[] bytes) {
this.contentType = contentType;
this.bytes = bytes;
}
public async Task WriteTo(HttpResponse response) { public async Task WriteTo(HttpResponse response) {
response.ContentType = contentType ?? string.Empty; response.ContentType = contentType ?? string.Empty;
await response.Body.WriteAsync(bytes); await response.Body.WriteAsync(bytes);
} }
} }
public sealed class Redirect : IHttpOutput { public sealed class Redirect(string url, bool permanent) : IHttpOutput {
private readonly string url;
private readonly bool permanent;
public Redirect(string url, bool permanent) {
this.url = url;
this.permanent = permanent;
}
public Task WriteTo(HttpResponse response) { public Task WriteTo(HttpResponse response) {
response.Redirect(url, permanent); response.Redirect(url, permanent);
return Task.CompletedTask; return Task.CompletedTask;

View File

@ -5,4 +5,4 @@ namespace DHT.Utils.Http;
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, GenerationMode = JsonSourceGenerationMode.Default)] [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, GenerationMode = JsonSourceGenerationMode.Default)]
[JsonSerializable(typeof(JsonElement))] [JsonSerializable(typeof(JsonElement))]
public sealed partial class JsonElementContext : JsonSerializerContext {} public sealed partial class JsonElementContext : JsonSerializerContext;

View File

@ -6,13 +6,7 @@ using System.Threading.Tasks;
namespace DHT.Utils.Resources; namespace DHT.Utils.Resources;
public sealed class ResourceLoader { public sealed class ResourceLoader(Assembly assembly) {
private readonly Assembly assembly;
public ResourceLoader(Assembly assembly) {
this.assembly = assembly;
}
private Stream GetEmbeddedStream(string filename) { private Stream GetEmbeddedStream(string filename) {
Stream? stream = null; Stream? stream = null;
foreach (var embeddedName in assembly.GetManifestResourceNames()) { foreach (var embeddedName in assembly.GetManifestResourceNames()) {
@ -35,7 +29,7 @@ public sealed class ResourceLoader {
} }
public async Task<string> ReadJoinedAsync(string path, char separator) { public async Task<string> ReadJoinedAsync(string path, char separator) {
StringBuilder joined = new(); StringBuilder joined = new ();
foreach (var embeddedName in assembly.GetManifestResourceNames()) { foreach (var embeddedName in assembly.GetManifestResourceNames()) {
if (embeddedName.Replace('\\', '/').StartsWith(path)) { if (embeddedName.Replace('\\', '/').StartsWith(path)) {

View File

@ -8,5 +8,5 @@ using DHT.Utils;
namespace DHT.Utils; namespace DHT.Utils;
static class Version { static class Version {
public const string Tag = "39.1.0.0"; public const string Tag = "40.0.0.0";
} }