Convolution.java

package org.opentrafficsim.draw.egtf;

import java.util.Locale;
import java.util.stream.IntStream;

/**
 * Utility class for convolution using fast fourier transformation. This utility is specifically tailored to EGTF and not for
 * general fast fourier purposes.
 * <p>
 * Copyright (c) 2013-2024 Delft University of Technology, PO Box 5, 2600 AA, Delft, the Netherlands. All rights reserved. <br>
 * BSD-style license. See <a href="https://opentrafficsim.org/docs/license.html">OpenTrafficSim License</a>.
 * </p>
 * @author <a href="https://github.com/averbraeck">Alexander Verbraeck</a>
 * @author <a href="https://tudelft.nl/staff/p.knoppers-1">Peter Knoppers</a>
 * @author <a href="https://github.com/wjschakel">Wouter Schakel</a>
 */
public final class Convolution
{

    /**
     * Private constructor.
     */
    private Convolution()
    {
        //
    }

    /**
     * Program entry point.
     * @param args String...; the command line arguments (not used)
     */
    public static void main(final String... args)
    {
        int[] size = new int[] {10, 12, 15, 18, 20, 25, 30, 35, 50, 100, 200, 500, 1000};
        for (int i = 0; i < size.length; i++)
        {
            for (int j = 0; j < size.length; j++)
            {
                if (size[i] * size[j] <= 100000)
                {
                    double[][] a = new double[size[i]][size[i]];
                    for (int k = 0; k < size[i]; k++)
                    {
                        for (int l = 0; l < size[i]; l++)
                        {
                            a[k][l] = Math.random();
                        }
                    }
                    double[][] b = new double[size[j]][size[j]];
                    for (int k = 0; k < size[j]; k++)
                    {
                        for (int l = 0; l < size[j]; l++)
                        {
                            b[k][l] = Math.random() * 35.0;
                        }
                    }
                    long t1 = System.currentTimeMillis();
                    double[][] out1 = conv(a, b);
                    t1 = System.currentTimeMillis() - t1;
                    long t2 = System.currentTimeMillis();
                    double[][] out2 = convolution(a, b);
                    t2 = System.currentTimeMillis() - t2;
                    for (int k = 0; k < size[j]; k++)
                    {
                        for (int l = 0; l < size[j]; l++)
                        {
                            if (Math.abs(out1[k][l] - out2[k][l]) > 1e-6)
                            {
                                throw new RuntimeException(
                                        String.format("output unequal: %.16f vs. %.16f", out1[k][l], out2[k][l]));
                            }
                        }
                    }
                    System.out.println(String.format("a = %d, b = %d: tConv = %dms, tFft = %dms, gain = %dms", size[i], size[j],
                            t1, t2, t2 - t1));
                }
            }
        }
    }

    /**
     * Convolution of two matrices using fast fourier transform.
     * @param a double[][]; the kernel matrix
     * @param b double[][]; the data matrix
     * @return double[][]; convolution of a over b, same size as b
     */
    private static double[][] conv(final double[][] a, final double[][] b)
    {
        double[][] out2 = new double[b.length][b[0].length];
        int fromRow2 = a.length / 2;
        int fromCol2 = a[0].length / 2;
        for (int i = 0; i < b.length; i++)
        {
            for (int j = 0; j < b[0].length; j++)
            {
                for (int k = 0; k < a.length; k++)
                {
                    for (int l = 0; l < a[0].length; l++)
                    {
                        int m = i - k + fromRow2;
                        int n = j - l + fromCol2;
                        if (m >= 0 && n >= 0 && m < b.length && n < b[0].length)
                        {
                            out2[i][j] += a[k][l] * b[m][n];
                        }
                    }
                }
            }
        }
        return out2;
    }

