Skip to content

Commit

Permalink
Merge pull request #16 from EMH333/master
Browse files Browse the repository at this point in the history
Add support for type aliases
  • Loading branch information
pjagielski authored Sep 7, 2020
2 parents 40f9439 + 1115cd7 commit e92f196
Show file tree
Hide file tree
Showing 14 changed files with 208 additions and 28 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ Maven:
* generates table mappings and functions for mapping from/to data classes
* type-safe SQL DSL without reading schema from existing database (code-first)
* explicit association fetching (via `leftJoin` / `innerJoin`)
* multiple data types support
* multiple data types support, including type aliases
* custom data type support (with `@Converter`), also for wrapped auto-generated ids
* you can still persist associations not directly reflected in domain model (eq. article favorites)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,15 @@ class ColumnProcessor(override val typeEnv: TypeEnvironment, private val annEnv:
}

private fun ImmutableKmType.toModelType(): Type? {
return (this.classifier as KmClassifier.Class).name
.split("/").let { Type(it.dropLast(1).joinToString(separator = "."), it.last()) }
return when(val classifier = (this.abbreviatedType?.classifier ?: this.classifier)){
is KmClassifier.Class -> {
classifier.name
.split("/").let { Type(it.dropLast(1).joinToString(separator = "."), it.last()) }
}
is KmClassifier.TypeAlias -> classifier.name
.split("/").let { Type(it.dropLast(1).joinToString(separator = "."), it.last(), this.copy(abbreviatedType = null).toModelType()) }
is KmClassifier.TypeParameter -> TODO()
}
}

