diff --git a/Backend.Tests/Controllers/AudioControllerTests.cs b/Backend.Tests/Controllers/AudioControllerTests.cs index 0c04a2ad83..8725ff12b1 100644 --- a/Backend.Tests/Controllers/AudioControllerTests.cs +++ b/Backend.Tests/Controllers/AudioControllerTests.cs @@ -6,6 +6,7 @@ using BackendFramework.Models; using BackendFramework.Services; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; using NUnit.Framework; namespace Backend.Tests.Controllers @@ -52,6 +53,26 @@ public void TearDown() _projRepo.Delete(_projId); } + [Test] + public void TestDownloadAudioFileInvalidArguments() + { + var result = _audioController.DownloadAudioFile("invalid/projId", "wordId", "fileName"); + Assert.That(result is UnsupportedMediaTypeResult); + + result = _audioController.DownloadAudioFile("projId", "invalid/wordId", "fileName"); + Assert.That(result is UnsupportedMediaTypeResult); + + result = _audioController.DownloadAudioFile("projId", "wordId", "invalid/fileName"); + Assert.That(result is UnsupportedMediaTypeResult); + } + + [Test] + public void TestDownloadAudioFileNoFile() + { + var result = _audioController.DownloadAudioFile("projId", "wordId", "fileName"); + Assert.That(result is BadRequestObjectResult); + } + [Test] public void TestAudioImport() { @@ -69,7 +90,7 @@ public void TestAudioImport() _ = _audioController.UploadAudioFile(_projId, word.Id, fileUpload).Result; var foundWord = _wordRepo.GetWord(_projId, word.Id).Result; - Assert.IsNotNull(foundWord?.Audio); + Assert.That(foundWord?.Audio, Is.Not.Null); } [Test] @@ -98,7 +119,7 @@ public void DeleteAudio() // Ensure the word with deleted audio is in the frontier Assert.That(frontier, Has.Count.EqualTo(1)); - Assert.AreNotEqual(frontier[0].Id, origWord.Id); + Assert.That(frontier[0].Id, Is.Not.EqualTo(origWord.Id)); Assert.That(frontier[0].Audio, Has.Count.EqualTo(0)); Assert.That(frontier[0].History, Has.Count.EqualTo(1)); } diff --git a/Backend.Tests/Helper/FileStorageTests.cs b/Backend.Tests/Helper/FileStorageTests.cs index d8a7300d91..df182caa65 100644 --- a/Backend.Tests/Helper/FileStorageTests.cs +++ b/Backend.Tests/Helper/FileStorageTests.cs @@ -1,4 +1,5 @@ using System; +using BackendFramework.Helper; using static BackendFramework.Helper.FileStorage; using NUnit.Framework; @@ -11,7 +12,7 @@ public void TestFileTypeExtension() { Assert.That(FileTypeExtension(FileType.Audio), Is.EqualTo(".webm")); Assert.That(FileTypeExtension(FileType.Avatar), Is.EqualTo(".jpg")); - Assert.Throws(() => { FileTypeExtension((FileType)99); }); + Assert.That(() => FileTypeExtension((FileType)99), Throws.TypeOf()); } [Test] @@ -19,22 +20,14 @@ public void TestFilePathIdSanitization() { const string invalidId = "@"; const string validId = "a"; - Assert.Throws( - () => GenerateAudioFilePathForWord(invalidId, validId)); - Assert.Throws( - () => GenerateAudioFilePathForWord(validId, invalidId)); - Assert.Throws( - () => GenerateAudioFilePath(invalidId, "file.mp3")); - Assert.Throws( - () => GenerateAudioFileDirPath(invalidId)); - Assert.Throws( - () => GenerateImportExtractedLocationDirPath(invalidId)); - Assert.Throws( - () => GenerateLiftImportDirPath(invalidId)); - Assert.Throws( - () => GenerateAvatarFilePath(invalidId)); - Assert.Throws( - () => GetProjectDir(invalidId)); + Assert.That(() => GenerateAudioFilePathForWord(invalidId, validId), Throws.TypeOf()); + Assert.That(() => GenerateAudioFilePathForWord(validId, invalidId), Throws.TypeOf()); + Assert.That(() => GenerateAudioFilePath(invalidId, "file.mp3"), Throws.TypeOf()); + Assert.That(() => GenerateAudioFileDirPath(invalidId), Throws.TypeOf()); + Assert.That(() => GenerateImportExtractedLocationDirPath(invalidId), Throws.TypeOf()); + Assert.That(() => GenerateLiftImportDirPath(invalidId), Throws.TypeOf()); + Assert.That(() => GenerateAvatarFilePath(invalidId), Throws.TypeOf()); + Assert.That(() => GetProjectDir(invalidId), Throws.TypeOf()); } } } diff --git a/Backend.Tests/Helper/SanitizationTests.cs b/Backend.Tests/Helper/SanitizationTests.cs index 3f3387fb33..9cae32a0b7 100644 --- a/Backend.Tests/Helper/SanitizationTests.cs +++ b/Backend.Tests/Helper/SanitizationTests.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using BackendFramework.Helper; using static BackendFramework.Helper.Sanitization; using NUnit.Framework; @@ -15,7 +16,7 @@ public class SanitizationTests [TestCaseSource(nameof(_validIds))] public void TestValidIds(string id) { - Assert.That(SanitizeId(id)); + Assert.That(SanitizeId(id), Is.EqualTo(id)); } private static List _invalidIds = new() @@ -48,7 +49,7 @@ public void TestValidIds(string id) [TestCaseSource(nameof(_invalidIds))] public void TestInvalidIds(string id) { - Assert.False(SanitizeId(id)); + Assert.That(() => SanitizeId(id), Throws.TypeOf()); } private static List _validFileNames = new() @@ -68,7 +69,7 @@ public void TestInvalidIds(string id) [TestCaseSource(nameof(_validFileNames))] public void TestValidFileNames(string fileName) { - Assert.That(SanitizeFileName(fileName)); + Assert.That(SanitizeFileName(fileName), Is.EqualTo(fileName)); } private static List _invalidFileNames = new() @@ -97,7 +98,7 @@ public void TestValidFileNames(string fileName) [TestCaseSource(nameof(_invalidFileNames))] public void TestInvalidFileNames(string fileName) { - Assert.False(SanitizeFileName(fileName)); + Assert.That(() => SanitizeFileName(fileName), Throws.TypeOf()); } private static List> _namesUnfriendlyFriendly = new() diff --git a/Backend/Controllers/AudioController.cs b/Backend/Controllers/AudioController.cs index d5af420467..2f1e656424 100644 --- a/Backend/Controllers/AudioController.cs +++ b/Backend/Controllers/AudioController.cs @@ -39,8 +39,13 @@ public IActionResult DownloadAudioFile(string projectId, string wordId, string f // } // Sanitize user input - if (!Sanitization.SanitizeId(projectId) || !Sanitization.SanitizeId(wordId) || - !Sanitization.SanitizeFileName(fileName)) + try + { + fileName = Sanitization.SanitizeFileName(fileName); + projectId = Sanitization.SanitizeId(projectId); + wordId = Sanitization.SanitizeId(wordId); + } + catch { return new UnsupportedMediaTypeResult(); } @@ -71,7 +76,12 @@ public async Task UploadAudioFile(string projectId, string wordId } // sanitize user input - if (!Sanitization.SanitizeId(projectId) || !Sanitization.SanitizeId(wordId)) + try + { + projectId = Sanitization.SanitizeId(projectId); + wordId = Sanitization.SanitizeId(wordId); + } + catch { return new UnsupportedMediaTypeResult(); } @@ -123,7 +133,13 @@ public async Task DeleteAudioFile(string projectId, string wordId } // sanitize user input - if (!Sanitization.SanitizeId(projectId) || !Sanitization.SanitizeId(wordId)) + try + { + fileName = Sanitization.SanitizeFileName(fileName); + projectId = Sanitization.SanitizeId(projectId); + wordId = Sanitization.SanitizeId(wordId); + } + catch { return new UnsupportedMediaTypeResult(); } diff --git a/Backend/Controllers/LiftController.cs b/Backend/Controllers/LiftController.cs index fae260fb35..3e24548815 100644 --- a/Backend/Controllers/LiftController.cs +++ b/Backend/Controllers/LiftController.cs @@ -94,7 +94,11 @@ public async Task FinishUploadLiftFile(string projectId) internal async Task FinishUploadLiftFile(string projectId, string userId) { // Sanitize projectId - if (!Sanitization.SanitizeId(projectId)) + try + { + projectId = Sanitization.SanitizeId(projectId); + } + catch { return new UnsupportedMediaTypeResult(); } @@ -150,7 +154,11 @@ public async Task UploadLiftFile(string projectId, [FromForm] Fil } // Sanitize projectId - if (!Sanitization.SanitizeId(projectId)) + try + { + projectId = Sanitization.SanitizeId(projectId); + } + catch { return new UnsupportedMediaTypeResult(); } @@ -188,7 +196,11 @@ public async Task UploadLiftFile(string projectId, [FromForm] Fil private async Task AddImportToProject(string liftStoragePath, string projectId) { // Sanitize projectId - if (!Sanitization.SanitizeId(projectId)) + try + { + projectId = Sanitization.SanitizeId(projectId); + } + catch { return new UnsupportedMediaTypeResult(); } @@ -277,7 +289,11 @@ private async Task ExportLiftFile(string projectId, string userId } // Sanitize projectId - if (!Sanitization.SanitizeId(projectId)) + try + { + projectId = Sanitization.SanitizeId(projectId); + } + catch { return new UnsupportedMediaTypeResult(); } @@ -402,7 +418,11 @@ public async Task CanUploadLift(string projectId) } // Sanitize user input - if (!Sanitization.SanitizeId(projectId)) + try + { + projectId = Sanitization.SanitizeId(projectId); + } + catch { return new UnsupportedMediaTypeResult(); } diff --git a/Backend/Controllers/ProjectController.cs b/Backend/Controllers/ProjectController.cs index 6573333988..7a55949a3f 100644 --- a/Backend/Controllers/ProjectController.cs +++ b/Backend/Controllers/ProjectController.cs @@ -187,7 +187,11 @@ public async Task DeleteProject(string projectId) } // Sanitize user input. - if (!Sanitization.SanitizeId(projectId)) + try + { + projectId = Sanitization.SanitizeId(projectId); + } + catch { return new UnsupportedMediaTypeResult(); } diff --git a/Backend/Helper/FileStorage.cs b/Backend/Helper/FileStorage.cs index 5eae788d3b..6bf071cc96 100644 --- a/Backend/Helper/FileStorage.cs +++ b/Backend/Helper/FileStorage.cs @@ -33,26 +33,15 @@ protected HomeFolderNotFoundException(SerializationInfo info, StreamingContext c : base(info, context) { } } - /// Indicates an invalid input id. - [Serializable] - public class InvalidIdException : Exception - { - public InvalidIdException() { } - - protected InvalidIdException(SerializationInfo info, StreamingContext context) - : base(info, context) { } - } - /// /// Generate a path to the file name of an audio file for the Project based on the Word ID. /// /// Throws when id invalid. public static string GenerateAudioFilePathForWord(string projectId, string wordId) { - if (!Sanitization.SanitizeId(projectId) || !Sanitization.SanitizeId(wordId)) - { - throw new InvalidIdException(); - } + projectId = Sanitization.SanitizeId(projectId); + wordId = Sanitization.SanitizeId(wordId); + return GenerateProjectFilePath(projectId, AudioPathSuffix, wordId, FileType.Audio); } @@ -62,10 +51,8 @@ public static string GenerateAudioFilePathForWord(string projectId, string wordI /// Throws when id invalid. public static string GenerateAudioFilePath(string projectId, string fileName) { - if (!Sanitization.SanitizeId(projectId)) - { - throw new InvalidIdException(); - } + projectId = Sanitization.SanitizeId(projectId); + return GenerateProjectFilePath(projectId, AudioPathSuffix, fileName); } @@ -75,10 +62,8 @@ public static string GenerateAudioFilePath(string projectId, string fileName) /// Throws when id invalid. public static string GenerateAudioFileDirPath(string projectId, bool createDir = true) { - if (!Sanitization.SanitizeId(projectId)) - { - throw new InvalidIdException(); - } + projectId = Sanitization.SanitizeId(projectId); + return GenerateProjectDirPath(projectId, AudioPathSuffix, createDir); } @@ -89,10 +74,8 @@ public static string GenerateAudioFileDirPath(string projectId, bool createDir = /// This function is not expected to be used often. public static string GenerateImportExtractedLocationDirPath(string projectId, bool createDir = true) { - if (!Sanitization.SanitizeId(projectId)) - { - throw new InvalidIdException(); - } + projectId = Sanitization.SanitizeId(projectId); + return GenerateProjectDirPath(projectId, ImportExtractedLocation, createDir); } @@ -102,10 +85,8 @@ public static string GenerateImportExtractedLocationDirPath(string projectId, bo /// Throws when id invalid. public static string GenerateLiftImportDirPath(string projectId, bool createDir = true) { - if (!Sanitization.SanitizeId(projectId)) - { - throw new InvalidIdException(); - } + projectId = Sanitization.SanitizeId(projectId); + return GenerateProjectDirPath(projectId, LiftImportSuffix, createDir); } @@ -115,10 +96,8 @@ public static string GenerateLiftImportDirPath(string projectId, bool createDir /// Throws when id invalid. public static string GenerateAvatarFilePath(string userId) { - if (!Sanitization.SanitizeId(userId)) - { - throw new InvalidIdException(); - } + userId = Sanitization.SanitizeId(userId); + return GenerateFilePath(AvatarsDir, userId, FileType.Avatar); } @@ -128,10 +107,8 @@ public static string GenerateAvatarFilePath(string userId) /// Throws when id invalid. public static string GetProjectDir(string projectId) { - if (!Sanitization.SanitizeId(projectId)) - { - throw new InvalidIdException(); - } + projectId = Sanitization.SanitizeId(projectId); + return GenerateProjectDirPath(projectId, "", false); } diff --git a/Backend/Helper/Sanitization.cs b/Backend/Helper/Sanitization.cs index 3d722760ae..0c1288f817 100644 --- a/Backend/Helper/Sanitization.cs +++ b/Backend/Helper/Sanitization.cs @@ -1,29 +1,56 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Collections.Immutable; using System.Globalization; using System.Linq; +using System.Runtime.Serialization; using System.Text; namespace BackendFramework.Helper { + /// Indicates an invalid input file name. + [Serializable] + public class InvalidFileNameException : Exception + { + public InvalidFileNameException() : base() { } + + protected InvalidFileNameException(SerializationInfo info, StreamingContext context) : base(info, context) { } + } + + /// Indicates an invalid input id. + [Serializable] + public class InvalidIdException : Exception + { + public InvalidIdException() { } + + protected InvalidIdException(SerializationInfo info, StreamingContext context) : base(info, context) { } + } + public static class Sanitization { /// /// Validate that an ID field sent from a user does not contain any illegal characters. - /// /// This is especially important if, for example, user input ultimately is used in the creation of a path to /// disk. /// - public static bool SanitizeId(string id) + /// The input string, if it is already sanitized. + /// Throws with string isn't sanitized. + public static string SanitizeId(string id) { - return id.All(c => char.IsLetterOrDigit(c) || c == '-'); + if (id.All(c => char.IsLetterOrDigit(c) || c == '-')) + { + return id; + } + throw new InvalidIdException(); } /// /// Validate that a file name does not have any illegal characters (such as / or \) which could manipulate /// the path of files that are stored or retrieved. /// - public static bool SanitizeFileName(string fileName) + /// The input string, if it is already sanitized. + /// Throws when string isn't sanitized. + public static string SanitizeFileName(string fileName) { // For list of invalid characters per OS, see https://stackoverflow.com/a/31976060. var validCharacters = new List @@ -36,14 +63,18 @@ public static bool SanitizeFileName(string fileName) ')', ' ' }.ToImmutableList(); - return fileName.All(c => char.IsLetterOrDigit(c) || validCharacters.Contains(c)); + if (fileName.All(c => char.IsLetterOrDigit(c) || validCharacters.Contains(c))) + { + return fileName; + } + throw new InvalidFileNameException(); } /// /// Convert a string (e.g., a project name), into one friendly to use in a path. /// Uses alphanumeric and '-' '_' ',' '(' ')'. - /// Returns converted string, unless length 0, then returns fallback. /// + /// Converted string, unless length 0, then returns fallback. public static string MakeFriendlyForPath(string name, string fallback = "") { // Method modified from https://stackoverflow.com/a/780800