Training lossΒΆ

import jax
import numpy as np
import matplotlib.pyplot as plt
from commplax import equalizer as eq, xcomm, plot as cplt
from gdbp import gdbp_base as gb, data as gdat, plot as gplt
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
data = gdat.load(1, 0, 4, 2)[0]
def train_gdbp(data, dtaps=261, ntaps=41):
  gdbp = gb.model_init(data, {'ntaps': ntaps, 'dtaps': dtaps, 'init_fn': gb.fdbp_init(data.a, xi=1.1, steps=3)}, [])
  return list(zip(*list(gb.train(gdbp, data, n_iter=2000))))


def train_fdbp(data, dtaps=261, ntaps=41):
  gdbp = gb.model_init(data, {'ntaps': ntaps, 'dtaps': dtaps, 'init_fn': gb.fdbp_init(data.a, xi=1.1, steps=3)}, [('fdbp_0',)])
  return list(zip(*list(gb.train(gdbp, data, n_iter=2000))))


def tree_diff(trees, **kwargs):
  diff = [jax.tree_map(lambda a, b: np.mean(np.abs(a - b)**2)/np.mean(np.abs(b)**2),
                       trees[i], trees[-1], **kwargs) for i in range(len(trees)-1)]
  return jax.tree_map(lambda *xs: list(xs), *diff)
loss_gdbp, params, state = train_gdbp(data)
loss_fdbp = train_fdbp(data)[0]
params_diff = tree_diff(params)
state_diff = tree_diff(state)
plt.figure(figsize=(6, 4), dpi=300)
sli = slice(0, 2000)
ax1 = plt.gca()
ax1.plot(params_diff['fdbp_0']['DConv_0']['kernel'][sli], label=r'D-filter$_1$')
ax1.plot(params_diff['RConv']['kernel'][sli], label=r'R-filter')
ax1.plot(params_diff['fdbp_0']['NConv_0']['kernel'][sli], '--', label=r'N-filter$_1$', markersize=1)
ax1.legend(fontsize=8, loc='lower left')
ax1.set_ylabel(r'mean(|$\frac{\mathbf{\theta}_i - \mathbf{\theta}_{2000}}{\mathbf{\theta}_{2000}}|^2$)')
ax1.set_xlabel('iteration i')
ax1.legend(fontsize=8)

axins = ax1.inset_axes([0.4, 0.4, 0.3, 0.3])
axins.plot(params_diff['fdbp_0']['DConv_0']['kernel'], label=r'D-filter$_1$')
axins.plot(params_diff['RConv']['kernel'], label=r'R-filter')
axins.set_xlim(0, 1000)
# axins.set_ylim(2.92, 3.4)
axins.set_xticks([0, 500, 1000])
ax1.indicate_inset_zoom(axins, edgecolor="slateblue")
(<matplotlib.patches.Rectangle at 0x7f177c62a2b0>,
 (<matplotlib.patches.ConnectionPatch at 0x7f177c70f370>,
  <matplotlib.patches.ConnectionPatch at 0x7f177c713c40>,
  <matplotlib.patches.ConnectionPatch at 0x7f177c7135b0>,
  <matplotlib.patches.ConnectionPatch at 0x7f177c713eb0>))
../_images/training_loss_7_1.png
def loss(loss, ax=None, label=None, alpha=0.4):
    if ax is None:
        plt.figure()
        ax = plt.gca()
    loss_mean = np.convolve(loss, np.ones(20) / 20, mode='same')
    p = ax.plot(loss[:], alpha=alpha, label=label)
    ax.plot(loss_mean[:-50], color=p[0].get_color())
    ax.set_xlabel('iteration')
    ax.set_ylabel('MSE')
    ax.legend()
plt.figure(figsize=(6, 4), dpi=300)
ax = plt.gca()
loss(loss_gdbp, ax=ax, label='NN + MIMO-DDLMS')
loss(loss_fdbp, ax=ax, label='MIMO-DDLMS only')
ax.set_xlabel('iteration i')
ax.set_ylim([0.023, 0.045])
(0.023, 0.045)
../_images/training_loss_9_1.png