Skip to content

Commit

Permalink
Merge pull request #369 from InseeFr/refactor-case
Browse files Browse the repository at this point in the history
Refactor case
  • Loading branch information
NicoLaval authored Nov 5, 2024
2 parents b4e7224 + 242ad21 commit 1dd9e40
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ public class VtlNativeMethods {
Fun.<Boolean, Double, Double>toMethod(ConditionalVisitor::ifThenElse),
Fun.<Boolean, String, String>toMethod(ConditionalVisitor::ifThenElse),
Fun.<Boolean, Boolean, Boolean>toMethod(ConditionalVisitor::ifThenElse),
Fun.<Boolean, Long>toMethod(ConditionalVisitor::caseFn),
Fun.<Boolean, Double>toMethod(ConditionalVisitor::caseFn),
Fun.<Boolean, String>toMethod(ConditionalVisitor::caseFn),
Fun.<Boolean, Boolean>toMethod(ConditionalVisitor::caseFn),
Fun.<Long, Long>toMethod(ConditionalVisitor::nvl),
Fun.<Double, Double>toMethod(ConditionalVisitor::nvl),
Fun.<Double, Long>toMethod(ConditionalVisitor::nvl),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@
import fr.insee.vtl.parser.VtlBaseVisitor;
import fr.insee.vtl.parser.VtlParser;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.*;
import java.util.stream.Collectors;

import static fr.insee.vtl.engine.VtlScriptEngine.fromContext;
import static fr.insee.vtl.engine.utils.TypeChecking.assertBoolean;
import static fr.insee.vtl.engine.utils.TypeChecking.hasSameTypeOrNull;

/**
Expand Down Expand Up @@ -68,34 +64,6 @@ public static Boolean ifThenElse(Boolean condition, Boolean thenExpr, Boolean el
return condition ? thenExpr : elseExpr;
}

public static Long caseFn(Boolean condition, Long thenExpr) {
if (condition == null) {
return null;
}
return condition ? thenExpr : null;
}

public static Double caseFn(Boolean condition, Double thenExpr) {
if (condition == null) {
return null;
}
return condition ? thenExpr : null;
}

public static String caseFn(Boolean condition, String thenExpr) {
if (condition == null) {
return null;
}
return condition ? thenExpr : null;
}

public static Boolean caseFn(Boolean condition, Boolean thenExpr) {
if (condition == null) {
return null;
}
return condition ? thenExpr : null;
}

public static Long nvl(Long value, Long defaultValue) {
return value == null ? defaultValue : value;
}
Expand Down Expand Up @@ -149,92 +117,59 @@ public ResolvableExpression visitIfExpr(VtlParser.IfExprContext ctx) {
*/
@Override
public ResolvableExpression visitCaseExpr(VtlParser.CaseExprContext ctx) {
Positioned pos = fromContext(ctx);
List<VtlParser.ExprContext> exprs = ctx.expr();
List<VtlParser.ExprContext> whenExprs = new ArrayList<>();
List<VtlParser.ExprContext> thenExprs = new ArrayList<>();
for (int i = 0; i < exprs.size() - 1; i = i + 2) {
whenExprs.add(exprs.get(i));
thenExprs.add(exprs.get(i + 1));
}
List<ResolvableExpression> whenExpressions = whenExprs.stream()
.map(e -> assertBoolean(exprVisitor.visit(e), e))
.collect(Collectors.toList());
List<ResolvableExpression> thenExpressions = thenExprs.stream()
.map(exprVisitor::visit)
.collect(Collectors.toList());
ResolvableExpression elseExpression = exprVisitor.visit(exprs.get(exprs.size() - 1));
List<ResolvableExpression> forTypeCheck = (new ArrayList<>(thenExpressions));
forTypeCheck.add(elseExpression);
// TODO: handle better the default element position
if (!hasSameTypeOrNull(forTypeCheck)) {
try {
throw new InvalidTypeException(
forTypeCheck.get(0).getClass(),
Boolean.class,
fromContext(ctx.expr(0))
);
} catch (InvalidTypeException e) {
throw new RuntimeException(e);
try {
Positioned pos = fromContext(ctx);
List<VtlParser.ExprContext> exprs = ctx.expr();
List<VtlParser.ExprContext> whenExprs = new ArrayList<>();
List<VtlParser.ExprContext> thenExprs = new ArrayList<>();
for (int i = 0; i < exprs.size() - 1; i = i + 2) {
whenExprs.add(exprs.get(i));
thenExprs.add(exprs.get(i + 1));
}
List<ResolvableExpression> whenExpressions = whenExprs.stream()
.map(exprVisitor::visit)
.collect(Collectors.toList());
List<ResolvableExpression> thenExpressions = thenExprs.stream()
.map(exprVisitor::visit)
.collect(Collectors.toList());
ResolvableExpression elseExpression = exprVisitor.visit(exprs.get(exprs.size() - 1));
List<ResolvableExpression> forTypeCheck = (new ArrayList<>(thenExpressions));
forTypeCheck.add(elseExpression);
// TODO: handle better the default element position
if (!hasSameTypeOrNull(forTypeCheck)) {
try {
throw new InvalidTypeException(
forTypeCheck.get(0).getClass(),
Boolean.class,
fromContext(ctx.expr(0))
);
} catch (InvalidTypeException e) {
throw new RuntimeException(e);
}
}
}

Class<?> outputType = elseExpression.getType();

if (outputType.equals(String.class)) {
return ResolvableExpression.withType(String.class)
.withPosition(pos)
.using(context -> {
for (int i = 0; i < whenExprs.size(); i++) {
Boolean condition = (Boolean) whenExpressions.get(i).resolve(context);
if (condition) {
return (String) (new CastExpression(pos, thenExpressions.get(i), outputType)).resolve(context);
}
}
return (String) (new CastExpression(pos, elseExpression, outputType)).resolve(context);
});
}
if (outputType.equals(Double.class)) {
return ResolvableExpression.withType(Double.class)
.withPosition(pos)
.using(context -> {
for (int i = 0; i < whenExprs.size(); i++) {
Boolean condition = (Boolean) whenExpressions.get(i).resolve(context);
if (condition) {
return (Double) (new CastExpression(pos, thenExpressions.get(i), outputType)).resolve(context);
}
}
return (Double) (new CastExpression(pos, elseExpression, outputType)).resolve(context);
});
Class<?> outputType = elseExpression.getType();
return new CastExpression(pos, caseToIfIt(whenExpressions.listIterator(), thenExpressions.listIterator(), elseExpression), outputType);
} catch (VtlScriptException e) {
throw new VtlRuntimeException(e);
}
if (outputType.equals(Long.class)) {
return ResolvableExpression.withType(Long.class)
.withPosition(pos)
.using(context -> {
for (int i = 0; i < whenExprs.size(); i++) {
Boolean condition = (Boolean) whenExpressions.get(i).resolve(context);
if (condition) {
return (Long) (new CastExpression(pos, thenExpressions.get(i), outputType)).resolve(context);
}
}
return (Long) (new CastExpression(pos, elseExpression, outputType)).resolve(context);
});
}

private ResolvableExpression caseToIfIt(ListIterator<ResolvableExpression> whenExpr, ListIterator<ResolvableExpression> thenExpr, ResolvableExpression elseExpression) throws VtlScriptException {
if (!whenExpr.hasNext() || !thenExpr.hasNext()) {
return elseExpression;
}
if (outputType.equals(Boolean.class)) {
return ResolvableExpression.withType(Boolean.class)
.withPosition(pos)
.using(context -> {
for (int i = 0; i < whenExprs.size(); i++) {
Boolean condition = (Boolean) whenExpressions.get(i).resolve(context);
if (condition) {
return (Boolean) (new CastExpression(pos, thenExpressions.get(i), outputType)).resolve(context);
}
}
return (Boolean) (new CastExpression(pos, elseExpression, outputType)).resolve(context);
});
} else return null;

ResolvableExpression nextWhen = whenExpr.next();

return genericFunctionsVisitor.invokeFunction("ifThenElse", Java8Helpers.listOf(
nextWhen,
thenExpr.next(),
caseToIfIt(whenExpr, thenExpr, elseExpression)
), nextWhen);
}


/**
* Visits nvl expressions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import fr.insee.vtl.engine.expressions.CastExpression;
import fr.insee.vtl.engine.expressions.ComponentExpression;
import fr.insee.vtl.engine.expressions.FunctionExpression;
import fr.insee.vtl.model.utils.Java8Helpers;
import fr.insee.vtl.engine.visitors.expression.ExpressionVisitor;
import fr.insee.vtl.model.*;
import fr.insee.vtl.model.exceptions.VtlScriptException;
import fr.insee.vtl.model.utils.Java8Helpers;
import fr.insee.vtl.parser.VtlBaseVisitor;
import fr.insee.vtl.parser.VtlParser;
import org.antlr.v4.runtime.Token;
Expand Down Expand Up @@ -169,7 +169,8 @@ private DatasetExpression invokeFunctionOnDataset(String funcName, List<Resolvab
.collect(Collectors.toMap(e -> "arg" + e.hashCode(), e -> e));
if (measureNames.size() != 1) {
throw new VtlRuntimeException(
new InvalidArgumentException("mono-measure datasets don't contain same measures (number or names)", position)
new InvalidArgumentException("Variables in the mono-measure datasets are not named the same: " +
measureNames + " found", position)
);
}
DatasetExpression ds = proc.executeInnerJoin(dsExprs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import fr.insee.vtl.engine.exceptions.FunctionNotFoundException;
import fr.insee.vtl.engine.samples.DatasetSamples;
import fr.insee.vtl.model.utils.Java8Helpers;
import fr.insee.vtl.model.Dataset;
import fr.insee.vtl.model.utils.Java8Helpers;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -92,6 +92,15 @@ public void testCaseExpr() throws ScriptException {
Java8Helpers.mapOf("id", "Franck", "c", 1L)
);
assertThat(((Dataset) res1).getDataStructure().get("c").getType()).isEqualTo(Long.class);
engine.eval("ds1 := ds_1[keep id, long1][rename long1 to bool_var];" +
"ds2 := ds_2[keep id, long1][rename long1 to bool_var]; " +
"res_ds <- case when ds1 < 30 then ds1 else ds2;");
Object res_ds = engine.getContext().getAttribute("res_ds");
assertThat(((Dataset) res_ds).getDataAsMap()).containsExactlyInAnyOrder(
Java8Helpers.mapOf("id", "Hadrien", "bool_var", 10L),
Java8Helpers.mapOf("id", "Nico", "bool_var", 20L),
Java8Helpers.mapOf("id", "Franck", "bool_var", 100L)
);
}

@Test
Expand Down

0 comments on commit 1dd9e40

Please sign in to comment.