Skip to content

Commit

Permalink
Improve case operator
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoLaval committed Nov 5, 2024
1 parent 2d17976 commit 242ad21
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ public class VtlNativeMethods {
Fun.<Boolean, Double, Double>toMethod(ConditionalVisitor::ifThenElse),
Fun.<Boolean, String, String>toMethod(ConditionalVisitor::ifThenElse),
Fun.<Boolean, Boolean, Boolean>toMethod(ConditionalVisitor::ifThenElse),
Fun.<Boolean, Long>toMethod(ConditionalVisitor::caseFn),
Fun.<Boolean, Double>toMethod(ConditionalVisitor::caseFn),
Fun.<Boolean, String>toMethod(ConditionalVisitor::caseFn),
Fun.<Boolean, Boolean>toMethod(ConditionalVisitor::caseFn),
Fun.<Long, Long>toMethod(ConditionalVisitor::nvl),
Fun.<Double, Double>toMethod(ConditionalVisitor::nvl),
Fun.<Double, Long>toMethod(ConditionalVisitor::nvl),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import java.util.stream.Collectors;

import static fr.insee.vtl.engine.VtlScriptEngine.fromContext;
import static fr.insee.vtl.engine.utils.TypeChecking.assertBoolean;
import static fr.insee.vtl.engine.utils.TypeChecking.hasSameTypeOrNull;

/**
Expand Down Expand Up @@ -65,34 +64,6 @@ public static Boolean ifThenElse(Boolean condition, Boolean thenExpr, Boolean el
return condition ? thenExpr : elseExpr;
}

public static Long caseFn(Boolean condition, Long thenExpr) {
if (condition == null) {
return null;
}
return condition ? thenExpr : null;
}

public static Double caseFn(Boolean condition, Double thenExpr) {
if (condition == null) {
return null;
}
return condition ? thenExpr : null;
}

public static String caseFn(Boolean condition, String thenExpr) {
if (condition == null) {
return null;
}
return condition ? thenExpr : null;
}

public static Boolean caseFn(Boolean condition, Boolean thenExpr) {
if (condition == null) {
return null;
}
return condition ? thenExpr : null;
}

public static Long nvl(Long value, Long defaultValue) {
return value == null ? defaultValue : value;
}
Expand Down Expand Up @@ -156,7 +127,7 @@ public ResolvableExpression visitCaseExpr(VtlParser.CaseExprContext ctx) {
thenExprs.add(exprs.get(i + 1));
}
List<ResolvableExpression> whenExpressions = whenExprs.stream()
.map(e -> assertBoolean(exprVisitor.visit(e), e))
.map(exprVisitor::visit)
.collect(Collectors.toList());
List<ResolvableExpression> thenExpressions = thenExprs.stream()
.map(exprVisitor::visit)
Expand Down Expand Up @@ -189,11 +160,13 @@ private ResolvableExpression caseToIfIt(ListIterator<ResolvableExpression> whenE
return elseExpression;
}

ResolvableExpression nextWhen = whenExpr.next();

return genericFunctionsVisitor.invokeFunction("ifThenElse", Java8Helpers.listOf(
whenExpr.next(),
nextWhen,
thenExpr.next(),
caseToIfIt(whenExpr, thenExpr, elseExpression)
), null);
), nextWhen);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import fr.insee.vtl.engine.expressions.CastExpression;
import fr.insee.vtl.engine.expressions.ComponentExpression;
import fr.insee.vtl.engine.expressions.FunctionExpression;
import fr.insee.vtl.model.utils.Java8Helpers;
import fr.insee.vtl.engine.visitors.expression.ExpressionVisitor;
import fr.insee.vtl.model.*;
import fr.insee.vtl.model.exceptions.VtlScriptException;
import fr.insee.vtl.model.utils.Java8Helpers;
import fr.insee.vtl.parser.VtlBaseVisitor;
import fr.insee.vtl.parser.VtlParser;
import org.antlr.v4.runtime.Token;
Expand Down Expand Up @@ -169,7 +169,8 @@ private DatasetExpression invokeFunctionOnDataset(String funcName, List<Resolvab
.collect(Collectors.toMap(e -> "arg" + e.hashCode(), e -> e));
if (measureNames.size() != 1) {
throw new VtlRuntimeException(
new InvalidArgumentException("mono-measure datasets don't contain same measures (number or names)", position)
new InvalidArgumentException("Variables in the mono-measure datasets are not named the same: " +
measureNames + " found", position)
);
}
DatasetExpression ds = proc.executeInnerJoin(dsExprs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import fr.insee.vtl.engine.exceptions.FunctionNotFoundException;
import fr.insee.vtl.engine.samples.DatasetSamples;
import fr.insee.vtl.model.utils.Java8Helpers;
import fr.insee.vtl.model.Dataset;
import fr.insee.vtl.model.utils.Java8Helpers;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -92,6 +92,15 @@ public void testCaseExpr() throws ScriptException {
Java8Helpers.mapOf("id", "Franck", "c", 1L)
);
assertThat(((Dataset) res1).getDataStructure().get("c").getType()).isEqualTo(Long.class);
engine.eval("ds1 := ds_1[keep id, long1][rename long1 to bool_var];" +
"ds2 := ds_2[keep id, long1][rename long1 to bool_var]; " +
"res_ds <- case when ds1 < 30 then ds1 else ds2;");
Object res_ds = engine.getContext().getAttribute("res_ds");
assertThat(((Dataset) res_ds).getDataAsMap()).containsExactlyInAnyOrder(
Java8Helpers.mapOf("id", "Hadrien", "bool_var", 10L),
Java8Helpers.mapOf("id", "Nico", "bool_var", 20L),
Java8Helpers.mapOf("id", "Franck", "bool_var", 100L)
);
}

@Test
Expand Down

0 comments on commit 242ad21

Please sign in to comment.