Skip to content

Commit

Permalink
fix: serialization with lambda functions
Browse files Browse the repository at this point in the history
  • Loading branch information
killme2008 committed Oct 4, 2023
1 parent b3784d2 commit c8122a7
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 44 deletions.
17 changes: 13 additions & 4 deletions src/main/java/com/googlecode/aviator/AviatorEvaluatorInstance.java
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,7 @@ private void loadSystemFunctions() {

private void loadInternalLibs() {
if (getEvalMode() == EvalMode.ASM) {

if (internalASMLibFunctions == null) {
internalASMLibFunctions = loadInternalFunctions(); // cache it
} else {
Expand Down Expand Up @@ -1081,10 +1082,18 @@ private Map<String, AviatorFunction> loadInternalFunctions() {
AviatorEvaluatorInstance(final EvalMode evalMode) {
fillDefaultOpts();
setOption(Options.EVAL_MODE, evalMode);
loadFeatureFunctions();
loadLib();
loadModule();
addFunctionLoader(ClassPathConfigFunctionLoader.getInstance());

// Load libs with Options.SERIALIZABLE=true
boolean serializable = this.getOptionValue(Options.SERIALIZABLE).bool;
try {
this.setOption(Options.SERIALIZABLE, true);
loadFeatureFunctions();
loadLib();
loadModule();
addFunctionLoader(ClassPathConfigFunctionLoader.getInstance());
} finally {
this.setOption(Options.SERIALIZABLE, serializable);
}
}

private void fillDefaultOpts() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public LambdaGenerator(final AviatorEvaluatorInstance instance,
this.inheritEnv = inheritEnv;
// Generate lambda class name
this.className =
"Lambda_" + System.currentTimeMillis() + "_" + LAMBDA_COUNTER.getAndIncrement();
"AviatorScript_" + System.currentTimeMillis() + "_" + LAMBDA_COUNTER.getAndIncrement();
// Auto compute frames
// this.classWriter = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
// visitClass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
**/
package com.googlecode.aviator.runtime.type;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.List;
Expand Down Expand Up @@ -45,8 +48,20 @@
public class AviatorJavaType extends AviatorObject {
private static final long serialVersionUID = -4353225521490659987L;
protected String name;
private final boolean containsDot;
private boolean containsDot;
private String[] subNames;
private SymbolTable symbolTable;

private void readObject(ObjectInputStream input) throws ClassNotFoundException, IOException {
String name = (String) input.readObject();
SymbolTable symbolTable = (SymbolTable) input.readObject();
init(name, symbolTable);
}

private void writeObject(ObjectOutputStream output) throws IOException {
output.writeObject(this.name);
output.writeObject(this.symbolTable);
}

@Override
public AviatorType getAviatorType() {
Expand All @@ -63,6 +78,10 @@ public AviatorJavaType(final String name) {

public AviatorJavaType(final String name, final SymbolTable symbolTable) {
super();
init(name, symbolTable);
}

private void init(final String name, final SymbolTable symbolTable) {
if (name != null) {
String rName = reserveName(name);
if (rName != null) {
Expand All @@ -79,6 +98,7 @@ public AviatorJavaType(final String name, final SymbolTable symbolTable) {
this.name = null;
this.containsDot = false;
}
this.symbolTable = symbolTable;
}

/**
Expand Down Expand Up @@ -136,8 +156,6 @@ public AviatorObject div(final AviatorObject other, final Map<String, Object> en
}
}



@Override
public AviatorObject match(final AviatorObject other, final Map<String, Object> env) {
Object val = getValue(env);
Expand Down Expand Up @@ -354,7 +372,6 @@ public static Object getValueFromEnv(final String name, final boolean nameContai
return null;
}


@Override
public AviatorObject defineValue(final AviatorObject value, final Map<String, Object> env) {
if (this.containsDot) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import com.googlecode.aviator.lexer.SymbolTable;
import com.googlecode.aviator.lexer.token.Variable;
import com.googlecode.aviator.parser.AviatorClassLoader;
import com.googlecode.aviator.runtime.LambdaFunctionBootstrap;
import com.googlecode.aviator.runtime.type.AviatorBigInt;
import com.googlecode.aviator.runtime.type.AviatorBoolean;
import com.googlecode.aviator.runtime.type.AviatorNil;
import com.googlecode.aviator.runtime.type.Range;
import com.googlecode.aviator.utils.Env;
import com.googlecode.aviator.utils.Reflector;

/**
Expand Down Expand Up @@ -47,14 +49,10 @@ protected Object resolveObject(Object obj) throws IOException {
Object object = super.resolveObject(obj);
if (object instanceof BaseExpression) {
BaseExpression exp = (BaseExpression) object;
exp.setInstance(this.instance);
if (exp.getCompileEnv() != null) {
exp.getCompileEnv().setInstance(this.instance);
}
if (object instanceof ClassExpression) {
((ClassExpression) object)
.setClassBytes(this.classBytesCache.get(object.getClass().getName()));
}
configureExpression(exp);
}
if (object instanceof Env) {
((Env) object).setInstance(this.instance);
}

// Processing some internal constants.
Expand Down Expand Up @@ -83,6 +81,13 @@ protected Object resolveObject(Object obj) throws IOException {
return object;
}

private void configureExpression(BaseExpression exp) {
exp.setInstance(this.instance);
if (exp instanceof ClassExpression) {
((ClassExpression) exp).setClassBytes(this.classBytesCache.get(exp.getClass().getName()));
}
}

@Override
protected Class<?> resolveClass(ObjectStreamClass desc)
throws IOException, ClassNotFoundException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ protected void annotateClass(Class<?> cl) throws IOException {
if (ClassExpression.class.isAssignableFrom(cl) && cl != ClassExpression.class) {
byte[] classBytes = this.classBytesCache.get(cl.getName());
if (classBytes == null) {
throw new IllegalArgumentException(
"Class bytes not found, forgot to enable Options.SERIALIZABLE before compiling the script?");
throw new IllegalArgumentException("Class bytes not found: " + cl.getName()
+ ", forgot to enable Options.SERIALIZABLE before compiling the script?");
}
this.writeInt(classBytes.length);
this.write(classBytes);
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/googlecode/aviator/utils/ArrayHashMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ public class ArrayHashMap<K, V> extends AbstractMap<K, V>

private static final long serialVersionUID = 362498820763181265L;

private static class MapEntry<K, V> implements Map.Entry<K, V> {
private static class MapEntry<K, V> implements Map.Entry<K, V>, Serializable {

private static final long serialVersionUID = 1759214536880718767L;
K key;
V value;
int hash;
Expand Down
99 changes: 75 additions & 24 deletions src/main/java/com/googlecode/aviator/utils/Env.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import com.googlecode.aviator.AviatorEvaluatorInstance;
import com.googlecode.aviator.Expression;
import com.googlecode.aviator.Feature;
Expand Down Expand Up @@ -296,6 +298,76 @@ public Set<Entry<String, Object>> entrySet() {
return ret;
}

static class TargetObjectTask implements GetValueTask {

public TargetObjectTask(Object target) {
super();
this.target = target;
}

Object target;

@Override
public Object call(Env env) {
return target;
}

}

static interface GetValueTask {
Object call(Env env);
}

/**
* Internal variable tasks to get the value.
*/
private static final IdentityHashMap<String, GetValueTask> INTERNAL_VARIABLES =
new IdentityHashMap<String, GetValueTask>();

static {
INTERNAL_VARIABLES.put(Constants.REDUCER_LOOP_VAR, new TargetObjectTask(Range.LOOP));
INTERNAL_VARIABLES.put(Constants.REDUCER_EMPTY_VAR,
new TargetObjectTask(ReducerResult.withEmpty(AviatorNil.NIL)));
INTERNAL_VARIABLES.put(Constants.ENV_VAR, new GetValueTask() {

@Override
public Object call(Env env) {
env.instance.ensureFeatureEnabled(Feature.InternalVars);
return env;
}

});
INTERNAL_VARIABLES.put(Constants.FUNC_ARGS_VAR, new GetValueTask() {

@Override
public Object call(Env env) {
env.instance.ensureFeatureEnabled(Feature.InternalVars);
return FunctionUtils.getFunctionArguments(env);
}

});
INTERNAL_VARIABLES.put(Constants.INSTANCE_VAR, new GetValueTask() {

@Override
public Object call(Env env) {
env.instance.ensureFeatureEnabled(Feature.InternalVars);
return env.instance;
}

});

INTERNAL_VARIABLES.put(Constants.EXP_VAR, new GetValueTask() {

@Override
public Object call(Env env) {
env.instance.ensureFeatureEnabled(Feature.InternalVars);
return env.expression;
}

});

}

/**
* Get value for key. If the key is present in the overrides map, the value from that map is
* returned; otherwise, the value for the key in the defaults map is returned.
Expand All @@ -305,30 +377,9 @@ public Set<Entry<String, Object>> entrySet() {
*/
@Override
public Object get(final Object key) {
// Should check ENV_VAR at first
// TODO: performance tweak
if (Constants.REDUCER_LOOP_VAR.equals(key)) {
return Range.LOOP;
}
if (Constants.REDUCER_EMPTY_VAR.equals(key)) {
return ReducerResult.withEmpty(AviatorNil.NIL);
}

if (Constants.ENV_VAR.equals(key)) {
this.instance.ensureFeatureEnabled(Feature.InternalVars);
return this;
}
if (Constants.FUNC_ARGS_VAR.equals(key)) {
this.instance.ensureFeatureEnabled(Feature.InternalVars);
return FunctionUtils.getFunctionArguments(this);
}
if (Constants.INSTANCE_VAR.equals(key)) {
this.instance.ensureFeatureEnabled(Feature.InternalVars);
return this.instance;
}
if (Constants.EXP_VAR.equals(key)) {
this.instance.ensureFeatureEnabled(Feature.InternalVars);
return this.expression;
GetValueTask task = INTERNAL_VARIABLES.get(key);
if (task != null) {
return task.call(this);
}

Map<String, Object> overrides = getmOverrides(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ public void testFunctions() {
assertEquals(610, testScript("fibonacci.av", "n", 15));
assertEquals(6765, testScript("fibonacci.av", "n", 20));
testScript("unpacking_arguments.av");
assertEquals(Arrays.asList(3L, 2L, 4L, 1L), testScript("recusive_fn.av"));
}

@Test
Expand Down

0 comments on commit c8122a7

Please sign in to comment.