Skip to content

Commit

Permalink
SW-4440 Validate photo files at upload time (#1466)
Browse files Browse the repository at this point in the history
To prevent clients from uploading photos in unsupported image formats, try loading
the image data as if we were about to generate a thumbnail for it. If that fails,
delete the uploaded file and return an error.
  • Loading branch information
sgrimm authored Nov 13, 2023
1 parent 580c07e commit 75d9856
Show file tree
Hide file tree
Showing 12 changed files with 71 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.springframework.web.bind.annotation.PathVariable
import org.springframework.web.bind.annotation.RequestParam
import org.springframework.web.context.request.WebRequest
import org.springframework.web.method.annotation.MethodArgumentTypeMismatchException
import org.springframework.web.reactive.function.UnsupportedMediaTypeException
import org.springframework.web.server.ResponseStatusException
import org.springframework.web.servlet.mvc.method.annotation.ResponseEntityExceptionHandler

Expand Down Expand Up @@ -93,6 +94,15 @@ class ControllerExceptionHandler : ResponseEntityExceptionHandler() {
ex.message ?: "An internal error has occurred.", HttpStatus.BAD_REQUEST, request)
}

@ExceptionHandler
fun handleUnsupportedMediaTypeException(
ex: UnsupportedMediaTypeException,
request: WebRequest
): ResponseEntity<*> {
return simpleErrorResponse(
ex.message ?: "Unsupported media type.", HttpStatus.UNSUPPORTED_MEDIA_TYPE, request)
}

