Skip to content

Commit

Permalink
[SPARK-47691][SQL] Postgres: Support multi dimensional array on the w…
Browse files Browse the repository at this point in the history
…rite side

### What changes were proposed in this pull request?

This pull request adds support for writing our nested array to a Postgres multiple-dimensional array, but the read side has not yet been implemented.

### Why are the changes needed?

improve pg datasource
### Does this PR introduce _any_ user-facing change?

yes, we support `array(array(...))` types or so, while _LEGACY_ERROR_TEMP_2082 is raised before

### How was this patch tested?
new tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#45815 from yaooqinn/SPARK-47691.

Authored-by: Kent Yao <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
yaooqinn authored and dongjoon-hyun committed Apr 3, 2024
1 parent 55b5ff6 commit 49b7b6b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.text.SimpleDateFormat
import java.time.LocalDateTime
import java.util.Properties

import org.apache.spark.SparkException
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -493,4 +494,26 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
Array[Byte](48, 48, 48, 48, 49), Array[Byte](48, 48, 48, 49, 48)))
checkAnswer(df, expected)
}

test("SPARK-47691: multiple dimensional array") {
sql("select array(1, 2) as col0").write
.jdbc(jdbcUrl, "single_dim_array", new Properties)
checkAnswer(spark.read.jdbc(jdbcUrl, "single_dim_array", new Properties), Row(Seq(1, 2)))

sql("select array(array(1, 2), array(3, 4)) as col0").write
.jdbc(jdbcUrl, "double_dim_array", new Properties)
sql("select array(array(array(1, 2), array(3, 4)), array(array(5, 6), array(7, 8))) as col0")
.write.jdbc(jdbcUrl, "triple_dim_array", new Properties)
// Reading multi-dimensional array is not supported yet.
checkError(
exception = intercept[SparkException] {
spark.read.jdbc(jdbcUrl, "double_dim_array", new Properties).collect()
},
errorClass = null)
checkError(
exception = intercept[SparkException] {
spark.read.jdbc(jdbcUrl, "triple_dim_array", new Properties).collect()
},
errorClass = null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.time.{Instant, LocalDate}
import java.util
import java.util.concurrent.TimeUnit

import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.Try
Expand Down Expand Up @@ -668,20 +669,32 @@ object JdbcUtils extends Logging with SQLConfHelper {

case ArrayType(et, _) =>
// remove type length parameters from end of type name
val typeName = getJdbcType(et, dialect).databaseTypeDefinition.split("\\(")(0)
(stmt: PreparedStatement, row: Row, pos: Int) =>
et match {
case TimestampNTZType =>
val array = row.getSeq[java.time.LocalDateTime](pos)
val arrayType = conn.createArrayOf(
typeName,
getJdbcType(et, dialect).databaseTypeDefinition.split("\\(")(0),
array.map(dialect.convertTimestampNTZToJavaTimestamp).toArray)
stmt.setArray(pos + 1, arrayType)
case _ =>
val array = row.getSeq[AnyRef](pos)
@tailrec
def getElementTypeName(dt: DataType): String = dt match {
case ArrayType(et0, _) => getElementTypeName(et0)
case a: AtomicType => getJdbcType(a, dialect).databaseTypeDefinition.split("\\(")(0)
case _ => throw QueryExecutionErrors.nestedArraysUnsupportedError()
}

def toArray(seq: scala.collection.Seq[Any], dt: DataType): Array[Any] = dt match {
case ArrayType(et0, _) =>
seq.map(i => toArray(i.asInstanceOf[scala.collection.Seq[Any]], et0)).toArray
case _ => seq.toArray
}

val seq = row.getSeq[AnyRef](pos)
val arrayType = conn.createArrayOf(
typeName,
array.toArray)
getElementTypeName(et),
toArray(seq, et).asInstanceOf[Array[AnyRef]])
stmt.setArray(pos + 1, arrayType)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ private case class PostgresDialect() extends JdbcDialect with SQLConfHelper {
case ShortType | ByteType => Some(JdbcType("SMALLINT", Types.SMALLINT))
case t: DecimalType => Some(
JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC))
case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
case ArrayType(et, _) if et.isInstanceOf[AtomicType] || et.isInstanceOf[ArrayType] =>
getJDBCType(et).map(_.databaseTypeDefinition)
.orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
.map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
Expand Down

0 comments on commit 49b7b6b

Please sign in to comment.