    /**
     * Convolution of two matrices using fast fourier transform.
     * @param a double[][]; the kernel matrix
     * @param b double[][]; the data matrix
     * @return double[][]; convolution of a over b, same size as b
     */
    public static double[][] convolution(final double[][] a, final double[][] b)
    {
        // create zero-padded matrices with dimensions as a power of 2
        int i = a.length + b.length - 1;
        int j = a[0].length + b[0].length - 1;
        int i2 = (int) Math.pow(2, 32 - Integer.numberOfLeadingZeros(i));
        int j2 = (int) Math.pow(2, 32 - Integer.numberOfLeadingZeros(j));
        double[][] a2 = zeroPadding(a, i2, j2); // copying matrix is also safe, so this effort is worthwhile
        double[][] b2 = zeroPadding(b, i2, j2);
        // fft
        Complex[] a3 = fft2(a2);
        Complex[] b3 = fft2(b2);
        // element-wise product (store in a3)
        for (int k = 0; k < i2; k++)
        {
            for (int m = 0; m < j2; m++)
            {
                double re = a3[k].re[m] * b3[k].re[m] - a3[k].im[m] * b3[k].im[m]; // im depends on re, so need tmp variable
                a3[k].im[m] = a3[k].re[m] * b3[k].im[m] + a3[k].im[m] * b3[k].re[m];
                a3[k].re[m] = re;
            }
        }
        // inverse fft
        ifft2(a3);
        // crop padded zeros (note that the convolution is centered in the resulting matrix, we start at half the size of 'a')
        double[][] out = new double[b.length][b[0].length];
        int fromRow = a.length / 2;
        int fromCol = a[0].length / 2;
        for (int k = 0; k < b.length; k++)
        {
            System.arraycopy(a3[fromRow + k].re, fromCol, out[k], 0, out[k].length);
        }
        return out;
    }

    /**
     * Adds zeros to a matrix to obtain size {@code i x j}.
     * @param x double[][]; original matrix
     * @param i int; number of desired rows
     * @param j int; number of desired columns
     * @return double[][]; {@code x} padded with zeros
     */
    private static double[][] zeroPadding(final double[][] x, final int i, final int j)
    {
        double[][] x2 = new double[i][j];
        for (int k = 0; k < i; k++)
        {
            if (k < x.length)
            {
                System.arraycopy(x[k], 0, x2[k], 0, x[k].length);
            }
        }
        return x2;
    }

    /**
     * Two-dimensional fast fourier transform.
     * @param x double[][]; matrix, this data is affected by the method
     * @return Complex[]; array of complex objects, each representing a row of complex values
     */
    private static Complex[] fft2(final double[][] x)
    {
        Complex[] xComp = new Complex[x.length];
        // create complex objects and perform the row-fft
        for (int i = 0; i < x.length; i++)
        {
            xComp[i] = fft(new Complex(x[i]));
        }
        // perform the column fft
        for (int i = 0; i < x[0].length; i++)
        {
            double[] re = new double[x.length];
            double[] im = new double[x.length];
            for (int j = 0; j < x.length; j++)
            {
                re[j] = xComp[j].re[i];
                im[j] = xComp[j].im[i];
            }
            fft(new Complex(re, im));
            for (int j = 0; j < x.length; j++)
            {
                xComp[j].re[i] = re[j];
                xComp[j].im[i] = im[j];
            }
        }
        return xComp;
    }

    /**
     * Fast fourier transform using Cooley–Tukey algorithm. This method is based on
     * https://introcs.cs.princeton.edu/java/97data/FFT.java.html.
     * @param x Complex; vector of complex objects
     * @return Complex; vector after fourier transform
     */
    private static Complex fft(final Complex x)
    {
        // bit reversal permutation (this simply rearranges the order in a way that happens to work for the butterfly updates)
        int n = x.re.length;
        int shift = 1 + Integer.numberOfLeadingZeros(n);
        for (int k = 0; k < n; k++)
        {
            int j = Integer.reverse(k) >>> shift;
            if (j > k)
            {
                double temp = x.re[j];
                x.re[j] = x.re[k];
                x.re[k] = temp;
                temp = x.im[j];
                x.im[j] = x.im[k];
                x.im[k] = temp;
            }
        }
        // butterfly updates
        for (int l = 2; l <= n; l = l + l)
        {
            double pil = -2.0 * Math.PI / l;
            for (int k = 0; k < l / 2; k++)
            {
                double kth = k * pil;
                double wReal = Math.cos(kth);
                double wImag = Math.sin(kth);
                for (int j = 0; j < n / l; j++)
                {
                    int jlk = j * l + k;
                    int jlkl2 = jlk + l / 2;
                    double xReal = x.re[jlkl2];
                    double xImag = x.im[jlkl2];
                    double taoReal = wReal * xReal - wImag * xImag;
                    double taoImag = wReal * xImag + wImag * xReal;
                    x.re[jlkl2] = x.re[jlk] - taoReal;
                    x.im[jlkl2] = x.im[jlk] - taoImag;
                    x.re[jlk] = x.re[jlk] + taoReal;
                    x.im[jlk] = x.im[jlk] + taoImag;
                }
            }
        }
        return x;
    }

