Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SparkR-237, 238] Fix cleanClosure by including private function checks in package namespaces. #229

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,5 @@ export("sparkR.init")
export("sparkR.stop")
export("print.jobj")
useDynLib(SparkR, stringHashCode)
useDynLib(SparkR, getMissingArg)
importFrom(methods, setGeneric, setMethod, setOldClass)
16 changes: 4 additions & 12 deletions pkg/R/RDD.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val)

if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) {
# This transformation is the first in its stage:
.Object@func <- func
.Object@func <- cleanClosure(func)
.Object@prev_jrdd <- getJRDD(prev)
.Object@env$prev_serializedMode <- prev@env$serializedMode
# NOTE: We use prev_serializedMode to track the serialization mode of prev_JRDD
Expand All @@ -75,7 +75,7 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val)
pipelinedFunc <- function(split, iterator) {
func(split, prev@func(split, iterator))
}
.Object@func <- pipelinedFunc
.Object@func <- cleanClosure(pipelinedFunc)
.Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline
# Get the serialization mode of the parent RDD
.Object@env$prev_serializedMode <- prev@env$prev_serializedMode
Expand Down Expand Up @@ -123,17 +123,13 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"),
return(rdd@env$jrdd_val)
}

computeFunc <- function(split, part) {
rdd@func(split, part)
}

packageNamesArr <- serialize(.sparkREnv[[".packages"]],
connection = NULL)

broadcastArr <- lapply(ls(.broadcastNames),
function(name) { get(name, .broadcastNames) })

serializedFuncArr <- serialize(computeFunc, connection = NULL)
serializedFuncArr <- serialize(rdd@func, connection = NULL)

prev_jrdd <- rdd@prev_jrdd

