BP of Conv1D

%pip install jax==0.2.13 jaxlib==0.1.66
from jax import grad
from jax.numpy import array, convolve, inner
a = array([1.1, 2., -3, 2.5, 7.])
b = array([3., 2., -2.2])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Inner product

An often use operation is inner product \(\odot\), let \(f(x, y)=x\odot y\), we have \(\nabla_xf=y\)

c = array([4.5, 1., 5.])
grad(inner)(b, c)  # returns the value of c
DeviceArray([4.5, 1. , 5. ], dtype=float32)

Conv1D

Another case is convolution, whose vector-Jacobian product is heavily used in the analysis of backprop. GDBP.

Let \(f(x, y)=g(x*y)\), we have \(\nabla_xf=\nabla_{f(x,y)}g*\overleftarrow{y}\)

We need a scalar output function so that its gradient exists, for simplity, let \(g(u, v)=u\odot v\), that is, for \(f(x, y, z)=(x*y)\odot z\), then we should have \(\nabla_xf=z*\overleftarrow{y}\)

In fact, the result should depend on the convolution mode, but we still use the same convolution operator \(*\) in the manuscript for simplicity.

convolution mode = ‘full’

f = lambda x, y, z: inner(convolve(x, y, mode='full'), z)
c = array([4.5, 1., 5., -2.2, -11, 3.4, 9])
grad(f)(a, b, c)  # evaluate the gradient of f respect to its first argument at given inputs a, b and c
DeviceArray([  4.5     ,  17.84    ,  34.800003, -36.08    , -46.      ],            dtype=float32)
convolve(c, b[::-1], mode='valid')  # we manually evaluate the gradient
DeviceArray([  4.5     ,  17.84    ,  34.800003, -36.08    , -46.      ],            dtype=float32)

convolution mode = ‘same’

f = lambda x, y, z: inner(convolve(x, y, mode='same'), z)

c = array([4.5, 1., 5., -2.2, -11])
grad(f)(a, b, c)
DeviceArray([  6.8     ,   4.5     ,  17.84    ,  34.800003, -28.6     ],            dtype=float32)
convolve(c, b[::-1], mode='same')
DeviceArray([  6.8     ,   4.5     ,  17.84    ,  34.800003, -28.6     ],            dtype=float32)

convolution mode = ‘valid’

f = lambda x, y, z: inner(convolve(x, y, mode='valid'), z)

c = array([4.5, 1., 5.])
grad(f)(a, b, c)
DeviceArray([-9.900001,  6.8     ,  4.5     , 13.      , 15.      ], dtype=float32)
convolve(c, b[::-1], mode='full')
DeviceArray([-9.900001,  6.8     ,  4.5     , 13.      , 15.      ], dtype=float32)