Performance penalty by using fewer taps

import os
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from functools import partial
import matplotlib.pyplot as plt
from commplax import util
from gdbp import gdbp_base as gb, data as gdat, plot as gplt, aux
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
ds_train, ds_test = gdat.load(1, 0, 4, 2)[0], gdat.load(2, 0, 4, 1)[0]
aux.dl_trained_params() # download pre-trained parameters
loc_trained_params = './trained_params/pretrained'
100% [..........................................................................] 1582928 / 1582928
# use Pandas.Dataframe to store the results
df_test_res = pd.DataFrame({c: pd.Series(dtype=t) for c, t in {'ntaps': 'int',
                                                               'dtaps': 'int',
                                                               'xi': 'float',
                                                               'Model': 'str',
                                                               'Q': 'float'}.items()})
def init_models(data, hparams, mode='train'):
  model_conf = {'mode':  mode,
                'dtaps': hparams['dtaps'],
                'ntaps': hparams['ntaps'],
                'init_fn': gb.fdbp_init(data.a, steps=3, xi=hparams['xi'])}
  model_init = partial(gb.model_init, data, model_conf)
  fdbp = model_init([('fdbp_0',)]            , name='FDBP')
  edbp = model_init([('fdbp_0', r'DConv_\d')], name='EDBP')
  gdbp = model_init([],                        name='GDBP')
  return fdbp, edbp, gdbp


def sweep_hparams(hparams, df_test_res=df_test_res, use_pretrained_params=False, save_params=False):
  for i, hp in enumerate(tqdm(hparams, desc='sweeping hparams')):  # iterate hparams
    models_train = init_models(ds_train, hp)
    models_test = init_models(ds_test, hp, mode='test')
    
    for j, (model_train, model_test) in enumerate(tqdm(zip(models_train, models_test),
                                                       total=len(models_train),
                                                       desc='iterating models',
                                                       leave=False)):  # iterate models
      params_file = os.path.join(loc_trained_params,
                                 'snr_vs_dtaps_ntaps',
                                 'params_%d_%d' % (i, j)) # params_{hparam}_{model}
      
      if use_pretrained_params:
        params = util.load_variable(params_file)
      else:
        params_queue = [None] * 3
        
        for _, p, _ in gb.train(model_train, ds_train, n_iter=2000):
          params_queue.append(p)
          params = params_queue.pop(0)
        
        if save_params:
          util.save_variable(params, params_file)
        
      test_Q = gb.test(model_test, params, ds_test)[0].QSq.total
      # collect result
      df_test_res = df_test_res.append({'dtaps': hp['dtaps'],
                                        'ntaps': hp['ntaps'],
                                        'xi': hp['xi'],
                                        'Model': model_test.name,
                                        'Q': test_Q},
                                       ignore_index=True)
    
    if i % 10 == 0:
      util.clear_xla_cache() # clear compile cache periodically to save RAM

  return df_test_res
# make sweeping grid

# caveats: taps must be odd number
NTAPS = np.arange(1, 41, 6)
DTAPS = np.arange(181, 261, 12)
OPTXI = pd.read_csv('../../source_data/fdbp_opt_xi_ntaps.csv').values[:, 0]  # read optimal xi given dtaps

hparams = []
for ntaps, xi in zip(NTAPS, OPTXI):
  for dtaps in DTAPS:
    hparams.append({'ntaps': ntaps, 'dtaps': dtaps, 'xi': xi})
# it may take a while to finish
df_test_res = sweep_hparams(hparams, df_test_res, use_pretrained_params=True, save_params=False)
df_test_res
ntaps dtaps xi Model Q
0 1 181 0.15 FDBP 7.239062
1 1 181 0.15 EDBP 7.244903
2 1 181 0.15 GDBP 8.223946
3 1 193 0.15 FDBP 7.611000
4 1 193 0.15 EDBP 7.618716
... ... ... ... ... ...
142 37 241 1.10 EDBP 8.596812
143 37 241 1.10 GDBP 8.741637
144 37 253 1.10 FDBP 8.590569
145 37 253 1.10 EDBP 8.670206
146 37 253 1.10 GDBP 8.746362

147 rows × 5 columns

fig, ax = plt.subplots(subplot_kw={"projection": "3d"}, figsize=(5, 5), dpi=200)
ax.view_init(19, -30)
curves = 20
cms = [plt.cm.Greys, plt.cm.Blues, plt.cm.Reds]

colors = plt.cm.RdPu_r(np.linspace(0., 0.5, 3))[::-1]

for cm, mod in zip(cms, ['FDBP', 'EDBP', 'GDBP']):
  df = df_test_res.groupby('Model').get_group(mod).pivot('dtaps', 'ntaps', 'Q')
  (X, Y), Z = np.meshgrid(df.columns.values, df.index.values), df.values
  gplt.wireframe_cmap(ax, X, Y, Z, cmap=cm, offset=[-1.2, 0], label=mod)
  
ax.set_xlabel('N-filter length')
ax.set_ylabel('D-filter length')
ax.set_zlabel('Q-factor (dB)')
ax.invert_xaxis()
ax.legend(loc=(-0.18, 0.4))
<matplotlib.legend.Legend at 0x7fde1cb677f0>
../_images/benchmark_few_taps_9_1.png