"""Projektowanie filtru IIR różnymi metodami"""

import numpy as np
import scipy.signal as sig
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams['figure.figsize'] = (8, 4)

fs = 48000


def get_signal(n=4096):
    """Sygnał testowy - 5 sinusów i szum."""
    t = np.arange(n) / fs
    fr = np.array([500, 1000, 1500, 2000, 2500]).reshape(-1, 1)
    x = np.sum(np.sin(2 * np.pi * t * fr), axis=0)
    x = x + 0.05 * np.random.randn(len(x))
    x = x / np.max(np.abs(x))
    return x


def plot_filter(b, a, x):
    """Wykres charakterystyki częstotliwościowej filtru i widma po filtracji."""
    if a is None:  # SOS
        y = sig.sosfilt(b, x)
        b, a = sig.sos2tf(b)
    else:
        y = sig.lfilter(b, a, x)
    w, hf = sig.freqz(b, a, worN=2048, fs=fs)
    lag = 20
    x_sp = np.fft.rfft(x[:2048] * np.hamming(2048))
    x_spdb = 20 * np.log10(np.abs(x_sp) / 1024)
    y_sp = np.fft.rfft(y[lag:lag + 2048] * np.hamming(2048))
    y_spdb = 20 * np.log10(np.abs(y_sp) / 1024)
    f = np.fft.rfftfreq(2048, 1 / fs)
    fig, ax = plt.subplots(2, sharex=True, tight_layout=True, figsize=(8, 5))
    ax[0].plot(w, 20 * np.log10(np.abs(hf)))
    ax[0].set_ylabel('Wzmocnienie [dB]')
    ax[0].set_title('Charakterystyka częstotliwościowa filtru')
    ax[0].grid()
    # ax[0].set_ylim(bottom=-90)
    ax[1].plot(f, x_spdb, c='#a0a0a0', label='Oryginalny')
    ax[1].plot(f, y_spdb, label='Po filtracji')
    ax[1].set_ylabel('Poziom [dB]')
    ax[1].set_title('Widmo sygnału')
    ax[1].grid()
    ax[1].set_ylim(bottom=-90)
    ax[1].set_xlabel('Częstotliwość [Hz]')
    return ax


x = get_signal()

# Projekt 1 - dwa ostatnie prążki
# filtr DP 1700, eliptyczny, N=8, Rp=1, Rz=60, forma TF
b1, a1 = sig.iirfilter(8, 1700, btype='lowpass', ftype='ellip', rp=1, rs=60, fs=fs)
ax1 = plot_filter(b1, a1, x)
ax1[0].set_ylim(bottom=-80)
ax1[1].set_xlim(0, 4000)

# Projekt 2 - usuwamy pierwszy i ostatni prążek
# filtr PP 750-2250 Hz, eliptyczny, Rp=1, Rz=60, forma SOS
sos2 = sig.iirdesign(wp=(750, 2250), ws=(500, 2500), gpass=1, gstop=60,
        ftype='ellip', output='sos', fs=fs)
ax2 = plot_filter(sos2, None, x)
ax2[0].set_ylim(bottom=-80)
ax2[1].set_xlim(0, 4000)

print(sos2.shape)

np.set_printoptions(suppress=True, precision=4)
print(sos2)

plt.show()
