Skip to content

Commit

Permalink
update arrow stream writer (#2388)
Browse files Browse the repository at this point in the history
  • Loading branch information
AFine-gs authored Oct 19, 2023
1 parent bdd95cf commit 0ca8d0a
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@

import java.nio.charset.Charset;
import java.util.Calendar;
import java.util.HashMap;
import java.util.Locale;
import java.util.GregorianCalendar;
import java.util.TimeZone;

import org.apache.arrow.adapter.jdbc.JdbcFieldInfo;
import org.apache.arrow.adapter.jdbc.JdbcToArrowConfig;
import org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder;
import org.apache.arrow.adapter.jdbc.LegendArrowVectorIterator;
Expand All @@ -32,39 +30,40 @@

import java.io.IOException;
import java.io.OutputStream;
import java.sql.ResultSet;
import java.sql.SQLException;
import org.finos.legend.engine.plan.execution.stores.relational.result.RelationalResult;

public class ArrowDataWriter extends ExternalFormatWriter implements AutoCloseable
{
private final LegendArrowVectorIterator iterator;
private final BufferAllocator allocator;

public ArrowDataWriter(ResultSet resultSet) throws SQLException
public ArrowDataWriter(RelationalResult resultSet) throws SQLException
{

HashMap<Integer, JdbcFieldInfo> map = new HashMap<Integer, JdbcFieldInfo>();

this.allocator = new RootAllocator();
JdbcToArrowConfig config = new JdbcToArrowConfigBuilder(allocator, Calendar.getInstance(TimeZone.getDefault(), Locale.ROOT)).build();
this.iterator = LegendArrowVectorIterator.create(resultSet, config);
Calendar calendar = resultSet.getRelationalDatabaseTimeZone() == null ?
new GregorianCalendar(TimeZone.getTimeZone("GMT")) :
new GregorianCalendar(TimeZone.getTimeZone(resultSet.getRelationalDatabaseTimeZone()));
JdbcToArrowConfig config = new JdbcToArrowConfigBuilder(allocator, calendar).setReuseVectorSchemaRoot(true).build();
this.iterator = LegendArrowVectorIterator.create(resultSet.getResultSet(), config);

}

@Override
public void writeData(OutputStream outputStream) throws IOException
{
try
try (VectorSchemaRoot vector = iterator.next();
ArrowStreamWriter writer = new ArrowStreamWriter(vector, null, outputStream);
)
{
writer.start();
writer.writeBatch();
while (this.iterator.hasNext())
{
try (VectorSchemaRoot vector = iterator.next();
ArrowStreamWriter writer = new ArrowStreamWriter(vector, null, outputStream)
)
{
writer.start();
writer.writeBatch();
}
iterator.next();
writer.writeBatch();

}
}
catch (Exception e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public Result executeExternalizeTDSExecutionNode(ExternalFormatExternalizeTDSExe
private Result streamArrowFromRelational(RelationalResult relationalResult) throws SQLException, IOException
{

return new ExternalFormatSerializeResult(new ArrowDataWriter(relationalResult.getResultSet()), relationalResult, CONTENT_TYPE);
return new ExternalFormatSerializeResult(new ArrowDataWriter(relationalResult), relationalResult, CONTENT_TYPE);

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.io.FileOutputStream;
import java.io.IOException;
import java.util.TimeZone;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
Expand All @@ -31,6 +32,7 @@
import org.finos.legend.engine.protocol.pure.v1.model.packageableElement.store.relational.model.result.SQLResultColumn;
import org.finos.legend.engine.shared.core.api.request.RequestContext;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

Expand All @@ -45,7 +47,6 @@ public class TestArrowNodeExecutor

{


@Test
public void testExternalize() throws Exception
{
Expand All @@ -57,7 +58,7 @@ public void testExternalize() throws Exception

mockExecutionNode.connection = mockDatabaseConnection;
Mockito.when(mockDatabaseConnection.accept(any())).thenReturn(false);
try (Connection conn = DriverManager.getConnection("jdbc:h2:~/test", "sa", "");
try (Connection conn = DriverManager.getConnection("jdbc:h2:~/test;TIME ZONE=America/New_York", "sa", "");
ByteArrayOutputStream outputStream = new ByteArrayOutputStream())
{
//setup table
Expand All @@ -70,7 +71,7 @@ public void testExternalize() throws Exception
conn.createStatement().execute("INSERT INTO testtable (testInt, testString, testDate, testBool) VALUES(1,'A', '2020-01-01 00:00:00-05:00',true),( 2,null, '2020-01-01 00:00:00-02:00',false ),( 3,'B', '2020-01-01 00:00:00-05:00',false )");
conn.createStatement().execute("INSERT INTO testtableJoin (testIntR, testStringR) VALUES(6,'A'), (1,'B')");

RelationalResult result = new RelationalResult(FastList.newListWith(new RelationalExecutionActivity("SELECT * FROM testtable left join testtableJoin on testtable.testInt=testtableJoin.testIntR", null)), mockExecutionNode, FastList.newListWith(new SQLResultColumn("testInt", "INTEGER"), new SQLResultColumn("testStringR", "VARCHAR"), new SQLResultColumn("testString", "VARCHAR"), new SQLResultColumn("testDate", "TIMESTAMP"), new SQLResultColumn("testBool", "TIMESTAMP")), null, "GMT", conn, null, null, null, new RequestContext());
RelationalResult result = new RelationalResult(FastList.newListWith(new RelationalExecutionActivity("SELECT * FROM testtable left join testtableJoin on testtable.testInt=testtableJoin.testIntR", null)), mockExecutionNode, FastList.newListWith(new SQLResultColumn("testInt", "INTEGER"), new SQLResultColumn("testStringR", "VARCHAR"), new SQLResultColumn("testString", "VARCHAR"), new SQLResultColumn("testDate", "TIMESTAMP"), new SQLResultColumn("testBool", "TIMESTAMP")), null, "America/New_York", conn, null, null, null, new RequestContext());

ExternalFormatSerializeResult nodeExecute = (ExternalFormatSerializeResult) extension.executeExternalizeTDSExecutionNode(node, result, null, null);

Expand All @@ -97,7 +98,7 @@ public void testExternalizeAsString() throws Exception

mockExecutionNode.connection = mockDatabaseConnection;
Mockito.when(mockDatabaseConnection.accept(any())).thenReturn(false);
try (Connection conn = DriverManager.getConnection("jdbc:h2:~/test", "sa", "");
try (Connection conn = DriverManager.getConnection("jdbc:h2:~/test;TIME ZONE=America/New_York", "sa", "");
)

{
Expand All @@ -106,7 +107,7 @@ public void testExternalizeAsString() throws Exception
conn.createStatement().execute("Create Table testtable (testInt INTEGER, testString VARCHAR(255), testDate TIMESTAMP, testBool BOOLEAN)");
conn.createStatement().execute("INSERT INTO testtable (testInt, testString, testDate, testBool) VALUES(1,'A', '2020-01-01 00:00:00-05:00',true),( 2,'B', '2020-01-01 00:00:00-02:00',false ),( 3,'B', '2020-01-01 00:00:00-05:00',false )");

RelationalResult result = new RelationalResult(FastList.newListWith(new RelationalExecutionActivity("SELECT * FROM testtable", null)), mockExecutionNode, FastList.newListWith(new SQLResultColumn("testInt", "INTEGER"), new SQLResultColumn("testString", "VARCHAR"), new SQLResultColumn("testDate", "TIMESTAMP"), new SQLResultColumn("testBool", "TIMESTAMP")), null, "GMT", conn, null, null, null, new RequestContext());
RelationalResult result = new RelationalResult(FastList.newListWith(new RelationalExecutionActivity("SELECT * FROM testtable", null)), mockExecutionNode, FastList.newListWith(new SQLResultColumn("testInt", "INTEGER"), new SQLResultColumn("testString", "VARCHAR"), new SQLResultColumn("testDate", "TIMESTAMP"), new SQLResultColumn("testBool", "TIMESTAMP")), null, "America/New_York", conn, null, null, null, new RequestContext());

ExternalFormatSerializeResult nodeExecute = (ExternalFormatSerializeResult) extension.executeExternalizeTDSExecutionNode(node, result, null, null);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@
public class TestArrowQueries
{


@Test
public void runTest()
{
try (
ByteArrayOutputStream baos = new ByteArrayOutputStream();
)
{

ObjectMapper objectMapper = ObjectMapperFactory.getNewStandardObjectMapperWithPureProtocolExtensionSupports();
ExecuteInput input = objectMapper.readValue(getClass().getClassLoader().getResource("arrowService.json"), ExecuteInput.class);

Expand All @@ -81,7 +83,8 @@ public void runTest()
.build();
StreamingResult streamingResult = (StreamingResult) executor.executeWithArgs(executeArgs);
streamingResult.stream(baos, SerializationFormat.DEFAULT);
assertAndValidateArrow(new ByteArrayInputStream(baos.toByteArray()), "expectedArrowServiceData.arrow");
assertAndValidateArrow(new ByteArrayInputStream(baos.toByteArray()), "expectedArrowServiceData.arrow");

}
catch (Exception e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
{
"_type": "collection",
"multiplicity": {
"lowerBound": 5,
"upperBound": 5
"lowerBound": 4,
"upperBound": 4
},
"values": [
{
Expand Down Expand Up @@ -91,27 +91,6 @@
}
]
},
{
"_type": "lambda",
"body": [
{
"_type": "property",
"parameters": [
{
"_type": "var",
"name": "x"
}
],
"property": "settlementDateTime"
}
],
"parameters": [
{
"_type": "var",
"name": "x"
}
]
},
{
"_type": "lambda",
"body": [
Expand All @@ -138,8 +117,8 @@
{
"_type": "collection",
"multiplicity": {
"lowerBound": 5,
"upperBound": 5
"lowerBound": 4,
"upperBound": 4
},
"values": [
{
Expand All @@ -154,10 +133,6 @@
"_type": "string",
"value": "Quantity"
},
{
"_type": "string",
"value": "Settlement Date Time"
},
{
"_type": "string",
"value": "Trade Date"
Expand Down Expand Up @@ -447,6 +422,7 @@
{
"connection": {
"_type": "RelationalDatabaseConnection",
"timeZone" : "America/New_York",
"authenticationStrategy": {
"_type": "h2Default"
},
Expand Down
Binary file not shown.

0 comments on commit 0ca8d0a

Please sign in to comment.