    /**
     * Two-dimensional inverse fourier transform. Result is stored in the input objects.
     * @param x Complex[]; array of complex objects, each representing a row of complex values
     */
    private static void ifft2(final Complex[] x)
    {
        // perform the row ifft
        for (int i = 0; i < x.length; i++)
        {
            ifft(x[i]);
        }
        // perform the column ifft
        for (int i = 0; i < x[0].re.length; i++)
        {
            double[] re = new double[x.length];
            double[] im = new double[x.length];
            int col = i; // effective final
            IntStream.range(0, x.length).forEach(j ->
            {
                re[j] = x[j].re[col];
                im[j] = x[j].im[col];
            });
            ifft(new Complex(re, im));
            IntStream.range(0, x.length).forEach(j ->
            {
                x[j].re[col] = re[j];
                x[j].im[col] = im[j];
            });
        }
    }

    /**
     * Inverse fourier transform. Result is stored in the input object.
     * @param x Complex; vector of complex values
     */
    private static void ifft(final Complex x)
    {
        // conjugate
        int n = x.im.length;
        for (int i = 0; i < n; i++)
        {
            x.im[i] = -x.im[i];
        }
        // forward fft
        fft(x);
        // conjugate and scaling
        for (int i = 0; i < n; i++)
        {
            x.im[i] = -x.im[i] / n;
            x.re[i] = x.re[i] / n;
        }
    }

    /**
     * Class that contains a vector of complex values.
     * <p>
     * Copyright (c) 2013-2024 Delft University of Technology, PO Box 5, 2600 AA, Delft, the Netherlands. All rights reserved.
     * <br>
     * BSD-style license. See <a href="https://opentrafficsim.org/docs/license.html">OpenTrafficSim License</a>.
     * </p>
     * @author <a href="https://github.com/averbraeck">Alexander Verbraeck</a>
     * @author <a href="https://tudelft.nl/staff/p.knoppers-1">Peter Knoppers</a>
     * @author <a href="https://github.com/wjschakel">Wouter Schakel</a>
     */
    private static class Complex
    {

        /** Real part. */
        @SuppressWarnings("visibilitymodifier")
        public final double[] re;

        /** Imaginary part. */
        @SuppressWarnings("visibilitymodifier")
        public final double[] im;

        /**
         * Constructor for zero imaginary part.
         * @param x double[]; real part
         */
        Complex(final double[] x)
        {
            this.re = x;
            this.im = new double[x.length];
        }

        /**
         * Constructor.
         * @param re double[]; real part
         * @param im double[]; imaginary part;
         */
        Complex(final double[] re, final double[] im)
        {
            this.re = re;
            this.im = im;
        }

        /** {@inheritDoc} */
        @Override
        public String toString()
        {
            StringBuilder str = new StringBuilder("[");
            String sep = "";
            for (int i = 0; i < this.re.length; i++)
            {
                str.append(sep);
                sep = ", ";
                if (this.im[i] >= 0)
                {
                    str.append(String.format(Locale.US, "%.2f+%.2fi", this.re[i], this.im[i]));
                }
                else
                {
                    str.append(String.format(Locale.US, "%.2f-%.2fi", this.re[i], -this.im[i]));
                }
            }
            str.append("]");
            return str.toString();
        }

    }
}