@ExceptionHandler
fun handleWebApplicationException(
ex: WebApplicationException,
Expand Down
13 changes: 13 additions & 0 deletions src/main/kotlin/com/terraformation/backend/file/FileService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,31 @@ class FileService(
) {
private val log = perClassLogger()

/**
* Stores a file on the file store and records its information in the database.
*
* @param validateFile Function to check that the file's contents are valid. If not, this should
* throw an exception.
* @param insertChildRows Function to write any additional use-case-specific data about the file.
* Called after the file's basic information has been inserted into the files table, and called
* in the same transaction that inserts into the files table. If this throws an exception, the
* transaction is rolled back and the file is deleted from the file store.
*/
@Throws(IOException::class)
fun storeFile(
category: String,
data: InputStream,
metadata: NewFileMetadata,
validateFile: ((URI) -> Unit)? = null,
insertChildRows: (FileId) -> Unit
): FileId {
val storageUrl = fileStore.newUrl(clock.instant(), category, metadata.contentType)

try {
fileStore.write(storageUrl, data)

validateFile?.invoke(storageUrl)

val filesRow =
FilesRow(
contentType = metadata.contentType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import com.terraformation.backend.file.SizedInputStream
import com.terraformation.backend.file.model.NewFileMetadata
import com.terraformation.backend.log.perClassLogger
import com.terraformation.backend.nursery.event.BatchDeletionStartedEvent
import com.terraformation.backend.util.ImageUtils
import jakarta.inject.Named
import java.io.InputStream
import java.time.InstantSource
Expand All @@ -29,14 +30,15 @@ class BatchPhotoService(
private val clock: InstantSource,
private val dslContext: DSLContext,
private val fileService: FileService,
private val imageUtils: ImageUtils,
) {
private val log = perClassLogger()

fun storePhoto(batchId: BatchId, data: InputStream, metadata: NewFileMetadata): FileId {
requirePermissions { updateBatch(batchId) }

val fileId =
fileService.storeFile("batch", data, metadata) { fileId ->
fileService.storeFile("batch", data, metadata, imageUtils::read) { fileId ->
batchPhotosDao.insert(
BatchPhotosRow(
batchId = batchId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import com.terraformation.backend.file.SizedInputStream
import com.terraformation.backend.file.model.NewFileMetadata
import com.terraformation.backend.log.perClassLogger
import com.terraformation.backend.nursery.event.WithdrawalDeletionStartedEvent
import com.terraformation.backend.util.ImageUtils
import jakarta.inject.Named
import java.io.InputStream
import org.jooq.Condition
Expand All @@ -24,6 +25,7 @@ import org.springframework.context.event.EventListener
class WithdrawalPhotoService(
private val dslContext: DSLContext,
private val fileService: FileService,
private val imageUtils: ImageUtils,
private val withdrawalPhotosDao: WithdrawalPhotosDao,
) {
private val log = perClassLogger()
Expand All @@ -32,7 +34,7 @@ class WithdrawalPhotoService(
requirePermissions { createWithdrawalPhoto(withdrawalId) }

val fileId =
fileService.storeFile("withdrawal", data, metadata) { fileId ->
fileService.storeFile("withdrawal", data, metadata, imageUtils::read) { fileId ->
withdrawalPhotosDao.insert(
WithdrawalPhotosRow(fileId = fileId, withdrawalId = withdrawalId))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class ReportFileService(
): FileId {
requirePermissions { updateReport(reportId) }

val fileId = fileService.storeFile("report", data, metadata, insertChildRow)
val fileId = fileService.storeFile("report", data, metadata, null, insertChildRow)

log.info("Stored ${metadata.contentType} file $fileId for report $reportId")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import com.terraformation.backend.file.model.ExistingFileMetadata
import com.terraformation.backend.file.model.FileMetadata
import com.terraformation.backend.file.model.NewFileMetadata
import com.terraformation.backend.log.perClassLogger
import com.terraformation.backend.util.ImageUtils
import jakarta.inject.Named
import java.io.IOException
import java.io.InputStream
Expand All @@ -28,6 +29,7 @@ class PhotoRepository(
private val accessionPhotosDao: AccessionPhotosDao,
private val dslContext: DSLContext,
private val fileService: FileService,
private val imageUtils: ImageUtils,
) {
private val log = perClassLogger()

Expand All @@ -36,7 +38,7 @@ class PhotoRepository(
requirePermissions { uploadPhoto(accessionId) }

val fileId =
fileService.storeFile("accession", data, metadata) { fileId ->
fileService.storeFile("accession", data, metadata, imageUtils::read) { fileId ->
accessionPhotosDao.insert(AccessionPhotosRow(accessionId = accessionId, fileId = fileId))
}

Expand Down
5 changes: 5 additions & 0 deletions src/test/kotlin/com/terraformation/backend/Helpers.kt
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,8 @@ fun multiPolygon(scale: Double): MultiPolygon {
* Keycloak-related output such as registration URLs.
*/
fun dummyKeycloakInfo() = KeycloakInfo("client-id", "secret", "http://dummy/realms/terraware")

/** A 1-pixel PNG file for testing code that requires valid image data. */
val onePixelPng: ByteArray by lazy {
TestClock::class.java.getResourceAsStream("/file/pixel.png").use { it.readAllBytes() }
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import com.terraformation.backend.file.ThumbnailStore
import com.terraformation.backend.file.model.FileMetadata
import com.terraformation.backend.mockUser
import com.terraformation.backend.nursery.event.BatchDeletionStartedEvent
import com.terraformation.backend.onePixelPng
import com.terraformation.backend.util.ImageUtils
import io.mockk.Runs
import io.mockk.every
import io.mockk.just
Expand Down Expand Up @@ -50,7 +52,7 @@ internal class BatchPhotoServiceTest : DatabaseTest(), RunsAsUser {
dslContext, Clock.fixed(Instant.EPOCH, ZoneOffset.UTC), filesDao, fileStore, thumbnailStore)
}
private val service: BatchPhotoService by lazy {
BatchPhotoService(batchPhotosDao, clock, dslContext, fileService)
BatchPhotoService(batchPhotosDao, clock, dslContext, fileService, ImageUtils(fileStore))
}

private val metadata = FileMetadata.of(MediaType.IMAGE_JPEG_VALUE, "filename", 123L)
Expand Down Expand Up @@ -99,11 +101,10 @@ internal class BatchPhotoServiceTest : DatabaseTest(), RunsAsUser {
inner class ReadPhoto {
@Test
fun `returns photo data`() {
val content = Random.nextBytes(10)
val fileId = storePhoto(content = content)
val fileId = storePhoto(content = onePixelPng)

val inputStream = service.readPhoto(batchId, fileId)
assertArrayEquals(content, inputStream.readAllBytes(), "File content")
assertArrayEquals(onePixelPng, inputStream.readAllBytes(), "File content")
}

@Test
Expand Down Expand Up @@ -274,7 +275,7 @@ internal class BatchPhotoServiceTest : DatabaseTest(), RunsAsUser {

private fun storePhoto(
batchId: BatchId = this.batchId,
content: ByteArray = ByteArray(0)
content: ByteArray = onePixelPng
): FileId {
return service.storePhoto(batchId, content.inputStream(), metadata)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import com.terraformation.backend.file.ThumbnailStore
import com.terraformation.backend.file.model.FileMetadata
import com.terraformation.backend.mockUser
import com.terraformation.backend.nursery.event.WithdrawalDeletionStartedEvent
import com.terraformation.backend.onePixelPng
import com.terraformation.backend.util.ImageUtils
import io.mockk.Runs
import io.mockk.every
import io.mockk.just
Expand All @@ -43,7 +45,7 @@ internal class WithdrawalPhotoServiceTest : DatabaseTest(), RunsAsUser {
dslContext, Clock.fixed(Instant.EPOCH, ZoneOffset.UTC), filesDao, fileStore, thumbnailStore)
}
private val service: WithdrawalPhotoService by lazy {
WithdrawalPhotoService(dslContext, fileService, withdrawalPhotosDao)
WithdrawalPhotoService(dslContext, fileService, ImageUtils(fileStore), withdrawalPhotosDao)
}

private val metadata = FileMetadata.of(MediaType.IMAGE_JPEG_VALUE, "filename", 123L)
Expand All @@ -69,11 +71,10 @@ internal class WithdrawalPhotoServiceTest : DatabaseTest(), RunsAsUser {

@Test
fun `readPhoto returns photo data`() {
val content = Random.nextBytes(10)
val fileId = storePhoto(content = content)
val fileId = storePhoto(content = onePixelPng)

val inputStream = service.readPhoto(withdrawalId, fileId)
assertArrayEquals(content, inputStream.readAllBytes(), "File content")
assertArrayEquals(onePixelPng, inputStream.readAllBytes(), "File content")
}

@Test
Expand Down Expand Up @@ -173,7 +174,7 @@ internal class WithdrawalPhotoServiceTest : DatabaseTest(), RunsAsUser {

private fun storePhoto(
withdrawalId: WithdrawalId = this.withdrawalId,
content: ByteArray = ByteArray(0)
content: ByteArray = onePixelPng
): FileId {
return service.storePhoto(withdrawalId, content.inputStream(), metadata)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import com.terraformation.backend.file.SizedInputStream
import com.terraformation.backend.file.ThumbnailStore
import com.terraformation.backend.file.model.FileMetadata
import com.terraformation.backend.mockUser
import com.terraformation.backend.onePixelPng
import com.terraformation.backend.util.ImageUtils
import io.mockk.Runs
import io.mockk.every
import io.mockk.just
Expand All @@ -29,7 +31,6 @@ import io.mockk.verify
import java.io.ByteArrayInputStream
import java.net.URI
import java.nio.file.NoSuchFileException
import java.nio.file.Path
import java.time.ZoneOffset
import java.time.ZonedDateTime
import kotlin.io.path.Path
Expand Down Expand Up @@ -65,6 +66,10 @@ class PhotoRepositoryTest : DatabaseTest(), RunsAsUser {
private val metadata = FileMetadata.of(contentType, filename, 1L)
private val clock = TestClock(uploadedTime)

private val sixPixelPng: ByteArray by lazy {
javaClass.getResourceAsStream("/file/sixPixels.png").use { it.readAllBytes() }
}

@BeforeEach
fun setUp() {
accessionStore =
Expand Down Expand Up @@ -94,21 +99,19 @@ class PhotoRepositoryTest : DatabaseTest(), RunsAsUser {
every { user.canUploadPhoto(any()) } returns true

fileService = FileService(dslContext, clock, filesDao, fileStore, thumbnailStore)
repository = PhotoRepository(accessionPhotosDao, dslContext, fileService)
repository = PhotoRepository(accessionPhotosDao, dslContext, fileService, ImageUtils(fileStore))

insertSiteData()
insertAccession(id = accessionId, number = accessionNumber)
}

@Test
fun `storePhoto writes file and database row`() {
val photoData = Random(System.currentTimeMillis()).nextBytes(10)

repository.storePhoto(accessionId, photoData.inputStream(), metadata)
repository.storePhoto(accessionId, onePixelPng.inputStream(), metadata)
fileStore.assertFileExists(photoStorageUrl)

val actualPhotoData = fileStore.read(photoStorageUrl)
assertArrayEquals(photoData, actualPhotoData.readAllBytes(), "File contents")
assertArrayEquals(onePixelPng, actualPhotoData.readAllBytes(), "File contents")

val expectedAccessionPhoto = AccessionPhotosRow(accessionId = accessionId)
val actualAccessionPhoto = accessionPhotosDao.fetchByAccessionId(accessionId).first()
Expand All @@ -126,37 +129,31 @@ class PhotoRepositoryTest : DatabaseTest(), RunsAsUser {

@Test
fun `storePhoto replaces existing photo with same filename`() {
val photoData1 = byteArrayOf(1, 2, 3)
val photoData2 = byteArrayOf(4, 5, 6)

every { thumbnailStore.deleteThumbnails(any()) } just Runs

repository.storePhoto(accessionId, photoData1.inputStream(), metadata)
repository.storePhoto(accessionId, onePixelPng.inputStream(), metadata)
every { random.nextLong() } returns 1
repository.storePhoto(accessionId, photoData2.inputStream(), metadata)
repository.storePhoto(accessionId, sixPixelPng.inputStream(), metadata)

fileStore.assertFileNotExists(photoStorageUrl, "Earlier photo file should have been deleted")
assertEquals(1, accessionPhotosDao.fetchByAccessionId(accessionId).size, "Number of photos")

val stream = repository.readPhoto(accessionId, filename)

assertArrayEquals(photoData2, stream.readAllBytes())
assertArrayEquals(sixPixelPng, stream.readAllBytes())
}

@Test
fun `readPhoto reads newest existing photo file`() {
val photoData1 = byteArrayOf(1, 2, 3)
val photoData2 = byteArrayOf(4, 5, 6)

repository.storePhoto(accessionId, photoData1.inputStream(), metadata)
repository.storePhoto(accessionId, onePixelPng.inputStream(), metadata)
every { random.nextLong() } returns 1
repository.storePhoto(accessionId, photoData2.inputStream(), metadata.copy(filename = "dupe"))
repository.storePhoto(accessionId, sixPixelPng.inputStream(), metadata.copy(filename = "dupe"))

filesDao.update(filesDao.findAll().map { it.copy(fileName = filename) })

val stream = repository.readPhoto(accessionId, filename)

assertArrayEquals(photoData2, stream.readAllBytes())
assertArrayEquals(sixPixelPng, stream.readAllBytes())
}

@Test
Expand All @@ -175,14 +172,13 @@ class PhotoRepositoryTest : DatabaseTest(), RunsAsUser {

@Test
fun `readPhoto returns thumbnail if photo dimensions are specified`() {
val photoData = Random.nextBytes(10)
val thumbnailData = Random.nextBytes(10)
val thumbnailStream =
SizedInputStream(ByteArrayInputStream(thumbnailData), thumbnailData.size.toLong())
val width = 123
val height = 456

repository.storePhoto(accessionId, photoData.inputStream(), metadata)
repository.storePhoto(accessionId, onePixelPng.inputStream(), metadata)
val fileId = filesDao.findAll().first().id!!

every { thumbnailStore.getThumbnailData(any(), any(), any()) } returns thumbnailStream
Expand All @@ -196,14 +192,12 @@ class PhotoRepositoryTest : DatabaseTest(), RunsAsUser {

@Test
fun `listPhotos does not return duplicate filenames`() {
val photoData = byteArrayOf(1, 2, 3)

every { random.nextLong() } returns 1
repository.storePhoto(accessionId, photoData.inputStream(), metadata.copy(filename = "1"))
repository.storePhoto(accessionId, onePixelPng.inputStream(), metadata.copy(filename = "1"))
every { random.nextLong() } returns 2
repository.storePhoto(accessionId, photoData.inputStream(), metadata.copy(filename = "2"))
repository.storePhoto(accessionId, onePixelPng.inputStream(), metadata.copy(filename = "2"))
every { random.nextLong() } returns 3
repository.storePhoto(accessionId, photoData.inputStream(), metadata.copy(filename = "3"))
repository.storePhoto(accessionId, onePixelPng.inputStream(), metadata.copy(filename = "3"))

filesDao.update(filesDao.findAll().map { it.copy(fileName = "1") })

Expand All @@ -212,16 +206,14 @@ class PhotoRepositoryTest : DatabaseTest(), RunsAsUser {

@Test
fun `deleteAllPhotos deletes multiple photos`() {
val photoData = Random.nextBytes(10)

every { thumbnailStore.deleteThumbnails(any()) } just Runs
every { user.canUpdateAccession(any()) } returns true

every { random.nextLong() } returns 1L
repository.storePhoto(accessionId, photoData.inputStream(), metadata.copy(filename = "1.jpg"))
repository.storePhoto(accessionId, onePixelPng.inputStream(), metadata.copy(filename = "1.jpg"))

every { random.nextLong() } returns 2L
repository.storePhoto(accessionId, photoData.inputStream(), metadata.copy(filename = "2.jpg"))
repository.storePhoto(accessionId, onePixelPng.inputStream(), metadata.copy(filename = "2.jpg"))

val photoRows = filesDao.findAll()
val fileIds = photoRows.mapNotNull { it.id }
Expand Down
Binary file added src/test/resources/file/pixel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/test/resources/file/sixPixels.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 75d9856

Please sign in to comment.