View Javadoc
1   package org.opentrafficsim.core.egtf;
2   
3   import java.util.Locale;
4   import java.util.stream.IntStream;
5   
6   /**
7    * Utility class for convolution using fast fourier transformation. This utility is specifically tailored to EGTF and not for
8    * general fast fourier purposes.
9    * <p>
10   * Copyright (c) 2013-2019 Delft University of Technology, PO Box 5, 2600 AA, Delft, the Netherlands. All rights reserved. <br>
11   * BSD-style license. See <a href="http://opentrafficsim.org/node/13">OpenTrafficSim License</a>.
12   * <p>
13   * @version $Revision$, $LastChangedDate$, by $Author$, initial version 31 okt. 2018 <br>
14   * @author <a href="http://www.tbm.tudelft.nl/averbraeck">Alexander Verbraeck</a>
15   * @author <a href="http://www.tudelft.nl/pknoppers">Peter Knoppers</a>
16   * @author <a href="http://www.transport.citg.tudelft.nl">Wouter Schakel</a>
17   */
18  public final class Convolution
19  {
20  
21      /**
22       * Private constructor.
23       */
24      private Convolution()
25      {
26          //
27      }
28  
29      public static void main(final String... args)
30      {
31          int[] size = new int[] { 10, 12, 15, 18, 20, 25, 30, 35, 50, 100, 200, 500, 1000 };
32          for (int i = 0; i < size.length; i++)
33          {
34              for (int j = 0; j < size.length; j++)
35              {
36                  if (size[i] * size[j] <= 100000)
37                  {
38                      double[][] a = new double[size[i]][size[i]];
39                      for (int k = 0; k < size[i]; k++)
40                      {
41                          for (int l = 0; l < size[i]; l++)
42                          {
43                              a[k][l] = Math.random();
44                          }
45                      }
46                      double[][] b = new double[size[j]][size[j]];
47                      for (int k = 0; k < size[j]; k++)
48                      {
49                          for (int l = 0; l < size[j]; l++)
50                          {
51                              b[k][l] = Math.random() * 35.0;
52                          }
53                      }
54                      long t1 = System.currentTimeMillis();
55                      double[][] out1 = conv(a, b);
56                      t1 = System.currentTimeMillis() - t1;
57                      long t2 = System.currentTimeMillis();
58                      double[][] out2 = convolution(a, b);
59                      t2 = System.currentTimeMillis() - t2;
60                      for (int k = 0; k < size[j]; k++)
61                      {
62                          for (int l = 0; l < size[j]; l++)
63                          {
64                              if (Math.abs(out1[k][l] - out2[k][l]) > 1e-6)
65                              {
66                                  throw new RuntimeException(
67                                          String.format("output unequal: %.16f vs. %.16f", out1[k][l], out2[k][l]));
68                              }
69                          }
70                      }
71                      System.out.println(String.format("a = %d, b = %d: tConv = %dms, tFft = %dms, gain = %dms", size[i], size[j],
72                              t1, t2, t2 - t1));
73                  }
74              }
75          }
76      }
77  
78      /**
79       * Convolution of two matrices using fast fourier transform.
80       * @param a double[][]; the kernel matrix
81       * @param b double[][]; the data matrix
82       * @return double[][]; convolution of a over b, same size as b
83       */
84      private static double[][] conv(final double[][] a, final double[][] b)
85      {
86          double[][] out2 = new double[b.length][b[0].length];
87          int fromRow2 = a.length / 2;
88          int fromCol2 = a[0].length / 2;
89          for (int i = 0; i < b.length; i++)
90          {
91              for (int j = 0; j < b[0].length; j++)
92              {
93                  for (int k = 0; k < a.length; k++)
94                  {
95                      for (int l = 0; l < a[0].length; l++)
96                      {
97                          int m = i - k + fromRow2;
98                          int n = j - l + fromCol2;
99                          if (m >= 0 && n >= 0 && m < b.length && n < b[0].length)
100                         {
101                             out2[i][j] += a[k][l] * b[m][n];
102                         }
103                     }
104                 }
105             }
106         }
107         return out2;
108     }
109 
110     /**
111      * Convolution of two matrices using fast fourier transform.
112      * @param a double[][]; the kernel matrix
113      * @param b double[][]; the data matrix
114      * @return double[][]; convolution of a over b, same size as b
115      */
116     public static double[][] convolution(final double[][] a, final double[][] b)
117     {
118         // create zero-padded matrices with dimensions as a power of 2
119         int i = a.length + b.length - 1;
120         int j = a[0].length + b[0].length - 1;
121         int i2 = (int) Math.pow(2, 32 - Integer.numberOfLeadingZeros(i));
122         int j2 = (int) Math.pow(2, 32 - Integer.numberOfLeadingZeros(j));
123         double[][] a2 = zeroPadding(a, i2, j2); // copying matrix is also safe, so this effort is worthwhile
124         double[][] b2 = zeroPadding(b, i2, j2);
125         // fft
126         Complex[] a3 = fft2(a2);
127         Complex[] b3 = fft2(b2);
128         // element-wise product (store in a3)
129         for (int k = 0; k < i2; k++)
130         {
131             for (int m = 0; m < j2; m++)
132             {
133                 double re = a3[k].re[m] * b3[k].re[m] - a3[k].im[m] * b3[k].im[m]; // im depends on re, so need tmp variable
134                 a3[k].im[m] = a3[k].re[m] * b3[k].im[m] + a3[k].im[m] * b3[k].re[m];
135                 a3[k].re[m] = re;
136             }
137         }
138         // inverse fft
139         ifft2(a3);
140         // crop padded zeros (note that the convolution is centered in the resulting matrix, we start at half the size of 'a')
141         double[][] out = new double[b.length][b[0].length];
142         int fromRow = a.length / 2;
143         int fromCol = a[0].length / 2;
144         for (int k = 0; k < b.length; k++)
145         {
146             System.arraycopy(a3[fromRow + k].re, fromCol, out[k], 0, out[k].length);
147         }
148         return out;
149     }
150 
151     /**
152      * Adds zeros to a matrix to obtain size {@code i x j}.
153      * @param x double[][]; original matrix
154      * @param i int; number of desired rows
155      * @param j int; number of desired columns
156      * @return double[][]; {@code x} padded with zeros
157      */
158     private static double[][] zeroPadding(final double[][] x, final int i, final int j)
159     {
160         double[][] x2 = new double[i][j];
161         for (int k = 0; k < i; k++)
162         {
163             if (k < x.length)
164             {
165                 System.arraycopy(x[k], 0, x2[k], 0, x[k].length);
166             }
167         }
168         return x2;
169     }
170 
171     /**
172      * Two-dimensional fast fourier transform.
173      * @param x double[][]; matrix, this data is affected by the method
174      * @return Complex[]; array of complex objects, each representing a row of complex values
175      */
176     private static Complex[] fft2(final double[][] x)
177     {
178         Complex[] xComp = new Complex[x.length];
179         // create complex objects and perform the row-fft
180         for (int i = 0; i < x.length; i++)
181         {
182             xComp[i] = fft(new Complex(x[i]));
183         }
184         // perform the column fft
185         for (int i = 0; i < x[0].length; i++)
186         {
187             double[] re = new double[x.length];
188             double[] im = new double[x.length];
189             for (int j = 0; j < x.length; j++)
190             {
191                 re[j] = xComp[j].re[i];
192                 im[j] = xComp[j].im[i];
193             }
194             fft(new Complex(re, im));
195             for (int j = 0; j < x.length; j++)
196             {
197                 xComp[j].re[i] = re[j];
198                 xComp[j].im[i] = im[j];
199             }
200         }
201         return xComp;
202     }
203 
204     /**
205      * Fast fourier transform using Cooley–Tukey algorithm. This method is based on
206      * https://introcs.cs.princeton.edu/java/97data/FFT.java.html.
207      * @param x Complex; vector of complex objects
208      * @return Complex; vector after fourier transform
209      */
210     private static Complex fft(final Complex x)
211     {
212         // bit reversal permutation (this simply rearranges the order in a way that happens to work for the butterfly updates)
213         int n = x.re.length;
214         int shift = 1 + Integer.numberOfLeadingZeros(n);
215         for (int k = 0; k < n; k++)
216         {
217             int j = Integer.reverse(k) >>> shift;
218             if (j > k)
219             {
220                 double temp = x.re[j];
221                 x.re[j] = x.re[k];
222                 x.re[k] = temp;
223                 temp = x.im[j];
224                 x.im[j] = x.im[k];
225                 x.im[k] = temp;
226             }
227         }
228         // butterfly updates
229         for (int l = 2; l <= n; l = l + l)
230         {
231             double pil = -2.0 * Math.PI / l;
232             for (int k = 0; k < l / 2; k++)
233             {
234                 double kth = k * pil;
235                 double wReal = Math.cos(kth);
236                 double wImag = Math.sin(kth);
237                 for (int j = 0; j < n / l; j++)
238                 {
239                     int jlk = j * l + k;
240                     int jlkl2 = jlk + l / 2;
241                     double xReal = x.re[jlkl2];
242                     double xImag = x.im[jlkl2];
243                     double taoReal = wReal * xReal - wImag * xImag;
244                     double taoImag = wReal * xImag + wImag * xReal;
245                     x.re[jlkl2] = x.re[jlk] - taoReal;
246                     x.im[jlkl2] = x.im[jlk] - taoImag;
247                     x.re[jlk] = x.re[jlk] + taoReal;
248                     x.im[jlk] = x.im[jlk] + taoImag;
249                 }
250             }
251         }
252         return x;
253     }
254 
255     /**
256      * Two-dimensional inverse fourier transform. Result is stored in the input objects.
257      * @param x Complex[]; array of complex objects, each representing a row of complex values
258      */
259     private static void ifft2(final Complex[] x)
260     {
261         // perform the row ifft
262         for (int i = 0; i < x.length; i++)
263         {
264             ifft(x[i]);
265         }
266         // perform the column ifft
267         for (int i = 0; i < x[0].re.length; i++)
268         {
269             double[] re = new double[x.length];
270             double[] im = new double[x.length];
271             int col = i; // effective final
272             IntStream.range(0, x.length).forEach(j ->
273             {
274                 re[j] = x[j].re[col];
275                 im[j] = x[j].im[col];
276             });
277             ifft(new Complex(re, im));
278             IntStream.range(0, x.length).forEach(j ->
279             {
280                 x[j].re[col] = re[j];
281                 x[j].im[col] = im[j];
282             });
283         }
284     }
285 
286     /**
287      * Inverse fourier transform. Result is stored in the input object.
288      * @param x Complex; vector of complex values
289      */
290     private static void ifft(final Complex x)
291     {
292         // conjugate
293         int n = x.im.length;
294         for (int i = 0; i < n; i++)
295         {
296             x.im[i] = -x.im[i];
297         }
298         // forward fft
299         fft(x);
300         // conjugate and scaling
301         for (int i = 0; i < n; i++)
302         {
303             x.im[i] = -x.im[i] / n;
304             x.re[i] = x.re[i] / n;
305         }
306     }
307 
308     /**
309      * Class that contains a vector of complex values.
310      * <p>
311      * Copyright (c) 2013-2019 Delft University of Technology, PO Box 5, 2600 AA, Delft, the Netherlands. All rights reserved.
312      * <br>
313      * BSD-style license. See <a href="http://opentrafficsim.org/node/13">OpenTrafficSim License</a>.
314      * <p>
315      * @version $Revision$, $LastChangedDate$, by $Author$, initial version 31 okt. 2018 <br>
316      * @author <a href="http://www.tbm.tudelft.nl/averbraeck">Alexander Verbraeck</a>
317      * @author <a href="http://www.tudelft.nl/pknoppers">Peter Knoppers</a>
318      * @author <a href="http://www.transport.citg.tudelft.nl">Wouter Schakel</a>
319      */
320     private static class Complex
321     {
322 
323         /** Real part. */
324         @SuppressWarnings("visibilitymodifier")
325         public final double[] re;
326 
327         /** Imaginary part. */
328         @SuppressWarnings("visibilitymodifier")
329         public final double[] im;
330 
331         /**
332          * Constructor for zero imaginary part.
333          * @param x double[]; real part
334          */
335         Complex(final double[] x)
336         {
337             this.re = x;
338             this.im = new double[x.length];
339         }
340 
341         /**
342          * Constructor.
343          * @param re double[]; real part
344          * @param im double[]; imaginary part;
345          */
346         Complex(final double[] re, final double[] im)
347         {
348             this.re = re;
349             this.im = im;
350         }
351 
352         /** {@inheritDoc} */
353         @Override
354         public String toString()
355         {
356             StringBuilder str = new StringBuilder("[");
357             String sep = "";
358             for (int i = 0; i < this.re.length; i++)
359             {
360                 str.append(sep);
361                 sep = ", ";
362                 if (this.im[i] >= 0)
363                 {
364                     str.append(String.format(Locale.US, "%.2f+%.2fi", this.re[i], this.im[i]));
365                 }
366                 else
367                 {
368                     str.append(String.format(Locale.US, "%.2f-%.2fi", this.re[i], -this.im[i]));
369                 }
370             }
371             str.append("]");
372             return str.toString();
373         }
374 
375     }
376 }