BP of Conv1D¶
%pip install jax==0.2.13 jaxlib==0.1.66
from jax import grad
from jax.numpy import array, convolve, inner
a = array([1.1, 2., -3, 2.5, 7.])
b = array([3., 2., -2.2])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Inner product¶
An often use operation is inner product \(\odot\), let \(f(x, y)=x\odot y\), we have \(\nabla_xf=y\)
c = array([4.5, 1., 5.])
grad(inner)(b, c) # returns the value of c
DeviceArray([4.5, 1. , 5. ], dtype=float32)
Conv1D¶
Another case is convolution, whose vector-Jacobian product is heavily used in the analysis of backprop. GDBP.
Let \(f(x, y)=g(x*y)\), we have \(\nabla_xf=\nabla_{f(x,y)}g*\overleftarrow{y}\)
We need a scalar output function so that its gradient exists, for simplity, let \(g(u, v)=u\odot v\), that is, for \(f(x, y, z)=(x*y)\odot z\), then we should have \(\nabla_xf=z*\overleftarrow{y}\)
In fact, the result should depend on the convolution mode, but we still use the same convolution operator \(*\) in the manuscript for simplicity.
convolution mode = ‘full’¶
f = lambda x, y, z: inner(convolve(x, y, mode='full'), z)
c = array([4.5, 1., 5., -2.2, -11, 3.4, 9])
grad(f)(a, b, c) # evaluate the gradient of f respect to its first argument at given inputs a, b and c
DeviceArray([ 4.5 , 17.84 , 34.800003, -36.08 , -46. ], dtype=float32)
convolve(c, b[::-1], mode='valid') # we manually evaluate the gradient
DeviceArray([ 4.5 , 17.84 , 34.800003, -36.08 , -46. ], dtype=float32)
convolution mode = ‘same’¶
f = lambda x, y, z: inner(convolve(x, y, mode='same'), z)
c = array([4.5, 1., 5., -2.2, -11])
grad(f)(a, b, c)
DeviceArray([ 6.8 , 4.5 , 17.84 , 34.800003, -28.6 ], dtype=float32)
convolve(c, b[::-1], mode='same')
DeviceArray([ 6.8 , 4.5 , 17.84 , 34.800003, -28.6 ], dtype=float32)
convolution mode = ‘valid’¶
f = lambda x, y, z: inner(convolve(x, y, mode='valid'), z)
c = array([4.5, 1., 5.])
grad(f)(a, b, c)
DeviceArray([-9.900001, 6.8 , 4.5 , 13. , 15. ], dtype=float32)
convolve(c, b[::-1], mode='full')
DeviceArray([-9.900001, 6.8 , 4.5 , 13. , 15. ], dtype=float32)