1 package org.opentrafficsim.core.egtf;
2
3 import java.util.Locale;
4 import java.util.stream.IntStream;
5
6
7
8
9
10
11
12
13
14
15
16
17
18 public final class Convolution
19 {
20
21
22
23
24 private Convolution()
25 {
26
27 }
28
29 public static void main(final String... args)
30 {
31 int[] size = new int[] { 10, 12, 15, 18, 20, 25, 30, 35, 50, 100, 200, 500, 1000 };
32 for (int i = 0; i < size.length; i++)
33 {
34 for (int j = 0; j < size.length; j++)
35 {
36 if (size[i] * size[j] <= 100000)
37 {
38 double[][] a = new double[size[i]][size[i]];
39 for (int k = 0; k < size[i]; k++)
40 {
41 for (int l = 0; l < size[i]; l++)
42 {
43 a[k][l] = Math.random();
44 }
45 }
46 double[][] b = new double[size[j]][size[j]];
47 for (int k = 0; k < size[j]; k++)
48 {
49 for (int l = 0; l < size[j]; l++)
50 {
51 b[k][l] = Math.random() * 35.0;
52 }
53 }
54 long t1 = System.currentTimeMillis();
55 double[][] out1 = conv(a, b);
56 t1 = System.currentTimeMillis() - t1;
57 long t2 = System.currentTimeMillis();
58 double[][] out2 = convolution(a, b);
59 t2 = System.currentTimeMillis() - t2;
60 for (int k = 0; k < size[j]; k++)
61 {
62 for (int l = 0; l < size[j]; l++)
63 {
64 if (Math.abs(out1[k][l] - out2[k][l]) > 1e-6)
65 {
66 throw new RuntimeException(
67 String.format("output unequal: %.16f vs. %.16f", out1[k][l], out2[k][l]));
68 }
69 }
70 }
71 System.out.println(String.format("a = %d, b = %d: tConv = %dms, tFft = %dms, gain = %dms", size[i], size[j],
72 t1, t2, t2 - t1));
73 }
74 }
75 }
76 }
77
78
79
80
81
82
83
84 private static double[][] conv(final double[][] a, final double[][] b)
85 {
86 double[][] out2 = new double[b.length][b[0].length];
87 int fromRow2 = a.length / 2;
88 int fromCol2 = a[0].length / 2;
89 for (int i = 0; i < b.length; i++)
90 {
91 for (int j = 0; j < b[0].length; j++)
92 {
93 for (int k = 0; k < a.length; k++)
94 {
95 for (int l = 0; l < a[0].length; l++)
96 {
97 int m = i - k + fromRow2;
98 int n = j - l + fromCol2;
99 if (m >= 0 && n >= 0 && m < b.length && n < b[0].length)
100 {
101 out2[i][j] += a[k][l] * b[m][n];
102 }
103 }
104 }
105 }
106 }
107 return out2;
108 }
109
110
111
112
113
114
115
116 public static double[][] convolution(final double[][] a, final double[][] b)
117 {
118
119 int i = a.length + b.length - 1;
120 int j = a[0].length + b[0].length - 1;
121 int i2 = (int) Math.pow(2, 32 - Integer.numberOfLeadingZeros(i));
122 int j2 = (int) Math.pow(2, 32 - Integer.numberOfLeadingZeros(j));
123 double[][] a2 = zeroPadding(a, i2, j2);
124 double[][] b2 = zeroPadding(b, i2, j2);
125
126 Complex[] a3 = fft2(a2);
127 Complex[] b3 = fft2(b2);
128
129 for (int k = 0; k < i2; k++)
130 {
131 for (int m = 0; m < j2; m++)
132 {
133 double re = a3[k].re[m] * b3[k].re[m] - a3[k].im[m] * b3[k].im[m];
134 a3[k].im[m] = a3[k].re[m] * b3[k].im[m] + a3[k].im[m] * b3[k].re[m];
135 a3[k].re[m] = re;
136 }
137 }
138
139 ifft2(a3);
140
141 double[][] out = new double[b.length][b[0].length];
142 int fromRow = a.length / 2;
143 int fromCol = a[0].length / 2;
144 for (int k = 0; k < b.length; k++)
145 {
146 System.arraycopy(a3[fromRow + k].re, fromCol, out[k], 0, out[k].length);
147 }
148 return out;
149 }
150
151
152
153
154
155
156
157
158 private static double[][] zeroPadding(final double[][] x, final int i, final int j)
159 {
160 double[][] x2 = new double[i][j];
161 for (int k = 0; k < i; k++)
162 {
163 if (k < x.length)
164 {
165 System.arraycopy(x[k], 0, x2[k], 0, x[k].length);
166 }
167 }
168 return x2;
169 }
170
171
172
173
174
175
176 private static Complex[] fft2(final double[][] x)
177 {
178 Complex[] xComp = new Complex[x.length];
179
180 for (int i = 0; i < x.length; i++)
181 {
182 xComp[i] = fft(new Complex(x[i]));
183 }
184
185 for (int i = 0; i < x[0].length; i++)
186 {
187 double[] re = new double[x.length];
188 double[] im = new double[x.length];
189 for (int j = 0; j < x.length; j++)
190 {
191 re[j] = xComp[j].re[i];
192 im[j] = xComp[j].im[i];
193 }
194 fft(new Complex(re, im));
195 for (int j = 0; j < x.length; j++)
196 {
197 xComp[j].re[i] = re[j];
198 xComp[j].im[i] = im[j];
199 }
200 }
201 return xComp;
202 }
203
204
205
206
207
208
209
210 private static Complex fft(final Complex x)
211 {
212
213 int n = x.re.length;
214 int shift = 1 + Integer.numberOfLeadingZeros(n);
215 for (int k = 0; k < n; k++)
216 {
217 int j = Integer.reverse(k) >>> shift;
218 if (j > k)
219 {
220 double temp = x.re[j];
221 x.re[j] = x.re[k];
222 x.re[k] = temp;
223 temp = x.im[j];
224 x.im[j] = x.im[k];
225 x.im[k] = temp;
226 }
227 }
228
229 for (int l = 2; l <= n; l = l + l)
230 {
231 double pil = -2.0 * Math.PI / l;
232 for (int k = 0; k < l / 2; k++)
233 {
234 double kth = k * pil;
235 double wReal = Math.cos(kth);
236 double wImag = Math.sin(kth);
237 for (int j = 0; j < n / l; j++)
238 {
239 int jlk = j * l + k;
240 int jlkl2 = jlk + l / 2;
241 double xReal = x.re[jlkl2];
242 double xImag = x.im[jlkl2];
243 double taoReal = wReal * xReal - wImag * xImag;
244 double taoImag = wReal * xImag + wImag * xReal;
245 x.re[jlkl2] = x.re[jlk] - taoReal;
246 x.im[jlkl2] = x.im[jlk] - taoImag;
247 x.re[jlk] = x.re[jlk] + taoReal;
248 x.im[jlk] = x.im[jlk] + taoImag;
249 }
250 }
251 }
252 return x;
253 }
254
255
256
257
258
259 private static void ifft2(final Complex[] x)
260 {
261
262 for (int i = 0; i < x.length; i++)
263 {
264 ifft(x[i]);
265 }
266
267 for (int i = 0; i < x[0].re.length; i++)
268 {
269 double[] re = new double[x.length];
270 double[] im = new double[x.length];
271 int col = i;
272 IntStream.range(0, x.length).forEach(j ->
273 {
274 re[j] = x[j].re[col];
275 im[j] = x[j].im[col];
276 });
277 ifft(new Complex(re, im));
278 IntStream.range(0, x.length).forEach(j ->
279 {
280 x[j].re[col] = re[j];
281 x[j].im[col] = im[j];
282 });
283 }
284 }
285
286
287
288
289
290 private static void ifft(final Complex x)
291 {
292
293 int n = x.im.length;
294 for (int i = 0; i < n; i++)
295 {
296 x.im[i] = -x.im[i];
297 }
298
299 fft(x);
300
301 for (int i = 0; i < n; i++)
302 {
303 x.im[i] = -x.im[i] / n;
304 x.re[i] = x.re[i] / n;
305 }
306 }
307
308
309
310
311
312
313
314
315
316
317
318
319
320 private static class Complex
321 {
322
323
324 @SuppressWarnings("visibilitymodifier")
325 public final double[] re;
326
327
328 @SuppressWarnings("visibilitymodifier")
329 public final double[] im;
330
331
332
333
334
335 Complex(final double[] x)
336 {
337 this.re = x;
338 this.im = new double[x.length];
339 }
340
341
342
343
344
345
346 Complex(final double[] re, final double[] im)
347 {
348 this.re = re;
349 this.im = im;
350 }
351
352
353 @Override
354 public String toString()
355 {
356 StringBuilder str = new StringBuilder("[");
357 String sep = "";
358 for (int i = 0; i < this.re.length; i++)
359 {
360 str.append(sep);
361 sep = ", ";
362 if (this.im[i] >= 0)
363 {
364 str.append(String.format(Locale.US, "%.2f+%.2fi", this.re[i], this.im[i]));
365 }
366 else
367 {
368 str.append(String.format(Locale.US, "%.2f-%.2fi", this.re[i], -this.im[i]));
369 }
370 }
371 str.append("]");
372 return str.toString();
373 }
374
375 }
376 }