-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathSample.java
105 lines (89 loc) · 3.96 KB
/
Sample.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package lphy.base.distribution;
import lphy.base.ParameterNames;
import lphy.core.model.GenerativeDistribution;
import lphy.core.model.RandomVariable;
import lphy.core.model.Value;
import lphy.core.model.VariableUtils;
import lphy.core.model.annotation.GeneratorInfo;
import lphy.core.model.annotation.ParameterInfo;
import lphy.core.simulator.RandomUtils;
import java.util.*;
public class Sample<T> implements GenerativeDistribution<T[]> {
private final String replParamName = "replace";
private Value<T[]> x;
private Value<Integer> size;
private Value<Boolean> replace;
Random random;
public Sample(@ParameterInfo(name = ParameterNames.ArrayParamName,
description = "1d-array to be sampled.") Value<T[]> x,
@ParameterInfo(name = ParameterNames.SizeParamName,
description = "the number of elements to choose.") Value<Integer> size,
@ParameterInfo(name = replParamName, description = "If replace is true, " +
"the same element can be sampled multiple times, if false (as default), " +
"it can only appear once in the result.",
optional = true) Value<Boolean> replace) {
this.x = x;
if (x == null) throw new IllegalArgumentException("The array can't be null!");
this.size = size;
if (size == null) throw new IllegalArgumentException("The size can't be null!");
if (replace == null)
replace = new Value<>(null, false);
this.replace = replace;
random = RandomUtils.getJavaRandom();
}
public void setParam(String paramName, Value value) {
if (paramName.equals(ParameterNames.ArrayParamName)) {
T[] arr = x.value();
if (arr == null || arr.length < 1)
throw new IllegalArgumentException("Must have at least 1 element in the array! " + Arrays.toString(arr));
x = value;
}
else if (paramName.equals(ParameterNames.SizeParamName)) {
if (size.value() <= 0 || size.value() > x.value().length)
throw new IllegalArgumentException("Invalid size : " + size.value());
size = value;
}
else if (paramName.equals(replParamName)) {
replace = value;
}
else throw new RuntimeException("Unrecognised parameter name: " + paramName);
}
@GeneratorInfo(name = "sample", description = "The sample function uniformly sample the subset of " +
"a given size from an array of the elements either with or without the replacement.")
public RandomVariable<T[]> sample() {
List<T> origArr = Arrays.asList(x.value());
int s = size.value();
// use List to handle generic
List<T> list2Arr;
if (replace.value()) {
list2Arr = new ArrayList<>();
int randomIndex;
for (int i = 0; i < s; i++) {
randomIndex = random.nextInt(origArr.size());
list2Arr.add( origArr.get(randomIndex) );
}
} else { // no replacement
Collections.shuffle(origArr, random);
list2Arr = origArr.stream().limit(s).toList();
}
System.out.println("Sample " + list2Arr.size() + " elements from the vector of " + origArr.size() +
(replace.value() ? " with":" without") + " the replacement.");
return VariableUtils.createRandomVariable( "S", list2Arr, this);
}
public Map<String, Value> getParams() {
return new TreeMap<>() {{
put(ParameterNames.ArrayParamName, x);
put(ParameterNames.SizeParamName, size);
if (replace != null) put(replParamName, replace); // optional
}};
}
public void setX(Value<T[]> x) {
this.x = x;
}
public void setSize(Value<Integer> size) {
this.size = size;
}
public void setReplace(Value<Boolean> replace) {
this.replace = replace;
}
}