Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: (WIP) Stdlib functions #102

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ to create `TypedColumn`s and with those a new Dataset from pieces of another usi
```kotlin
val dataset: Dataset<YourClass> = ...
val newDataset: Dataset<Pair<TypeA, TypeB>> = dataset.selectTyped(col(YourClass::colA), col(YourClass::colB))

// Alternatively, for instance when working with a Dataset<Row>
val typedDataset: Dataset<Pair<String, Int>> = otherDataset.selectTyped(col("a").`as`<String>(), col("b").`as`<Int>())
```

### Overload resolution ambiguity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ object KSparkExtensions {

def collectAsList[T](ds: Dataset[T]): util.List[T] = JavaConverters.seqAsJavaList(ds.collect())

def tailAsList[T](ds: Dataset[T], n: Int): util.List[T] = util.Arrays.asList(ds.tail(n) : _*)

def debugCodegen(df: Dataset[_]): Unit = {
import org.apache.spark.sql.execution.debug._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,12 +647,19 @@ operator fun Column.get(key: Any): Column = getItem(key)
fun lit(a: Any) = functions.lit(a)

/**
* Provides a type hint about the expected return value of this column. This information can
* Provides a type hint about the expected return value of this column. This information can
* be used by operations such as `select` on a [Dataset] to automatically convert the
* results into the correct JVM types.
*
* ```
* val df: Dataset<Row> = ...
* val typedColumn: Dataset<Int> = df.selectTyped( col("a").`as`<Int>() )
* ```
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T> Column.`as`(): TypedColumn<Any, T> = `as`(encoder<T>())


/**
* Alias for [Dataset.joinWith] which passes "left" argument
* and respects the fact that in result of left join right relation is nullable
Expand Down Expand Up @@ -809,45 +816,74 @@ fun <T> Dataset<T>.showDS(numRows: Int = 20, truncate: Boolean = true) = apply {
/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1> Dataset<T>.selectTyped(
c1: TypedColumn<out Any, U1>,
): Dataset<U1> = select(c1 as TypedColumn<T, U1>)

/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1, reified U2> Dataset<T>.selectTyped(
c1: TypedColumn<T, U1>,
c2: TypedColumn<T, U2>,
c1: TypedColumn<out Any, U1>,
c2: TypedColumn<out Any, U2>,
): Dataset<Pair<U1, U2>> =
select(c1, c2).map { Pair(it._1(), it._2()) }
select(
c1 as TypedColumn<T, U1>,
c2 as TypedColumn<T, U2>,
).map { Pair(it._1(), it._2()) }

/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1, reified U2, reified U3> Dataset<T>.selectTyped(
c1: TypedColumn<T, U1>,
c2: TypedColumn<T, U2>,
c3: TypedColumn<T, U3>,
c1: TypedColumn<out Any, U1>,
c2: TypedColumn<out Any, U2>,
c3: TypedColumn<out Any, U3>,
): Dataset<Triple<U1, U2, U3>> =
select(c1, c2, c3).map { Triple(it._1(), it._2(), it._3()) }
select(
c1 as TypedColumn<T, U1>,
c2 as TypedColumn<T, U2>,
c3 as TypedColumn<T, U3>,
).map { Triple(it._1(), it._2(), it._3()) }

/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1, reified U2, reified U3, reified U4> Dataset<T>.selectTyped(
c1: TypedColumn<T, U1>,
c2: TypedColumn<T, U2>,
c3: TypedColumn<T, U3>,
c4: TypedColumn<T, U4>,
c1: TypedColumn<out Any, U1>,
c2: TypedColumn<out Any, U2>,
c3: TypedColumn<out Any, U3>,
c4: TypedColumn<out Any, U4>,
): Dataset<Arity4<U1, U2, U3, U4>> =
select(c1, c2, c3, c4).map { Arity4(it._1(), it._2(), it._3(), it._4()) }
select(
c1 as TypedColumn<T, U1>,
c2 as TypedColumn<T, U2>,
c3 as TypedColumn<T, U3>,
c4 as TypedColumn<T, U4>,
).map { Arity4(it._1(), it._2(), it._3(), it._4()) }

/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1, reified U2, reified U3, reified U4, reified U5> Dataset<T>.selectTyped(
c1: TypedColumn<T, U1>,
c2: TypedColumn<T, U2>,
c3: TypedColumn<T, U3>,
c4: TypedColumn<T, U4>,
c5: TypedColumn<T, U5>,
c1: TypedColumn<out Any, U1>,
c2: TypedColumn<out Any, U2>,
c3: TypedColumn<out Any, U3>,
c4: TypedColumn<out Any, U4>,
c5: TypedColumn<out Any, U5>,
): Dataset<Arity5<U1, U2, U3, U4, U5>> =
select(c1, c2, c3, c4, c5).map { Arity5(it._1(), it._2(), it._3(), it._4(), it._5()) }

select(
c1 as TypedColumn<T, U1>,
c2 as TypedColumn<T, U2>,
c3 as TypedColumn<T, U3>,
c4 as TypedColumn<T, U4>,
c5 as TypedColumn<T, U5>,
).map { Arity5(it._1(), it._2(), it._3(), it._4(), it._5()) }

@OptIn(ExperimentalStdlibApi::class)
inline fun <reified T> schema(map: Map<String, KType> = mapOf()) = schema(typeOf<T>(), map)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,31 +339,34 @@ class ApiTest : ShouldSpec({
SomeClass(intArrayOf(1, 2, 4), 5),
)

val typedColumnA: TypedColumn<Any, IntArray> = dataset.col("a").`as`(encoder())
val newDS1WithAs: Dataset<Int> = dataset.selectTyped(
col("b").`as`<Int>(),
)
newDS1WithAs.show()

val newDS2 = dataset.selectTyped(
val newDS2: Dataset<Pair<Int, Int>> = dataset.selectTyped(
// col(SomeClass::a), NOTE that this doesn't work on 2.4, returnting a data class with an array in it
col(SomeClass::b),
col(SomeClass::b),
)
newDS2.show()

val newDS3 = dataset.selectTyped(
val newDS3: Dataset<Triple<Int, Int, Int>> = dataset.selectTyped(
col(SomeClass::b),
col(SomeClass::b),
col(SomeClass::b),
)
newDS3.show()

val newDS4 = dataset.selectTyped(
val newDS4: Dataset<Arity4<Int, Int, Int, Int>> = dataset.selectTyped(
col(SomeClass::b),
col(SomeClass::b),
col(SomeClass::b),
col(SomeClass::b),
)
newDS4.show()

val newDS5 = dataset.selectTyped(
val newDS5: Dataset<Arity5<Int, Int, Int, Int, Int>> = dataset.selectTyped(
col(SomeClass::b),
col(SomeClass::b),
col(SomeClass::b),
Expand Down
Loading