diff --git a/core/workflow-core/src/main/scala/edu/uci/ics/amber/core/storage/result/ArrowFileDocument.scala b/core/workflow-core/src/main/scala/edu/uci/ics/amber/core/storage/result/ArrowFileDocument.scala index dc33304bbc9..afbe38b76cc 100644 --- a/core/workflow-core/src/main/scala/edu/uci/ics/amber/core/storage/result/ArrowFileDocument.scala +++ b/core/workflow-core/src/main/scala/edu/uci/ics/amber/core/storage/result/ArrowFileDocument.scala @@ -16,18 +16,22 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Using class ArrowFileDocument[T]( - val uri: URI, - val arrowSchema: Schema, - val serializer: (T, Int, VectorSchemaRoot) => Unit, - val deserializer: (Int, VectorSchemaRoot) => T - ) extends VirtualDocument[T] with BufferedItemWriter[T] { + val uri: URI, + val arrowSchema: Schema, + val serializer: (T, Int, VectorSchemaRoot) => Unit, + val deserializer: (Int, VectorSchemaRoot) => T +) extends VirtualDocument[T] + with BufferedItemWriter[T] { private val file: FileObject = VFS.getManager.resolveFile(uri) private val lock = new ReentrantReadWriteLock() - private val allocator = new RootAllocator() private val buffer = new ArrayBuffer[T]() override val bufferSize: Int = 1024 + private var arrowRootallocator: RootAllocator = _ + private var arrowVectorSchemaRoot: VectorSchemaRoot = _ + private var arrowFileWriter: ArrowFileWriter = _ + // Initialize the file if it doesn't exist withWriteLock { if (!file.exists()) { @@ -39,127 +43,105 @@ class ArrowFileDocument[T]( } } - // Utility function to wrap code block with read lock private def withReadLock[M](block: => M): M = { lock.readLock().lock() try block finally lock.readLock().unlock() } - // Utility function to wrap code block with write lock private def withWriteLock[M](block: => M): M = { lock.writeLock().lock() try block finally lock.writeLock().unlock() } - override def putOne(item: T): Unit = withWriteLock { - buffer.append(item) - if (buffer.size >= bufferSize) { - flushBuffer() + override def open(): Unit = + withWriteLock { + buffer.clear() + arrowRootallocator = new RootAllocator() + arrowVectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, arrowRootallocator) + val outputStream = new FileOutputStream(file.getURL.getPath) + arrowFileWriter = new ArrowFileWriter(arrowVectorSchemaRoot, null, outputStream.getChannel) + arrowFileWriter.start() } - } - - override def removeOne(item: T): Unit = withWriteLock { - buffer -= item - } - - /** Write buffered items to the file and clear the buffer */ - private def flushBuffer(): Unit = withWriteLock { - val outputStream = new FileOutputStream(file.getURL.getPath, true) - Using.Manager { use => - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val writer = new ArrowFileWriter(root, null, outputStream.getChannel) - use(writer) - use(root) - writer.start() - - buffer.zipWithIndex.foreach { case (item, index) => - serializer(item, index, root) + override def putOne(item: T): Unit = + withWriteLock { + buffer.append(item) + if (buffer.size >= bufferSize) { + flushBuffer() } - - root.setRowCount(buffer.size) - writer.writeBatch() - buffer.clear() - writer.end() } - } - - /** Open the writer (clear the buffer) */ - override def open(): Unit = withWriteLock { - buffer.clear() - } - /** Close the writer, flushing any remaining buffered items */ - override def close(): Unit = withWriteLock { - if (buffer.nonEmpty) { - flushBuffer() + override def removeOne(item: T): Unit = + withWriteLock { + buffer -= item } - allocator.close() - } - /** Get an iterator of data items of type T */ - private def getIterator: Iterator[T] = withReadLock { - val path = Paths.get(file.getURL.toURI) - val channel: SeekableByteChannel = FileChannel.open(path, StandardOpenOption.READ) - val reader = new ArrowFileReader(channel, allocator) - val root = reader.getVectorSchemaRoot - - new Iterator[T] { - private var currentIndex = 0 - private var currentBatchLoaded = reader.loadNextBatch() - - private def loadNextBatch(): Boolean = { - currentBatchLoaded = reader.loadNextBatch() - currentIndex = 0 - currentBatchLoaded + private def flushBuffer(): Unit = + withWriteLock { + if (buffer.nonEmpty) { + buffer.zipWithIndex.foreach { + case (item, index) => + serializer(item, index, arrowVectorSchemaRoot) + } + arrowVectorSchemaRoot.setRowCount(buffer.size) + arrowFileWriter.writeBatch() + buffer.clear() + arrowVectorSchemaRoot.clear() } + } - override def hasNext: Boolean = currentIndex < root.getRowCount || loadNextBatch() - - override def next(): T = { - if (!hasNext) throw new NoSuchElementException("No more elements") - val item = deserializer(currentIndex, root) - currentIndex += 1 - item + override def close(): Unit = + withWriteLock { + if (buffer.nonEmpty) { + flushBuffer() + } + if (arrowFileWriter != null) { + arrowFileWriter.end() + arrowFileWriter.close() } + if (arrowVectorSchemaRoot != null) arrowVectorSchemaRoot.close() + if (arrowRootallocator != null) arrowRootallocator.close() } - } - - /** Get the ith data item */ - override def getItem(i: Int): T = withReadLock { - getIterator.drop(i).next() - } - - /** Get a range of data items */ - override def getRange(from: Int, until: Int): Iterator[T] = withReadLock { - getIterator.slice(from, until) - } - - /** Get items after a certain offset */ - override def getAfter(offset: Int): Iterator[T] = withReadLock { - getIterator.drop(offset + 1) - } - - /** Get the total count of items */ - override def getCount: Long = withReadLock { - getIterator.size - } - - /** Get all items as an iterator */ - override def get(): Iterator[T] = withReadLock { - getIterator - } - /** Physically remove the file */ - override def clear(): Unit = withWriteLock { - if (file.exists()) { - file.delete() - } else { - throw new RuntimeException(s"File $uri doesn't exist") + override def get(): Iterator[T] = + withReadLock { + val path = Paths.get(file.getURL.toURI) + val allocator = new RootAllocator() + val channel: SeekableByteChannel = FileChannel.open(path, StandardOpenOption.READ) + val reader = new ArrowFileReader(channel, allocator) + val root = reader.getVectorSchemaRoot + + new Iterator[T] { + private var currentIndex = 0 + private var currentBatchLoaded = reader.loadNextBatch() + + private def loadNextBatch(): Boolean = { + currentBatchLoaded = reader.loadNextBatch() + currentIndex = 0 + currentBatchLoaded + } + + override def hasNext: Boolean = currentIndex < root.getRowCount || loadNextBatch() + + override def next(): T = { + if (!hasNext) throw new NoSuchElementException("No more elements") + val item = deserializer(currentIndex, root) + currentIndex += 1 + item + } + } } - } override def getURI: URI = uri -} \ No newline at end of file + + override def clear(): Unit = + withWriteLock { + if (file.exists()) { + file.delete() + } else { + throw new RuntimeException(s"File $uri doesn't exist") + } + } +} diff --git a/core/workflow-core/src/test/scala/edu/uci/ics/amber/storage/result/ArrowFileDocumentSpec.scala b/core/workflow-core/src/test/scala/edu/uci/ics/amber/storage/result/ArrowFileDocumentSpec.scala index 23693063922..1b0794025a3 100644 --- a/core/workflow-core/src/test/scala/edu/uci/ics/amber/storage/result/ArrowFileDocumentSpec.scala +++ b/core/workflow-core/src/test/scala/edu/uci/ics/amber/storage/result/ArrowFileDocumentSpec.scala @@ -1,8 +1,8 @@ package edu.uci.ics.amber.storage.result import edu.uci.ics.amber.core.storage.result.ArrowFileDocument -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} -import org.apache.arrow.vector.{IntVector, VarCharVector, VectorSchemaRoot} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} +import org.apache.arrow.vector.{VarCharVector, VectorSchemaRoot} import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.ScalaFutures.convertScalaFuture import org.scalatest.flatspec.AnyFlatSpec @@ -27,22 +27,29 @@ object ArrowFileDocumentSpec { class ArrowFileDocumentSpec extends AnyFlatSpec with Matchers with BeforeAndAfter { - val stringArrowSchema = new Schema(List( - Field.nullablePrimitive("data", ArrowType.Utf8.INSTANCE) - ).asJava) - - def stringSerializer(item: String, index: Int, root: VectorSchemaRoot): Unit = { - val vector = root.getVector("data").asInstanceOf[VarCharVector] - vector.setSafe(index, item.getBytes("UTF-8")) - } - - def stringDeserializer(index: Int, root: VectorSchemaRoot): String = { - new String(root.getVector("data").asInstanceOf[VarCharVector].get(index)) - } + val stringArrowSchema = new Schema( + List( + Field.nullablePrimitive("data", ArrowType.Utf8.INSTANCE) + ).asJava + ) def createDocument(): ArrowFileDocument[String] = { val tempPath = Files.createTempFile("arrow_test", ".arrow") - new ArrowFileDocument[String](tempPath.toUri, stringArrowSchema, stringSerializer, stringDeserializer) + new ArrowFileDocument[String]( + tempPath.toUri, + stringArrowSchema, + ArrowFileDocumentSpec.stringSerializer, + ArrowFileDocumentSpec.stringDeserializer + ) + } + + def openDocument(uri: URI): ArrowFileDocument[String] = { + new ArrowFileDocument[String]( + uri, + stringArrowSchema, + ArrowFileDocumentSpec.stringSerializer, + ArrowFileDocumentSpec.stringDeserializer + ) } def deleteDocument(doc: ArrowFileDocument[String]): Unit = { @@ -57,8 +64,7 @@ class ArrowFileDocumentSpec extends AnyFlatSpec with Matchers with BeforeAndAfte doc.close() val items = doc.get().toList - items should contain("Buffered Item 1") - items should contain("Buffered Item 2") + items should contain theSameElementsAs List("Buffered Item 1", "Buffered Item 2") deleteDocument(doc) } @@ -72,24 +78,26 @@ class ArrowFileDocumentSpec extends AnyFlatSpec with Matchers with BeforeAndAfte doc.close() val items = doc.get().toList - largeBuffer.foreach { item => - items should contain(item) - } + items should contain theSameElementsAs largeBuffer deleteDocument(doc) } - it should "allow removing items from the buffer" in { + it should "override file content when reopened for writing" in { val doc = createDocument() + + // First write doc.open() - doc.putOne("Item to keep") - doc.putOne("Item to remove") - doc.removeOne("Item to remove") + doc.putOne("First Write") + doc.close() + + // Second write should override the first one + doc.open() + doc.putOne("Second Write") doc.close() val items = doc.get().toList - items should contain("Item to keep") - items should not contain "Item to remove" + items should contain only "Second Write" deleteDocument(doc) } @@ -98,15 +106,15 @@ class ArrowFileDocumentSpec extends AnyFlatSpec with Matchers with BeforeAndAfte val doc = createDocument() val numberOfThreads = 5 + doc.open() val futures = (1 to numberOfThreads).map { i => Future { - doc.open() doc.putOne(s"Content from thread $i") - doc.close() } } Future.sequence(futures).futureValue + doc.close() val items = doc.get().toList (1 to numberOfThreads).foreach { i => @@ -116,23 +124,33 @@ class ArrowFileDocumentSpec extends AnyFlatSpec with Matchers with BeforeAndAfte deleteDocument(doc) } - it should "handle writing after reopening the file" in { + it should "handle concurrent reads and writes safely" in { val doc = createDocument() - // First write + // Writer thread to add items doc.open() - doc.putOne("First Write") - doc.close() + val writerFuture = Future { + (1 to 10).foreach { i => + doc.putOne(s"Write $i") + } + } - // Second write - doc.open() - doc.putOne("Second Write") + // Reader threads to read items concurrently + val readerFutures = (1 to 3).map { _ => + Future { + doc.get().toList + } + } + + Future.sequence(readerFutures).futureValue + writerFuture.futureValue doc.close() - val items = doc.get().toList - items should contain("First Write") - items should contain("Second Write") + val finalItems = doc.get().toList + (1 to 10).foreach { i => + finalItems should contain(s"Write $i") + } deleteDocument(doc) } -} \ No newline at end of file +} diff --git a/core/workflow-core/src/test/scala/edu/uci/ics/amber/storage/result/PartitionedFileDocumentSpec.scala b/core/workflow-core/src/test/scala/edu/uci/ics/amber/storage/result/PartitionedFileDocumentSpec.scala index 4a3840fbe76..1260eb3f4e0 100644 --- a/core/workflow-core/src/test/scala/edu/uci/ics/amber/storage/result/PartitionedFileDocumentSpec.scala +++ b/core/workflow-core/src/test/scala/edu/uci/ics/amber/storage/result/PartitionedFileDocumentSpec.scala @@ -1,6 +1,9 @@ package edu.uci.ics.amber.core.storage.result -import edu.uci.ics.amber.core.storage.result.ArrowFileDocumentSpec.{stringDeserializer, stringSerializer} +import edu.uci.ics.amber.core.storage.result.ArrowFileDocumentSpec.{ + stringDeserializer, + stringSerializer +} import edu.uci.ics.amber.core.storage.result.PartitionedFileDocument.getPartitionURI import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} import org.apache.arrow.vector.{VarCharVector, VectorSchemaRoot} @@ -23,12 +26,13 @@ object ArrowFileDocumentSpec { } } - class PartitionedFileDocumentSpec extends AnyFlatSpec with Matchers with BeforeAndAfter { - val stringArrowSchema = new Schema(List( - Field.nullablePrimitive("data", ArrowType.Utf8.INSTANCE) - ).asJava) + val stringArrowSchema = new Schema( + List( + Field.nullablePrimitive("data", ArrowType.Utf8.INSTANCE) + ).asJava + ) var partitionDocument: PartitionedFileDocument[ArrowFileDocument[String], String] = _ val numOfPartitions = 3 @@ -39,7 +43,8 @@ class PartitionedFileDocumentSpec extends AnyFlatSpec with Matchers with BeforeA partitionDocument = new PartitionedFileDocument[ArrowFileDocument[String], String]( partitionId, numOfPartitions, - uri => new ArrowFileDocument[String](uri, stringArrowSchema, stringSerializer, stringDeserializer) + uri => + new ArrowFileDocument[String](uri, stringArrowSchema, stringSerializer, stringDeserializer) ) } @@ -51,7 +56,12 @@ class PartitionedFileDocumentSpec extends AnyFlatSpec with Matchers with BeforeA "PartitionDocument" should "create and write to each partition directly" in { for (i <- 0 until numOfPartitions) { val partitionURI = getPartitionURI(partitionId, i) - val fileDoc = new ArrowFileDocument[String](partitionURI, stringArrowSchema, stringSerializer, stringDeserializer) + val fileDoc = new ArrowFileDocument[String]( + partitionURI, + stringArrowSchema, + stringSerializer, + stringDeserializer + ) fileDoc.open() fileDoc.putOne(s"Data for partition $i") fileDoc.close() @@ -67,7 +77,12 @@ class PartitionedFileDocumentSpec extends AnyFlatSpec with Matchers with BeforeA // Write some data directly to each partition for (i <- 0 until numOfPartitions) { val partitionURI = getPartitionURI(partitionId, i) - val fileDoc = new ArrowFileDocument[String](partitionURI, stringArrowSchema, stringSerializer, stringDeserializer) + val fileDoc = new ArrowFileDocument[String]( + partitionURI, + stringArrowSchema, + stringSerializer, + stringDeserializer + ) fileDoc.open() fileDoc.putOne(s"Content in partition $i") fileDoc.close() @@ -84,7 +99,12 @@ class PartitionedFileDocumentSpec extends AnyFlatSpec with Matchers with BeforeA // Write some data directly to each partition for (i <- 0 until numOfPartitions) { val partitionURI = getPartitionURI(partitionId, i) - val fileDoc = new ArrowFileDocument[String](partitionURI, stringArrowSchema, stringSerializer, stringDeserializer) + val fileDoc = new ArrowFileDocument[String]( + partitionURI, + stringArrowSchema, + stringSerializer, + stringDeserializer + ) fileDoc.open() fileDoc.putOne(s"Some data in partition $i") fileDoc.close() @@ -107,7 +127,12 @@ class PartitionedFileDocumentSpec extends AnyFlatSpec with Matchers with BeforeA val futures = (0 until numOfPartitions).map { i => Future { val partitionURI = getPartitionURI(partitionId, i) - val fileDoc = new ArrowFileDocument[String](partitionURI, stringArrowSchema, stringSerializer, stringDeserializer) + val fileDoc = new ArrowFileDocument[String]( + partitionURI, + stringArrowSchema, + stringSerializer, + stringDeserializer + ) fileDoc.open() fileDoc.putOne(s"Concurrent write to partition $i") fileDoc.close()