Skip to content

Commit

Permalink
Fixed presentation and saving data
Browse files Browse the repository at this point in the history
  • Loading branch information
lostinafro committed Mar 23, 2023
1 parent b6c79ba commit e8b1d91
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions cb_bsw_vs_minimum_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import matplotlib.pyplot as plt

import scenario.common as cmn
from environment import RisProtocolEnv, command_parser, NOISE_POWER_dBm, TX_POW_dBm, T, OUTPUT_DIR, TAU
from environment import RisProtocolEnv, command_parser, NOISE_POWER_dBm, TX_POW_dBm, T, OUTPUT_DIR, TAU, BANDWIDTH, NUM_PILOTS
import os

from matplotlib import rc

# rc('font', **{'family': 'sans serif', 'serif': ['Computer Modern']})
# rc('text', usetex=True)

##################################################
# Parameters
Expand All @@ -24,9 +22,6 @@
# Get noise power
noise_power = cmn.dbm2watt(NOISE_POWER_dBm)

# Define number of pilots
num_pilots = 1

# For grid mesh
num_users = int(1e4)

Expand Down Expand Up @@ -87,7 +82,7 @@
h_eq_cb_fixed = (g_rb.conj()[np.newaxis, np.newaxis, :] * codebook_fixed[np.newaxis, :, :] * h_ur[:, np.newaxis, :]).sum(axis=-1)

# Generate some noise
var = noise_power / num_pilots
var = noise_power / NUM_PILOTS
bsw_noise_ = np.sqrt(var) * noise_

# Compute the SNR of each user when using CB scheme
Expand Down Expand Up @@ -123,7 +118,7 @@
index = np.argmax(mask)

# Store results
se_flex[ms, uu] = np.log2(1 + minimum_snr)
se_flex[ms, uu] = np.log2(1 + minimum_snr) if snr_cb_flexi[uu, index] > minimum_snr else 0.
n_configurations_flex[ms, uu] = index + 1

# Go through all users (FIXED)
Expand All @@ -135,12 +130,15 @@
if np.sum(mask) == 0:
continue

# Get the index of the highest snr satisfying the constraint
index = np.argmax(snr_cb_hat_fixed[uu] * mask)

# Store results
se_fixed[ms, uu] = np.log2(1 + minimum_snr)
se_fixed[ms, uu] = np.log2(1 + minimum_snr) if snr_cb_fixed[uu, index] > minimum_snr else 0.


## Varying the frame duration
tau_plot = [8.75, 10., 20.]
tau_plot = [7.5, 15, 30]

opt_kpi_flexi = np.zeros((len(TAU), len(setups)))
opt_kpi_fixed = np.zeros((len(TAU), len(setups)))
Expand All @@ -156,7 +154,7 @@
datafilename = 'data/opt_ce_vs_tau_' + setup + '.npz'

# Load data
rate = np.squeeze(np.load(datafilename)['rate'][TAU == total_tau])
rate = np.squeeze(np.load(datafilename)['rate'][TAU == total_tau]) * BANDWIDTH / 1e6

if total_tau in tau_plot:
axes[ss].plot(minimum_snr_range_dB, np.mean(rate) * np.ones_like(minimum_snr_range), linewidth=1.5, color='black', label='OPT-CE')
Expand All @@ -170,6 +168,7 @@
else:
tau_setup = 3 * T

# Flexible structure
# Pre-log term
tau_alg = (2 * n_configurations_flex - 1) * T
prelog_term = 1 - (tau_setup + tau_setup + tau_alg) / total_tau
Expand All @@ -178,18 +177,19 @@
rate_cb_bsw = prelog_term * se_flex
avg_rate_cb_bsw_flexi = np.mean(rate_cb_bsw, axis=-1)
if total_tau in tau_plot:
axes[ss].plot(minimum_snr_range_dB, avg_rate_cb_bsw_flexi, linewidth=1.5, linestyle=':', label='CB-BSW: Flexible', color='red')
axes[ss].plot(minimum_snr_range_dB, avg_rate_cb_bsw_flexi * BANDWIDTH / 1e6, linewidth=1.5, linestyle=':', label='CB-BSW: Flexible', color='red')

opt_kpi_flexi[tt, ss] = minimum_snr_range_dB[np.argmax(avg_rate_cb_bsw_flexi)]

# Fixed structure
# Pre-log term
tau_alg = num_configs * T
prelog_term = 1 - (tau_setup + tau_setup + tau_alg) / total_tau
prelog_term = np.maximum(1 - (tau_setup + tau_setup + tau_alg) / total_tau, 0)

rate_cb_bsw = prelog_term * se_fixed
avg_rate_cb_bsw_fixed = np.mean(rate_cb_bsw, axis=-1)
if total_tau in tau_plot:
axes[ss].plot(minimum_snr_range_dB, avg_rate_cb_bsw_fixed, linewidth=1.5, linestyle='--', label='CB-BSW: Fixed', color='blue')
axes[ss].plot(minimum_snr_range_dB, avg_rate_cb_bsw_fixed * BANDWIDTH / 1e6, linewidth=1.5, linestyle='--', label='CB-BSW: Fixed', color='blue')

opt_kpi_fixed[tt, ss] = minimum_snr_range_dB[np.argmax(avg_rate_cb_bsw_fixed)]

Expand All @@ -206,7 +206,7 @@
# Plot for specific values
if total_tau in tau_plot:
cmn.printplot(fig, axes, render=render, filename=f'rate_vs_gamma0_{total_tau:.2f}', dirname=OUTPUT_DIR,
labels=[r'$\gamma_0$ [dB]', r'$R$ [bit/Hz/s]', r'$R$ [bit/Hz/s]'], orientation='vertical')
labels=[r'$\gamma_0$ [dB]', r'$R$ [Mbit/s]', r'$R$ [Mbit/s]'], orientation='vertical')

# Save optimal threshold values
opt_kpi_fixed = np.mean(opt_kpi_fixed, axis=-1)
Expand Down

0 comments on commit e8b1d91

Please sign in to comment.