Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.SparkThrowable
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.SqlScriptingContextManager
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody, ExceptionHandlerType, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.classic.{DataFrame, SparkSession}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -86,6 +86,33 @@ class SqlScriptingExecution(
currExecPlan.curr = Some(new LeaveStatementExec(label))
}

/**
* Helper method to skip a conditional statement in the execution plan.
* @param executionPlan Execution plan to skip a conditional statement in.
*/
private def skipConditionalStatement(executionPlan: NonLeafStatementExec): Unit = {
// Go as deep as possible, to find a leaf node. Instead of a statement that
// should be executed next, skip a conditional statement in.
var currExecPlan = executionPlan
while (currExecPlan.curr.exists(_.isInstanceOf[NonLeafStatementExec])) {
currExecPlan = currExecPlan.curr.get.asInstanceOf[NonLeafStatementExec]
}

currExecPlan match {
case exec: IfElseStatementExec =>
exec.curr = None
case exec: SimpleCaseStatementExec =>
exec.skipSimpleCaseStatement()
case exec: SearchedCaseStatementExec =>
exec.curr = None
case exec: WhileStatementExec =>
exec.curr = None
case exec: ForStatementExec =>
exec.skipForStatement()
case _ =>
}
}

/** Helper method to iterate get next statements from the first available frame. */
private def getNextStatement: Option[CompoundStatementExec] = {
// Remove frames that are already executed.
Expand All @@ -103,14 +130,24 @@ class SqlScriptingExecution(

// If the last frame is a handler, set leave statement to be the next one in the
// innermost scope that should be exited.
if (lastFrame.frameType == SqlScriptingFrameType.HANDLER && context.frames.nonEmpty) {
if ((lastFrame.frameType == SqlScriptingFrameType.EXIT_HANDLER
|| lastFrame.frameType == SqlScriptingFrameType.CONTINUE_HANDLER)
&& context.frames.nonEmpty) {
// Remove the scope if handler is executed.
if (context.firstHandlerScopeLabel.isDefined
&& lastFrame.scopeLabel.get == context.firstHandlerScopeLabel.get) {
context.firstHandlerScopeLabel = None
}
// Inject leave statement into the execution plan of the last frame.
injectLeaveStatement(context.frames.last.executionPlan, lastFrame.scopeLabel.get)

if (lastFrame.frameType == SqlScriptingFrameType.EXIT_HANDLER) {
// Inject leave statement into the execution plan of the last frame.
injectLeaveStatement(context.frames.last.executionPlan, lastFrame.scopeLabel.get)
}

if (lastFrame.frameType == SqlScriptingFrameType.CONTINUE_HANDLER
&& context.frames.nonEmpty) {
skipConditionalStatement(context.frames.last.executionPlan)
}
}
}
// If there are still frames available, get the next statement.
Expand Down Expand Up @@ -169,7 +206,11 @@ class SqlScriptingExecution(
case Some(handler) =>
val handlerFrame = new SqlScriptingExecutionFrame(
handler.body,
SqlScriptingFrameType.HANDLER,
if (handler.handlerType == ExceptionHandlerType.CONTINUE) {
SqlScriptingFrameType.CONTINUE_HANDLER
} else {
SqlScriptingFrameType.EXIT_HANDLER
},
Comment on lines +209 to +213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this ExceptionHandlerType.EXIT/CONTINUE and SqlScriptingFrameType.EXIT_HANDLER/CONTINUE_HANDLER is a bit confusing. We should maybe find a way to reuse information about handler type in execution frame. Let's discuss this offline first.

handler.scopeLabel
)
context.frames.append(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class SqlScriptingExecutionContext extends SqlScriptingExecutionContextExtension
}

// If the last frame is a handler, try to find a handler in its body first.
if (frames.last.frameType == SqlScriptingFrameType.HANDLER) {
if (frames.last.frameType == SqlScriptingFrameType.EXIT_HANDLER
|| frames.last.frameType == SqlScriptingFrameType.CONTINUE_HANDLER) {
val handler = frames.last.findHandler(condition, sqlState, firstHandlerScopeLabel)
if (handler.isDefined) {
return handler
Expand All @@ -83,7 +84,7 @@ class SqlScriptingExecutionContext extends SqlScriptingExecutionContextExtension

object SqlScriptingFrameType extends Enumeration {
type SqlScriptingFrameType = Value
val SQL_SCRIPT, HANDLER = Value
val SQL_SCRIPT, EXIT_HANDLER, CONTINUE_HANDLER = Value
}

/**
Expand Down Expand Up @@ -141,7 +142,9 @@ class SqlScriptingExecutionFrame(
sqlState: String,
firstHandlerScopeLabel: Option[String]): Option[ExceptionHandlerExec] = {

val searchScopes = if (frameType == SqlScriptingFrameType.HANDLER) {
val searchScopes =
if (frameType == SqlScriptingFrameType.EXIT_HANDLER
|| frameType == SqlScriptingFrameType.CONTINUE_HANDLER) {
// If the frame is a handler, search for the handler in its body. Don't skip any scopes.
scopes.reverseIterator
} else if (firstHandlerScopeLabel.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,12 @@ class SimpleCaseStatementExec(
conditionBodyTupleIterator
}

protected[scripting] def skipSimpleCaseStatement(): Unit = {
// Force variables to output false on the next hasNext
this.state = CaseState.Body
this.bodyExec = None
}

private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {
override def hasNext: Boolean = state match {
Expand Down Expand Up @@ -1025,6 +1031,12 @@ class ForStatementExec(
*/
private var firstIteration: Boolean = true

protected[scripting] def skipForStatement(): Unit = {
// Force variables to output false on the next hasNext
this.state = ForState.Body
this.bodyWithVariables = None
}

private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ case class SqlScriptingInterpreter(session: SparkSession) {
context)

// Execution node of handler.
val handlerScopeLabel = if (handler.handlerType == ExceptionHandlerType.EXIT) {
val handlerScopeLabel = if (handler.handlerType == ExceptionHandlerType.EXIT
|| handler.handlerType == ExceptionHandlerType.CONTINUE) {
Comment on lines +91 to +92
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get this change. We either leave the if as is or we remove the condition, as these are the only 2 possible handler types.

Not sure if we need this label in a CONTINUE handler though, cc @miland-db what is this label used for again?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the label of the scope where exception handler is defined. @TeodorDjelic as you go through code feel free to add even more comments explaining what's going on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a great place for it.

Some(compoundBody.label.get)
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
* For full functionality tests, see SqlScriptingParserSuite and SqlScriptingInterpreterSuite.
*/
class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {

protected override def beforeAll(): Unit = {
super.beforeAll()
conf.setConf(SQLConf.SQL_SCRIPTING_CONTINUE_HANDLER_ENABLED, true)
}

protected override def afterAll(): Unit = {
conf.unsetConf(SQLConf.SQL_SCRIPTING_CONTINUE_HANDLER_ENABLED.key)
super.afterAll()
}

// Helpers
private def verifySqlScriptResult(
sqlText: String,
Expand Down Expand Up @@ -77,7 +88,7 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {
}
}

test("Scripting with exception handlers") {
test("Scripting with exit exception handlers") {
val sqlScript =
"""
|BEGIN
Expand All @@ -104,6 +115,36 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {
verifySqlScriptResult(sqlScript, Seq(Row(2)))
}

test("Scripting with continue exception handlers") {
val sqlScript =
"""
|BEGIN
| DECLARE flag1 INT = -1;
| DECLARE flag2 INT = -1;
| DECLARE CONTINUE HANDLER FOR DIVIDE_BY_ZERO
| BEGIN
| SELECT flag1;
| SET flag1 = 1;
| END;
| BEGIN
| DECLARE CONTINUE HANDLER FOR SQLSTATE '22012'
| BEGIN
| SELECT flag1;
| SET flag1 = 2;
| END;
| SELECT 5;
| SET flag2 = 1;
| SELECT 1/0;
| SELECT 6;
| SET flag2 = 2;
| END;
| SELECT 7;
| SELECT flag1, flag2;
|END
|""".stripMargin
verifySqlScriptResult(sqlScript, Seq(Row(2, 2)))
}

test("single select") {
val sqlText = "SELECT 1;"
verifySqlScriptResult(sqlText, Seq(Row(1)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, LeafNode, OneRowRelation, Project, SetVariable}
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.classic.{DataFrame, SparkSession}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

Expand All @@ -32,6 +33,16 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
* It is then checked if the leaf statements have been iterated in the expected order.
*/
class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSession {
protected override def beforeAll(): Unit = {
super.beforeAll()
conf.setConf(SQLConf.SQL_SCRIPTING_CONTINUE_HANDLER_ENABLED, true)
}

protected override def afterAll(): Unit = {
conf.unsetConf(SQLConf.SQL_SCRIPTING_CONTINUE_HANDLER_ENABLED.key)
super.afterAll()
}

// Helpers
case class TestCompoundBody(
override val statements: Seq[CompoundStatementExec],
Expand Down
Loading