private fun String.name(): Name = typeEnv.elementUtils.getName(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ enum class EnumType {

data class Type(
val packageName: String,
val simpleName: String
val simpleName: String,
val aliasOf: Type? = null
)

data class EmbeddableDefinition(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class EntityPropertyTypeValidator : Validator<EntityDefinition> {

override fun validate(el: EntityDefinition): ValidationResult {
val errors = mutableListOf<ValidationErrorMessage>()
el.properties.filter { !it.hasConverter() && !it.isEnumerated() && it.type !in supportedPropertyTypes }.forEach {
el.properties.filter { !it.hasConverter() && !it.isEnumerated() && it.type !in supportedPropertyTypes && it.type.aliasOf !in supportedPropertyTypes }.forEach {
errors.add(ValidationErrorMessage("Entity ${el.qualifiedName} has unsupported property type ${it.type}"))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class CopiedReferencesMappingsGenerator : MappingsGenerator() {
val associations = entity.getAssociations(ONE_TO_ONE, ONE_TO_MANY, MANY_TO_MANY)
associations.forEach { assoc ->
val target = graphs[assoc.target.packageName]?.get(assoc.target) ?: throw EntityNotMappedException(assoc.target)
val entityIdTypeName = entityId.asTypeName()
val entityIdTypeName = entityId.asUnderlyingTypeName()
val associationMapName = "${entity.name.asVariable()}_${assoc.name}"
val associationMapValueType = if (assoc.type in listOf(ONE_TO_MANY, MANY_TO_MANY)) "MutableSet<${target.name}>" else "${target.name}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ abstract class MappingsGenerator : SourceGenerator {
}

private fun buildToEntityMapFunc(entityType: TypeElement, entity: EntityDefinition, graphs: EntityGraphs): FunSpec {
val rootKey = entity.id?.asTypeName() ?: throw MissingIdException(entity)
val rootKey = entity.id?.asUnderlyingTypeName() ?: throw MissingIdException(entity)

val rootVal = entity.name.asVariable()
val func = FunSpec.builder("to${entity.name}Map")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class RealReferencesMappingsGenerator : MappingsGenerator() {
val associations = entity.getAssociations(ONE_TO_ONE, ONE_TO_MANY, MANY_TO_MANY)
associations.forEach { assoc ->
val target = graphs[assoc.target.packageName]?.get(assoc.target) ?: throw EntityNotMappedException(assoc.target)
val entityIdTypeName = entityId.asTypeName()
val entityIdTypeName = entityId.asUnderlyingTypeName()
val associationMapName = "${entity.name.asVariable()}_${assoc.name}"
val associationMapValueType = if (assoc.type in listOf(ONE_TO_MANY, MANY_TO_MANY)) "MutableSet<${target.name}>" else "${target.name}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class TablesGenerator : SourceGenerator {
entity.getAssociations(AssociationType.MANY_TO_ONE).forEach { assoc ->
val name = assoc.name.toString()

val columnType = assoc.targetId.type.asClassName()
val columnType = assoc.targetId.type.asUnderlyingClassName()
CodeBlock.builder()
val initializer = associationInitializer(assoc, name)
tableSpec.addProperty(
Expand All @@ -84,7 +84,7 @@ class TablesGenerator : SourceGenerator {
entity.getAssociations(AssociationType.ONE_TO_ONE).filter {it.mapped}.forEach {assoc ->
val name = assoc.name.toString()

val columnType = assoc.targetId.type.asClassName()
val columnType = assoc.targetId.type.asUnderlyingClassName()
CodeBlock.builder()
val initializer = associationInitializer(assoc, name)
tableSpec.addProperty(
Expand All @@ -103,14 +103,14 @@ class TablesGenerator : SourceGenerator {
.superclass(Table::class)
.addSuperclassConstructorParameter(CodeBlock.of("%S", assoc.joinTable))

val sourceType = entity.id.type.asClassName()
val sourceType = entity.id.type.asUnderlyingClassName()
manyToManyTableSpec.addProperty(
PropertySpec.builder("${rootVal}SourceId", Column::class.asClassName().parameterizedBy(sourceType))
.initializer(manyToManyPropertyInitializer(entity.id, entity, "_source"))
.build()
)

val targetIdType = assoc.targetId.type.asClassName()
val targetIdType = assoc.targetId.type.asUnderlyingClassName()
val targetEntityDef = graphs.entity(assoc.target.packageName, assoc.target) ?:
throw AssociationTargetEntityNotFoundException(assoc.target)
manyToManyTableSpec.addProperty(
Expand Down Expand Up @@ -155,7 +155,7 @@ class TablesGenerator : SourceGenerator {
val isGenerated = entity.id?.generatedValue ?: false
val persistedName = if (isGenerated) "persisted${entityName.capitalize()}" else entityName
val func = FunSpec.builder("insert")
.receiver(Type(entityType.packageName, entity.tableName).asClassName())
.receiver(Type(entityType.packageName, entity.tableName).asUnderlyingClassName())
.addParameter(entity.name.asVariable(), entityType.asType().asTypeName())
.returns(entityType.asType().asTypeName())

Expand Down Expand Up @@ -195,7 +195,7 @@ class TablesGenerator : SourceGenerator {
}

private fun converterFunc(name: String, type: TypeName, it: ConverterDefinition, fileSpec: FileSpec.Builder) {
val wrapperName = when (it.targetType.asClassName()) {
val wrapperName = when (it.targetType.asUnderlyingClassName()) {
STRING -> "stringWrapper"
LONG -> "longWrapper"
else -> throw TypeConverterNotSupportedException(it.targetType)
Expand All @@ -215,7 +215,7 @@ class TablesGenerator : SourceGenerator {

val codeBlock = if (id.converter != null) {
converterPropInitializer(entityName = entity.name, propertyName = id.name, columnName = id.columnName.asVariable())
} else when (id.asTypeName()) {
} else when (id.asUnderlyingTypeName()) {
STRING -> CodeBlock.of("varchar(%S, %L)", id.columnName, id.annotation?.length ?: 255)
LONG -> CodeBlock.of("long(%S)", id.columnName)
INT -> CodeBlock.of("integer(%S)", id.columnName)
Expand Down Expand Up @@ -263,7 +263,7 @@ class TablesGenerator : SourceGenerator {

private fun enumPropInitializer(property: PropertyDefinition): CodeBlock {
val columnName = property.columnName
val enumType = property.type.asClassName()
val enumType = property.type.asUnderlyingClassName()

return when (property.enumerated!!.enumType) {
EnumType.STRING -> {
Expand All @@ -275,7 +275,7 @@ class TablesGenerator : SourceGenerator {
}

private fun typePropInitializer(property: PropertyDefinition): CodeBlock {
return when (property.asTypeName()) {
return when (property.asUnderlyingTypeName()) {
STRING -> CodeBlock.of("varchar(%S, %L)", property.columnName, property.annotation?.length ?: 255)
LONG -> CodeBlock.of("long(%S)", property.columnName)
BOOLEAN -> CodeBlock.of("bool(%S)", property.columnName)
Expand Down Expand Up @@ -318,7 +318,7 @@ class TablesGenerator : SourceGenerator {
private fun idCodeBlock(id: IdDefinition, entityName: Name, columnName: String): CodeBlock {
return if (id.converter != null) {
converterPropInitializer(entityName = entityName, propertyName = id.name, columnName = columnName)
} else when (id.asTypeName()) {
} else when (id.asUnderlyingTypeName()) {
STRING -> CodeBlock.of("varchar(%S, %L)", columnName, id.annotation?.length ?: 255)
LONG -> CodeBlock.of("long(%S)", columnName)
INT -> CodeBlock.of("integer(%S)", columnName)
Expand All @@ -329,6 +329,22 @@ class TablesGenerator : SourceGenerator {
}
}

fun IdDefinition.asUnderlyingTypeName(): TypeName {
return this.type.asUnderlyingClassName()
}

fun PropertyDefinition.asUnderlyingTypeName(): TypeName {
return this.type.asUnderlyingClassName()
}

fun Type.asUnderlyingClassName(): ClassName {
return if(this.aliasOf !=null) {
ClassName(this.aliasOf.packageName, this.aliasOf.simpleName)
}else{
ClassName(this.packageName, this.simpleName)
}
}

fun IdDefinition.asTypeName(): TypeName {
return this.type.asClassName()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,8 @@ package pl.touk.example
import java.time.LocalDate
import java.time.LocalDateTime
import java.time.ZonedDateTime
import javax.persistence.Column
import javax.persistence.Embeddable
import javax.persistence.Embedded
import javax.persistence.Entity
import javax.persistence.*
import javax.persistence.EnumType.STRING
import javax.persistence.Enumerated
import javax.persistence.GeneratedValue
import javax.persistence.Id
import javax.persistence.JoinColumn
import javax.persistence.OneToOne
import javax.persistence.Table

@Entity
data class DefaultPropertyNameEntity(
Expand Down Expand Up @@ -48,6 +39,29 @@ data class NullablePropertyEntity(
val prop1: String?
)

typealias StringMap = Map<String, String>
typealias PlainString = String

@Entity
data class TypeAliasEntity(
@Id @GeneratedValue
val id: Long?,
@Convert(converter = StringMapConverter::class)
val aliased: StringMap,
val justAString: PlainString
)

@Converter
class StringMapConverter : AttributeConverter<StringMap, String> {
override fun convertToDatabaseColumn(attribute: StringMap?): String {
return attribute?.map { it.key + ":" + it.value }?.joinToString("\n") ?: ""
}

override fun convertToEntityAttribute(dbData: String?): StringMap {
return dbData?.splitToSequence("\n")?.associate { Pair(it.split(":")[0], it.split(":")[1]) } ?: HashMap()
}
}

@Entity
data class OneToOneSourceEntity(

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,21 @@ import javax.lang.model.util.Types
.containsKey(enumPropertyEntity(getTypeEnv()))
.containsValue(enumPropertyEntityDefinition(getTypeEnv()))
}

@Test
fun shouldHandleTypeAliases(){
//given
val typealiasGraphBuilder = typealiasGraphBuilder(getTypeEnv())

//when
val graphs = typealiasGraphBuilder.build()

//then
assertThat(graphs).containsKey("pl.touk.example")

assertThat(graphs["pl.touk.example"])
.containsKey(typealiasEntity(getTypeEnv()))
.containsValue(typealiasEntityDefinition(getTypeEnv()))
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ interface EntityGraphSampleData {
return getTypeElement("pl.touk.example.PropertyTypeUnsupportedEntity", typeEnvironment.elementUtils)
}

fun typealiasEntity(typeEnvironment: TypeEnvironment): TypeElement {
return getTypeElement("pl.touk.example.TypeAliasEntity", typeEnvironment.elementUtils)
}

fun customerGraphBuilder(typeEnvironment: TypeEnvironment): EntityGraphBuilder {
val entity = customerTestEntity(typeEnvironment)
val id = getVariableElement(entity, typeEnvironment.elementUtils, "id")
Expand Down Expand Up @@ -451,6 +455,56 @@ interface EntityGraphSampleData {
)
}

fun typealiasGraphBuilder(typeEnvironment: TypeEnvironment): EntityGraphBuilder {
val elements = typeEnvironment.elementUtils

val entity = typealiasEntity(typeEnvironment)
val id = getVariableElement(entity, elements, "id")
val aliased = getVariableElement(entity, elements, "aliased")
val plainString = getVariableElement(entity, elements, "justAString")

val annEnv = AnnotationEnvironment(entities = listOf(entity), ids = listOf(id),
columns = listOf(aliased, plainString), oneToMany = emptyList(), manyToOne = emptyList(),
manyToMany = emptyList(), oneToOne = emptyList(), embedded = emptyList(), embeddedColumn = emptyList())

return EntityGraphBuilder(typeEnvironment, annEnv)
}

fun typealiasEntityDefinition(typeEnvironment: TypeEnvironment): EntityDefinition {
val entity = typealiasEntity(typeEnvironment)
val id = getVariableElement(entity, typeEnvironment.elementUtils, "id")
val prop1 = getVariableElement(entity, typeEnvironment.elementUtils, "aliased")
val plainString = getVariableElement(entity, typeEnvironment.elementUtils, "justAString")

return EntityDefinition(
name = entity.simpleName, qualifiedName = entity.qualifiedName,
table = entity.simpleName.asVariable(),
id = autoGenIdDefinition(id, typeEnvironment.elementUtils.getName(id.simpleName)),
properties = listOf(
propertyDefinition(
typeEnvironment,
prop1,
"aliased",
Type(
packageName = "pl.touk.example",
simpleName = "StringMap",
aliasOf = Type("kotlin.collections","Map")),
false)
.copy(converter = ConverterDefinition(
name = "pl.touk.example.StringMapConverter",
targetType = Type(packageName = "kotlin", simpleName = "String"
))),
propertyDefinition(
typeEnvironment,
plainString,
"justAString",
Type(
packageName = "pl.touk.example",
simpleName = "PlainString",
aliasOf = Type("kotlin", "String")),
false)))
}

private fun autoGenIdDefinition(id: VariableElement, name: Name): IdDefinition {
return IdDefinition(
name = id.simpleName,
Expand Down
26 changes: 26 additions & 0 deletions example/src/main/kotlin/pl/touk/krush/typealiases/VisitorLog.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package pl.touk.krush.typealiases

import javax.persistence.*

typealias VisitorList = List<String>
typealias PlainString = String

@Entity
data class VisitorLog(
@Id @GeneratedValue
val id: Long? = null,
@Convert(converter = VisitorListConverter::class)
val visitors: VisitorList,
val guard: PlainString
)

@Converter
class VisitorListConverter : AttributeConverter<VisitorList, String> {
override fun convertToDatabaseColumn(attribute: VisitorList?): String {
return attribute?.joinToString(",") ?: ""
}

override fun convertToEntityAttribute(dbData: String?): VisitorList {
return dbData?.split(",") ?: emptyList()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package pl.touk.krush.typealiases

import org.assertj.core.api.Assertions
import org.jetbrains.exposed.sql.SchemaUtils
import org.jetbrains.exposed.sql.select
import org.jetbrains.exposed.sql.transactions.transaction
import org.junit.jupiter.api.Test
import pl.touk.krush.base.BaseDatabaseTest
import pl.touk.krush.types.Event
import pl.touk.krush.types.EventTable
import pl.touk.krush.types.insert
import pl.touk.krush.types.toEventList
import java.time.*
import java.util.*

class VisitorLogTest : BaseDatabaseTest() {

@Test
fun shouldHandleTypeAliases() {
transaction {
SchemaUtils.create(VisitorLogTable)

// given
val log = VisitorLogTable.insert(VisitorLog(visitors = listOf("Krush", "Kotlin", "Gradle" ), guard = "Kelly"))

//when
val logs = (VisitorLogTable)
.select { VisitorLogTable.guard eq "Kelly" }
.toVisitorLogList()

//then
Assertions.assertThat(logs).containsOnly(log)
}
}
}
Loading

0 comments on commit e92f196

Please sign in to comment.