1 package org.opentrafficsim.core.distributions;
2
3 import java.util.ArrayList;
4 import java.util.List;
5 import java.util.Objects;
6 import java.util.function.Supplier;
7
8 import org.djutils.exceptions.Throw;
9
10 import nl.tudelft.simulation.jstats.streams.StreamInterface;
11
12
13
14
15
16
17
18
19
20
21 public class ObjectDistribution<O> implements Supplier<O>
22 {
23
24 private final List<FrequencyAndObject<O>> objects = new ArrayList<>();
25
26
27 private double cumulativeTotal;
28
29
30 private final StreamInterface stream;
31
32
33
34
35
36
37
38
39 public ObjectDistribution(final List<FrequencyAndObject<O>> objects, final StreamInterface stream)
40 {
41 this(stream);
42 Throw.whenNull(objects, "objects");
43
44
45 this.objects.addAll(objects);
46 fixProbabilities();
47 }
48
49
50
51
52
53
54 public ObjectDistribution(final StreamInterface stream)
55 {
56 Throw.whenNull(stream, "stream");
57 this.stream = stream;
58 }
59
60
61
62
63 private void fixProbabilities()
64 {
65 if (0 == this.objects.size())
66 {
67 return;
68 }
69 this.cumulativeTotal = 0;
70 for (FrequencyAndObject<O> object : this.objects)
71 {
72 double frequency = object.frequency();
73 this.cumulativeTotal += frequency;
74 }
75 }
76
77 @Override
78 public O get()
79 {
80 Throw.when(0 == this.objects.size(), IllegalStateException.class, "Cannot draw from empty collection");
81 Throw.when(0 == this.cumulativeTotal, IllegalStateException.class, "Sum of frequencies or probabilities must be > 0");
82
83 double randomValue = this.stream.nextDouble() * this.cumulativeTotal;
84 for (FrequencyAndObject<O> fAndO : this.objects)
85 {
86 double frequency = fAndO.frequency();
87 if (frequency >= randomValue)
88 {
89 return fAndO.object();
90 }
91 randomValue -= frequency;
92 }
93
94 FrequencyAndObject<O> useThisOne = this.objects.get(0);
95 for (FrequencyAndObject<O> fAndO : this.objects)
96 {
97 if (fAndO.frequency() > 0)
98 {
99 useThisOne = fAndO;
100 break;
101 }
102 }
103 return useThisOne.object();
104 }
105
106
107
108
109
110
111
112 public ObjectDistribution<O> add(final FrequencyAndObject<O> object)
113 {
114 Throw.whenNull(object, "object");
115 return add(this.objects.size(), object);
116 }
117
118
119
120
121
122
123
124
125
126 public ObjectDistribution<O> add(final int index, final FrequencyAndObject<O> object)
127 {
128 Throw.whenNull(object, "object");
129 this.objects.add(index, object);
130 fixProbabilities();
131 return this;
132 }
133
134
135
136
137
138
139
140 public ObjectDistribution<O> remove(final int index)
141 {
142 this.objects.remove(index);
143 fixProbabilities();
144 return this;
145 }
146
147
148
149
150
151
152
153
154
155 public ObjectDistribution<O> set(final int index, final FrequencyAndObject<O> object)
156 {
157 Throw.whenNull(object, "object");
158 this.objects.set(index, object);
159 fixProbabilities();
160 return this;
161 }
162
163
164
165
166
167
168
169
170
171 public ObjectDistribution<O> modifyFrequency(final int index, final double frequency)
172 {
173 Throw.when(index < 0 || index >= this.size(), IndexOutOfBoundsException.class, "Index %s out of range (0..%d)", index,
174 this.size() - 1);
175 return set(index, new FrequencyAndObject<O>(frequency, this.objects.get(index).object()));
176 }
177
178
179
180
181
182 public ObjectDistribution<O> clear()
183 {
184 this.objects.clear();
185 return this;
186 }
187
188
189
190
191
192
193
194 public FrequencyAndObject<O> get(final int index)
195 {
196 return this.objects.get(index);
197 }
198
199
200
201
202
203 public int size()
204 {
205 return this.objects.size();
206 }
207
208 @Override
209 public int hashCode()
210 {
211 return Objects.hash(this.cumulativeTotal, this.objects, this.stream);
212 }
213
214 @Override
215 @SuppressWarnings("checkstyle:needbraces")
216 public boolean equals(final Object obj)
217 {
218 if (this == obj)
219 return true;
220 if (obj == null)
221 return false;
222 if (getClass() != obj.getClass())
223 return false;
224 ObjectDistribution<?> other = (ObjectDistribution<?>) obj;
225 return Double.doubleToLongBits(this.cumulativeTotal) == Double.doubleToLongBits(other.cumulativeTotal)
226 && Objects.equals(this.objects, other.objects) && Objects.equals(this.stream, other.stream);
227 }
228
229 @Override
230 public String toString()
231 {
232 StringBuilder result = new StringBuilder();
233 result.append("Distribution [");
234 String separator = "";
235 for (FrequencyAndObject<O> fAndO : this.objects)
236 {
237 result.append(separator + fAndO.frequency() + "->" + fAndO.object());
238 separator = ", ";
239 }
240 result.append(']');
241 return result.toString();
242 }
243
244 }