Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pretty print SQL #78

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions sqlest/src/main/scala/sqlest/executor/Database.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Session(database: Database) extends Logging {

def executeSelect[A](select: Select[_, _])(extractor: ResultSet => A): A =
withConnection { connection =>
val (preprocessedSelect, sql, argumentLists) = database.statementBuilder(select)
val (preprocessedSelect, sql, argumentLists, prettySql) = database.statementBuilder(select)
try {
val startTime = new DateTime
val preparedStatement = prepareStatement(connection, preprocessedSelect, sql, argumentLists)
Expand All @@ -92,7 +92,7 @@ class Session(database: Database) extends Logging {
try {
val result = extractor(resultSet)
val endTime = new DateTime
logger.info(s"Ran sql in ${endTime.getMillis - startTime.getMillis}ms: ${logDetails(connection, sql, argumentLists)}")
logger.info(s"Ran sql in ${endTime.getMillis - startTime.getMillis}ms: ${logDetails(connection, prettySql, argumentLists)}")
result
} finally {
try {
Expand Down Expand Up @@ -177,11 +177,16 @@ class Session(database: Database) extends Logging {

protected def logDetails(connection: Connection, sql: String, argumentLists: List[List[LiteralColumn[_]]]) = {
val connectionLog = database.connectionDescription.map(connectionDescription => s", connection [${connectionDescription(connection)}]").getOrElse("")

val argumentsLog =
if (argumentLists.size == 1) argumentLists.head.map(_.value).mkString(", ")
else argumentLists.map(_.map(_.value).mkString("(", ", ", ")")).mkString(", ")

s"sql [$sql], arguments [$argumentsLog]${connectionLog}"
s"""sql [
|
|$sql
|
|], arguments [$argumentsLog]${connectionLog}""".stripMargin
}
}

Expand Down Expand Up @@ -265,14 +270,14 @@ case class Transaction(database: Database) extends Session(database) {

def executeCommand(command: Command): Int =
withConnection { connection =>
val (preprocessedCommand, sql, argumentLists) = database.statementBuilder(command)
val (preprocessedCommand, sql, argumentLists, prettySql) = database.statementBuilder(command)
val startTime = new DateTime
try {
val preparedStatement = prepareStatement(connection, preprocessedCommand, sql, argumentLists)
try {
val result = preparedStatement.executeBatch.sum
val endTime = new DateTime
logger.info(s"Ran sql in ${endTime.getMillis - startTime.getMillis}ms: ${logDetails(connection, sql, argumentLists)}")
logger.info(s"Ran sql in ${endTime.getMillis - startTime.getMillis}ms: ${logDetails(connection, prettySql, argumentLists)}")
result
} finally {
try {
Expand All @@ -288,7 +293,7 @@ case class Transaction(database: Database) extends Session(database) {

def executeInsertReturningKeys[T](command: Insert)(implicit columnType: ColumnType[T]): List[T] =
withConnection { connection =>
val (preprocessedCommand, sql, argumentLists) = database.statementBuilder(command)
val (preprocessedCommand, sql, argumentLists, prettySql) = database.statementBuilder(command)
val startTime = new DateTime
try {
val preparedStatement = prepareStatement(
Expand All @@ -303,7 +308,7 @@ case class Transaction(database: Database) extends Session(database) {
val rs = preparedStatement.getGeneratedKeys
val keys = IndexedExtractor[T](1).extractAll(ResultSetIterable(rs))
val endTime = new DateTime
logger.info(s"Ran sql in ${endTime.getMillis - startTime.getMillis}ms: ${logDetails(connection, sql, argumentLists)}")
logger.info(s"Ran sql in ${endTime.getMillis - startTime.getMillis}ms: ${logDetails(connection, prettySql, argumentLists)}")
keys
} finally {
try {
Expand Down
60 changes: 41 additions & 19 deletions sqlest/src/main/scala/sqlest/sql/DB2StatementBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ trait DB2StatementBuilder extends base.StatementBuilder {
case mappedColumnType: MappedColumnType[_, _] => castLiteralSql(mappedColumnType.baseColumnType)
}

override def selectSql(select: Select[_, _ <: Relation]): String = {
override def selectSql(select: Select[_, _ <: Relation], indent: Int): String = {
val offset = select.offset getOrElse 0L
if (offset > 0L) {
rowNumberSelectSql(select, offset, select.limit)
rowNumberSelectSql(select, offset, select.limit, indent)
} else {
super.selectSql(select)
super.selectSql(select, indent)
}
}

Expand All @@ -71,40 +71,62 @@ trait DB2StatementBuilder extends base.StatementBuilder {
override def selectOffsetSql(offset: Option[Long]): Option[String] =
None

override def joinSql(relation: Relation): String = relation match {
case tableFunctionApplication: TableFunctionApplication[_] => "table(" + functionSql(tableFunctionApplication.tableName, tableFunctionApplication.parameterColumns.map(addTypingToSqlColumn)) + ") as " + identifierSql(tableFunctionApplication.tableAlias)
case TableFunctionFromSelect(select, alias) => "table(" + selectSql(select) + ") as " + identifierSql(alias)
case LeftExceptionJoin(left, right, condition) => joinSql(left) + " left exception join " + joinSql(right) + " on " + columnSql(condition)
case RightExceptionJoin(left, right, condition) => joinSql(left) + " right exception join " + joinSql(right) + " on " + columnSql(condition)
case _ => super.joinSql(relation)
override def joinSql(relation: Relation, indent: Int): String = relation match {
case tableFunctionApplication: TableFunctionApplication[_] => "table(" + functionSql(tableFunctionApplication.tableName, tableFunctionApplication.parameterColumns.map(addTypingToSqlColumn), indent) + ") as " + identifierSql(tableFunctionApplication.tableAlias)
case TableFunctionFromSelect(select, alias) =>
"table(" +
onNewLine(selectSql(select, indent + TabWidth), indent + TabWidth) +
onNewLine(") as " + identifierSql(alias), indent)
case LeftExceptionJoin(left, right, condition) =>
joinSql(left, indent) +
onNewLine("left exception join ", indent) +
joinSql(right, indent) +
onNewLine("on ", indent) +
columnSql(condition, indent)
case RightExceptionJoin(left, right, condition) =>
joinSql(left, indent) +
onNewLine("right exception join ", indent) +
joinSql(right, indent) +
onNewLine("on ", indent) +
columnSql(condition, indent)
case _ => super.joinSql(relation, indent)
}

def rowNumberSelectSql(select: Select[_, _ <: Relation], offset: Long, limit: Option[Long]): String = {
def rowNumberSelectSql(select: Select[_, _ <: Relation], offset: Long, limit: Option[Long], indent: Int): String = {
val orderBy = selectOrderBySql(select.orderBy, indent).getOrElse("")
val whatColumns = Seq(selectWhatSql(select.columns, indent + TabWidth), s"row_number() over ($orderBy) as rownum")
val whatSql = withLineBreaks(whatColumns, indent + (TabWidth * 2))("", ", ", "")

val subquery = Seq(
s"${selectWhatSql(select.columns)}, row_number() over (${selectOrderBySql(select.orderBy) getOrElse ""}) as rownum",
selectFromSql(select.from)
whatSql,
selectFromSql(select.from, indent + TabWidth)
) ++ Seq(
selectWhereSql(select.where),
selectGroupBySql(select.groupBy)
).flatten mkString " "
selectWhereSql(select.where, indent + TabWidth),
selectGroupBySql(select.groupBy, indent + TabWidth)
).flatten mkString (NewLine + padding(indent + TabWidth))

val what =
select.columns map (col => identifierSql(col.columnAlias)) mkString ", "
withLineBreaks(select.columns.map(col => identifierSql(col.columnAlias)), indent)("", ", ", "")

val bounds = limit
.map(limit => s"rownum between ? and ?")
.getOrElse(s"rownum >= ?")

s"with subquery as ($subquery) select $what from subquery where $bounds"
s"with subquery as (" +
onNewLine(subquery, indent + TabWidth) +
onNewLine(")", indent) +
onNewLine(s"select $what", indent) +
onNewLine("from subquery", indent) +
onNewLine(s"where $bounds", indent)
}

override def columnSql(column: Column[_]): String =
override def columnSql(column: Column[_], indent: Int): String =
column match {
case literalColumn: LiteralColumn[_] if literalColumn.columnType == BooleanColumnType =>
if (literalColumn.value == true) "(? = ?)" else "(? <> ?)"
case constantColumn: ConstantColumn[_] if constantColumn.columnType == BooleanColumnType =>
if (constantColumn.value == true) "(0 = 0)" else "(0 <> 0)"
case _ => super.columnSql(column)
case _ => super.columnSql(column, indent)
}

override def selectArgs(select: Select[_, _ <: Relation]): List[LiteralColumn[_]] = {
Expand Down
4 changes: 2 additions & 2 deletions sqlest/src/main/scala/sqlest/sql/H2StatementBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package sqlest.sql
import sqlest.ast._

trait H2StatementBuilder extends base.StatementBuilder {
override def groupSql(group: Group): String = group match {
override def groupSql(group: Group, indent: Int): String = group match {
case group: FunctionGroup => throw new UnsupportedOperationException
case group => super.groupSql(group)
case group => super.groupSql(group, indent)
}
}

Expand Down
Loading