Expand Down Expand Up @@ -529,11 +525,7 @@ setMethod("mapPartitions",
setMethod("lapplyPartitionsWithIndex",
signature(X = "RDD", FUN = "function"),
function(X, FUN) {
FUN <- cleanClosure(FUN)
closureCapturingFunc <- function(split, part) {
FUN(split, part)
}
PipelinedRDD(X, closureCapturingFunc)
PipelinedRDD(X, FUN)
})

#' @rdname lapplyPartitionsWithIndex
Expand Down
68 changes: 41 additions & 27 deletions pkg/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -291,76 +291,83 @@ numToInt <- function(num) {
# param
# node The current AST node in the traversal.
# oldEnv The original function environment.
# newEnv A new function environment to store necessary function dependencies, an output argument.
# defVars An Accumulator of variables names defined in the function's calling environment,
# including function argument and local variable names.
# checkedFunc An environment of function objects examined during cleanClosure. It can
# be considered as a "name"-to-"list of functions" mapping.
# newEnv A new function environment to store necessary function dependencies, an output argument.
processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
processClosure <- function(node, newEnv, oldEnv = environment(), defVars = initAccumulator(),
checkedFuncs = new.env()) {
nodeLen <- length(node)

if (nodeLen > 1 && typeof(node) == "language") {
# Recursive case: current AST node is an internal node, check for its children.
if (length(node[[1]]) > 1) {
for (i in 1:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
processClosure(node[[i]], newEnv, oldEnv, defVars, checkedFuncs)
}
} else { # if node[[1]] is length of 1, check for some R special functions.
nodeChar <- as.character(node[[1]])
if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol.
for (i in 2:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
processClosure(node[[i]], newEnv, oldEnv, defVars, checkedFuncs)
}
} else if (nodeChar == "<-" || nodeChar == "=" ||
nodeChar == "<<-") { # Assignment Ops.
for (i in 3:nodeLen) {
processClosure(node[[i]], newEnv, oldEnv, defVars, checkedFuncs)
}
defVar <- node[[2]]
if (length(defVar) == 1 && typeof(defVar) == "symbol") {
# Add the defined variable name into defVars.
addItemToAccumulator(defVars, as.character(defVar))
} else {
processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
}
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
processClosure(node[[2]], newEnv, oldEnv, defVars, checkedFuncs)
}
} else if (nodeChar == "function") { # Function definition.
# Add parameter names.
newArgs <- names(node[[2]])
lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) })
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
processClosure(node[[i]], newEnv, oldEnv, defVars, checkedFuncs)
}
defVars$counter <- defVars$counter - length(newArgs)
} else if (nodeChar == "$") { # Skip the field.
processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
processClosure(node[[2]], newEnv, oldEnv, defVars, checkedFuncs)
} else if (nodeChar == "::" || nodeChar == ":::") {
processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv)
processClosure(node[[3]], newEnv, oldEnv, defVars, checkedFuncs)
} else {
for (i in 1:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
processClosure(node[[i]], newEnv, oldEnv, defVars, checkedFuncs)
}
}
}
} else if (nodeLen == 1 &&
(typeof(node) == "symbol" || typeof(node) == "language")) {
# Base case: current AST node is a leaf node and a symbol or a function call.
nodeChar <- as.character(node)
if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable.
if (!nodeChar %in% defVars$data[1:defVars$counter]) { # Not a function parameter or local variable.
func.env <- oldEnv
topEnv <- parent.env(.GlobalEnv)
# Search in function environment, and function's enclosing environments
# up to global environment. There is no need to look into package environments
# above the global or namespace environment that is not SparkR below the global,
# as they are assumed to be loaded on workers.
# above the global, as they are assumed to be loaded on workers.
while (!identical(func.env, topEnv)) {
# Namespaces other than "SparkR" will not be searched.
# Only examine functions in non-namespace environments or private functions in
# package namespaces.
if (!isNamespace(func.env) ||
(getNamespaceName(func.env) == "SparkR" &&
!(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals.
!(nodeChar %in% getNamespaceExports(getNamespaceName(func.env)))
) {
# Set parameter 'inherits' to FALSE since we do not need to search in
# attached package environments.
if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE),
error = function(e) { FALSE })) {
obj <- get(nodeChar, envir = func.env, inherits = FALSE)
obj <- tryCatch(get(nodeChar, envir = func.env, inherits = FALSE),
error = function(e) { print(e); .Call("getMissingArg") })
if (missing(obj)) {
assign(nodeChar, .Call("getMissingArg"), envir = newEnv)
break
}
if (is.function(obj)) { # If the node is a function call.
funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
ifnotfound = list(list(NULL)))[[1]]
Expand All @@ -371,14 +378,21 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
break
}
# Function has not been examined, record it and recursively clean its closure.
assign(nodeChar,
if (is.null(funcList[[1]])) {
list(obj)
} else {
append(funcList, obj)
},
envir = checkedFuncs)
newFuncList <- if (is.null(funcList[[1]])) {
list(obj)
} else {
append(funcList, obj)
}
assign(nodeChar, newFuncList, envir = checkedFuncs)
obj <- cleanClosure(obj, checkedFuncs)
parent.env(environment(obj)) <- newEnv
# Remove examined functions.
newFuncList[[length(newFuncList)]] <- NULL
if (length(newFuncList) > 0) {
assign(nodeChar, newFuncList, envir = checkedFuncs)
} else {
remove(list = nodeChar, envir = checkedFuncs)
}
}
assign(nodeChar, obj, envir = newEnv)
break
Expand Down Expand Up @@ -414,7 +428,7 @@ cleanClosure <- function(func, checkedFuncs = new.env()) {
addItemToAccumulator(defVars, argNames[i])
}
# Recursively examine variables in the function body.
processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv)
processClosure(func.body, newEnv, oldEnv, defVars, checkedFuncs)
environment(func) <- newEnv
}
func
Expand Down
40 changes: 32 additions & 8 deletions pkg/inst/tests/test_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,35 +66,59 @@ test_that("cleanClosure on R functions", {
field <- matrix(2)
defUse <- 3
g <- function(x) { x + y }
h <- function(x) { f(x) }
f <- function(x) {
defUse <- base::as.integer(x) + 1 # Test for access operators `::`.
lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply.
l$field[1,1] <- 3 # Test for access operators `$`.
res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol.
f(res) # Test for recursive calls.
h(res) # f should be examined again in h's env. More recursive call f -> h -> f ...
}
newF <- cleanClosure(f)
env <- environment(newF)
expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse".
expect_equal(length(ls(env)), 4) # Only "g", "l", "f" and "h". No "base", "field" or "defUse".
expect_true("g" %in% ls(env))
expect_true("l" %in% ls(env))
expect_true("f" %in% ls(env))
expect_true("h" %in% ls(env))
expect_equal(get("l", envir = env, inherits = FALSE), l)
# "y" should be in the environemnt of g.
newG <- get("g", envir = env, inherits = FALSE)
newH <- get("h", envir = env, inherits = FALSE)
# "y" should be in the environemnt of g.
env <- environment(newG)
expect_equal(length(ls(env)), 1)
expect_equal(ls(env), "y")
actual <- get("y", envir = env, inherits = FALSE)
expect_equal(actual, y)
# "f" should be in h's env.
env <- environment(newH)
expect_equal(ls(env), "f")
actual <- get("f", envir = env, inherits = FALSE)
expect_equal(actual, f)

# Test for function (and variable) definitions.
f <- function(x) {
g <- function(y) { y * 2 }
g(x)
z <- c(1, 2)
f <- function(x, y) {
privateCallRes <- unlist(joinTaggedList(x, y)) # Call package private functions.
g <- function(y, z) { y * 2 }
z <- z * 2 # Write after read.
g(privateCallRes)
}
newF <- cleanClosure(f)
env <- environment(newF)
expect_equal(length(ls(env)), 0) # "y" and "g" should not be included.
expect_equal(length(ls(env)), 2) # Only "joinTaggedList" and "z". No "y" or "g".
expect_true("joinTaggedList" %in% ls(env))
expect_true("z" %in% ls(env))
actual <- get("joinTaggedList", envir = env, inherits = FALSE)
expect_equal(actual, joinTaggedList)
env <- environment(actual)
# Private "genCompactLists" and "mergeCompactLists" are called by "joinTaggedList".
expect_true("genCompactLists" %in% ls(env))
expect_true("mergeCompactLists" %in% ls(env))
actual <- get("genCompactLists", envir = env, inherits = FALSE)
expect_equal(actual, genCompactLists)
actual <- get("mergeCompactLists", envir = env, inherits = FALSE)
expect_equal(actual, mergeCompactLists)

# Test for overriding variables in base namespace (Issue: SparkR-196).
nums <- as.list(1:10)
Expand All @@ -113,7 +137,7 @@ test_that("cleanClosure on R functions", {
a <- matrix(nrow=10, ncol=10, data=rnorm(100))
aBroadcast <- broadcast(sc, a)
normMultiply <- function(x) { norm(aBroadcast$value) * x }
newnormMultiply <- SparkR:::cleanClosure(normMultiply)
newnormMultiply <- cleanClosure(normMultiply)
env <- environment(newnormMultiply)
expect_equal(ls(env), "aBroadcast")
expect_equal(get("aBroadcast", envir = env, inherits = FALSE), aBroadcast)
Expand Down
6 changes: 5 additions & 1 deletion pkg/inst/worker/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ for (pkg in packageNames) {
funcLen <- SparkR:::readInt(inputCon)
computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen))
env <- environment(computeFunc)
parent.env(env) <- .GlobalEnv # Attach under global environment.
if (length(packageNames) > 0) {
parent.env(env) <- getNamespace(packageNames[1]) # Attach under package namespace environment.
} else {
parent.env(env) <- .GlobalEnv # Attach under global environment.
}

# Read and set broadcast variables
numBroadcastVars <- SparkR:::readInt(inputCon)
Expand Down
4 changes: 2 additions & 2 deletions pkg/src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ $(MAVEN_TARGET_NAME): pom.xml $(SCALA_FILES) $(RESOURCE_FILES)
mvn -Dhadoop.version=$(SPARK_HADOOP_VERSION) -Dspark.version=$(SPARK_VERSION) -DskipTests $(MAVEN_YARN_FLAG) -Dyarn.version=$(SPARK_YARN_VERSION) clean package shade:shade
cp -f $(MAVEN_TARGET_NAME) ../inst/$(JAR_NAME)

sharelib: string_hash_code.c
R CMD SHLIB -o SparkR.so string_hash_code.c
sharelib: utils.c
R CMD SHLIB -o SparkR.so utils.c

clean:
$(BUILD_TOOL) clean
Expand Down
4 changes: 2 additions & 2 deletions pkg/src/Makefile.win
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ $(MAVEN_TARGET_NAME): $(SCALA_FILES) $(RESOURCE_FILES)
mvn.bat -Dhadoop.version=$(SPARK_HADOOP_VERSION) -Dspark.version=$(SPARK_VERSION) -Dyarn.version=$(SPARK_YARN_VERSION) -DskipTests clean package shade:shade
cp -f $(MAVEN_TARGET_NAME) ../inst/$(JAR_NAME)

sharelib: string_hash_code.c
R CMD SHLIB -o SparkR.dll string_hash_code.c
sharelib: utils.c
R CMD SHLIB -o SparkR.dll utils.c

clean:
mvn.bat clean
Expand Down
19 changes: 13 additions & 6 deletions pkg/src/string_hash_code.c → pkg/src/utils.c
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
/*
* A C function for R extension which implements the Java String hash algorithm.
* Refer to http://en.wikipedia.org/wiki/Java_hashCode%28%29#The_java.lang.String_hash_function
*
*/

#include <R.h>
#include <Rinternals.h>

Expand All @@ -12,6 +6,11 @@
#define IS_SCALAR(x, type) (TYPEOF(x) == (type) && XLENGTH(x) == 1)
#endif

/*
* A C function for R extension which implements the Java String hash algorithm.
* Refer to http://en.wikipedia.org/wiki/Java_hashCode%28%29#The_java.lang.String_hash_function
*
*/
SEXP stringHashCode(SEXP string) {
const char* str;
R_xlen_t len, i;
Expand All @@ -30,3 +29,11 @@ SEXP stringHashCode(SEXP string) {

return ScalarInteger(hashCode);
}

/*
* Get the value of missing argument symbol.
*/
SEXP getMissingArg() {
return R_MissingArg;
}