Skip to content

Commit

Permalink
added explicit time axis
Browse files Browse the repository at this point in the history
  • Loading branch information
abhierra2 committed Aug 7, 2024
1 parent e226f0b commit 2a1a8b0
Showing 1 changed file with 118 additions and 70 deletions.
188 changes: 118 additions & 70 deletions wave_plot_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,30 @@ def plot_wave(fig, x_values, y_values, color, name, marker_color=None):
if marker_color:
fig.add_trace(go.Scatter(x=x_values, y=y_values, mode='markers', marker=dict(color=marker_color), name=name, showlegend=False))


def calculate_and_plot_wave(df, freq, db, color, threshold=None):
khz = df[(df['Freq(Hz)'] == freq) & (df[db_column] == db)]
if not khz.empty:
index = khz.index.values[0]
final = df.loc[index, '0':].dropna()
final = pd.to_numeric(final, errors='coerce').dropna()
y_values = interpolate_and_smooth(final)

target = 244 * (time_scale / 10)

y_values = interpolate_and_smooth(final, target) # Original y-values for plotting
sampling_rate = len(y_values) / time_scale # Assuming 10 ms duration for 244 points

x_values = np.linspace(0, len(y_values) / sampling_rate, len(y_values))

y_values_for_peak_finding = interpolate_and_smooth(final[:244])
if units == 'Nanovolts':
y_values /= 1000

y_values *= multiply_y_factor
y_values_for_peak_finding *= multiply_y_factor

highest_peaks, relevant_troughs = peak_finding(y_values)
highest_peaks, relevant_troughs = peak_finding(y_values_for_peak_finding)

return y_values, highest_peaks, relevant_troughs
return x_values, y_values, highest_peaks, relevant_troughs
return None, None, None

def plot_waves_single_frequency(df, freq, y_min, y_max, plot_time_warped=False):
Expand All @@ -108,23 +117,25 @@ def plot_waves_single_frequency(df, freq, y_min, y_max, plot_time_warped=False):
st.write("Threshold can't be calculated.")

for i, db in enumerate(sorted(db_levels)):
y_values, highest_peaks, relevant_troughs = calculate_and_plot_wave(file_df, freq, db, glasbey_colors[i])
x_values, y_values, highest_peaks, relevant_troughs = calculate_and_plot_wave(file_df, freq, db, glasbey_colors[i])

if y_values is not None:
fig.add_trace(go.Scatter(x=np.linspace(0, 10, len(y_values)), y=y_values, mode='lines', name=f'{int(db)} dB', line=dict(color=glasbey_colors[i])))
if return_units == 'Nanovolts':
y_values *= 1000
fig.add_trace(go.Scatter(x=x_values, y=y_values, mode='lines', name=f'{int(db)} dB', line=dict(color=glasbey_colors[i])))
# Mark the highest peaks with red markers
fig.add_trace(go.Scatter(x=np.linspace(0,10,len(y_values))[highest_peaks], y=y_values[highest_peaks], mode='markers', marker=dict(color='red'), name='Peaks'))#, showlegend=False))
fig.add_trace(go.Scatter(x=x_values[highest_peaks], y=y_values[highest_peaks], mode='markers', marker=dict(color='red'), name='Peaks'))#, showlegend=False))

# Mark the relevant troughs with blue markers
fig.add_trace(go.Scatter(x=np.linspace(0,10,len(y_values))[relevant_troughs], y=y_values[relevant_troughs], mode='markers', marker=dict(color='blue'), name='Troughs'))#, showlegend=False))
fig.add_trace(go.Scatter(x=x_values[relevant_troughs], y=y_values[relevant_troughs], mode='markers', marker=dict(color='blue'), name='Troughs'))#, showlegend=False))

if plot_time_warped:
original_waves.append(y_values.tolist())

if plot_time_warped:
original_waves_array = np.array([wave[:-1] for wave in original_waves])
try:
time = np.linspace(0, 10, original_waves_array.shape[1])
time = np.linspace(0, time_scale, original_waves_array.shape[1])
obj = fs.fdawarp(original_waves_array.T, time)
obj.srsf_align(parallel=True)
warped_waves_array = obj.fn.T
Expand All @@ -134,11 +145,18 @@ def plot_waves_single_frequency(df, freq, y_min, y_max, plot_time_warped=False):
pass

if threshold is not None:
y_values, _, _ = calculate_and_plot_wave(file_df, freq, threshold, 'black')
x_values, y_values, _, _ = calculate_and_plot_wave(file_df, freq, threshold, 'black')
if y_values is not None:
fig.add_trace(go.Scatter(x=np.linspace(0, 10, len(y_values)), y=y_values, mode='lines', name=f'Threshold: {int(threshold)} dB', line=dict(color='black', width=5)))
if return_units == 'Nanovolts':
y_values *= 1000
fig.add_trace(go.Scatter(x=x_values, y=y_values, mode='lines', name=f'Threshold: {int(threshold)} dB', line=dict(color='black', width=5)))

if return_units == 'Nanovolts':
y_units = 'Voltage (nV)'
else:
y_units = 'Voltage (μV)'

fig.update_layout(title=f'{selected_files[idx].split("/")[-1]} - Frequency: {freq} Hz', xaxis_title='Time (ms)', yaxis_title='Voltage (μV)')
fig.update_layout(title=f'{selected_files[idx].split("/")[-1]} - Frequency: {freq} Hz', xaxis_title='Time (ms)', yaxis_title=y_units)
fig.update_layout(annotations=annotations)
fig.update_layout(yaxis_range=[y_min, y_max])
fig.update_layout(width=700, height=450)
Expand All @@ -150,18 +168,25 @@ def plot_waves_single_tuple(freq, db, y_min, y_max):
db_column = 'Level(dB)' if level else 'PostAtten(dB)'

for idx, file_df in enumerate(selected_dfs):
y_values, highest_peaks, relevant_troughs = calculate_and_plot_wave(file_df, freq, db, 'blue')
x_values, y_values, highest_peaks, relevant_troughs = calculate_and_plot_wave(file_df, freq, db, 'blue')

if y_values is not None:
fig.add_trace(go.Scatter(x=np.linspace(0, 10, len(y_values)), y=y_values, mode='lines', name=f'{selected_files[idx].split("/")[-1]}', showlegend=True))
if return_units == 'Nanovolts':
y_values *= 1000
fig.add_trace(go.Scatter(x=x_values, y=y_values, mode='lines', name=f'{selected_files[idx].split("/")[-1]}', showlegend=True))
# Mark the highest peaks with red markers
fig.add_trace(go.Scatter(x=np.linspace(0,10,len(y_values))[highest_peaks], y=y_values[highest_peaks], mode='markers', marker=dict(color='red'), name='Peaks'))#, showlegend=False))
fig.add_trace(go.Scatter(x=x_values[highest_peaks], y=y_values[highest_peaks], mode='markers', marker=dict(color='red'), name='Peaks'))#, showlegend=False))

# Mark the relevant troughs with blue markers
fig.add_trace(go.Scatter(x=np.linspace(0,10,len(y_values))[relevant_troughs], y=y_values[relevant_troughs], mode='markers', marker=dict(color='blue'), name='Troughs'))#, showlegend=False))
fig.add_trace(go.Scatter(x=x_values[relevant_troughs], y=y_values[relevant_troughs], mode='markers', marker=dict(color='blue'), name='Troughs'))#, showlegend=False))

if return_units == 'Nanovolts':
y_units = 'Voltage (nV)'
else:
y_units = 'Voltage (μV)'

fig.update_layout(width=700, height=450)
fig.update_layout(xaxis_title='Time (ms)', yaxis_title='Voltage (μV)')
fig.update_layout(xaxis_title='Time (ms)', yaxis_title=y_units)
fig.update_layout(annotations=annotations)
fig.update_layout(yaxis_range=[y_min, y_max])

Expand All @@ -186,9 +211,11 @@ def plot_3d_surface(df, freq, y_min, y_max):
threshold = None

for db in db_levels:
y_values, _, _ = calculate_and_plot_wave(file_df, freq, db, 'blue')
x_values, y_values, _, _ = calculate_and_plot_wave(file_df, freq, db, 'blue')

if y_values is not None:
if return_units == 'Nanovolts':
y_values *= 1000
original_waves.append(y_values.tolist())

original_waves_array = np.array([wave[:-1] for wave in original_waves])
Expand All @@ -202,9 +229,9 @@ def plot_3d_surface(df, freq, y_min, y_max):
warped_waves_array = np.array([])

for i, (db, warped_waves) in enumerate(zip(db_levels, warped_waves_array)):
fig.add_trace(go.Scatter3d(x=[db] * len(warped_waves), y=np.linspace(0, 10, len(warped_waves)), z=warped_waves, mode='lines', name=f'{int(db)} dB', line=dict(color='blue')))
fig.add_trace(go.Scatter3d(x=[db] * len(warped_waves), y=x_values, z=warped_waves, mode='lines', name=f'{int(db)} dB', line=dict(color='blue')))
if db == threshold:
fig.add_trace(go.Scatter3d(x=[db] * len(warped_waves), y=np.linspace(0, 10, len(warped_waves)), z=warped_waves, mode='lines', name=f'Thresh: {int(db)} dB', line=dict(color='black', width=5)))
fig.add_trace(go.Scatter3d(x=[db] * len(warped_waves), y=x_values, z=warped_waves, mode='lines', name=f'Thresh: {int(db)} dB', line=dict(color='black', width=5)))

for i in range(len(time)):
z_values_at_time = [warped_waves_array[j, i] for j in range(len(db_levels))]
Expand Down Expand Up @@ -257,53 +284,61 @@ def display_metrics_table(df, freq, db, baseline_level):
).set_properties(**{'width': '100px'})
return styled_metrics_table

def display_metrics_table_all_db(selected_dfs, freq, db_levels, baseline_level):
def display_metrics_table_all_db(selected_dfs, freqs, db_levels, baseline_level):
if level:
db_column = 'Level(dB)'
else:
db_column = 'PostAtten(dB)'

ru = 'μV'
if return_units == 'Nanovolts':
ru = 'nV'

metrics_data = {'File Name': [], 'Frequency (Hz)': [], 'dB Level': [], 'First Peak Amplitude (μV)': [], 'Latency to First Peak (ms)': [], 'Amplitude Ratio (Peak1/Peak4)': [], 'Estimated Threshold': []}
metrics_data = {'File Name': [], 'Frequency (Hz)': [], 'dB Level': [], f'First Peak Amplitude ({ru})': [], 'Latency to First Peak (ms)': [], 'Amplitude Ratio (Peak1/Peak4)': [], 'Estimated Threshold': []}

for file_df, file_name in zip(selected_dfs, selected_files):
try:
threshold = calculate_hearing_threshold(file_df, freq)
except:
threshold = np.nan
pass
for freq in freqs:
try:
threshold = calculate_hearing_threshold(file_df, freq)
except:
threshold = np.nan
pass

for db in db_levels:
khz = file_df[(file_df['Freq(Hz)'] == freq) & (file_df[db_column] == db)]
if not khz.empty:
index = khz.index.values[0]
final = file_df.loc[index, '0':].dropna()
final = pd.to_numeric(final, errors='coerce').dropna()
y_values = interpolate_and_smooth(final)

if units == 'Nanovolts':
y_values /= 1000

y_values *= multiply_y_factor

highest_peaks, relevant_troughs = peak_finding(y_values)

if highest_peaks.size > 0: # Check if highest_peaks is not empty
first_peak_amplitude = y_values[highest_peaks[0]] - y_values[relevant_troughs[0]]
latency_to_first_peak = highest_peaks[0] * (10 / len(y_values)) # Assuming 10 ms duration for waveform

if len(highest_peaks) >= 4 and len(relevant_troughs) >= 4:
amplitude_ratio = (y_values[highest_peaks[0]] - y_values[relevant_troughs[0]]) / (
y_values[highest_peaks[3]] - y_values[relevant_troughs[3]])
else:
amplitude_ratio = np.nan

metrics_data['File Name'].append(file_name.split("/")[-1])
metrics_data['Frequency (Hz)'].append(freq)
metrics_data['dB Level'].append(db)
metrics_data['First Peak Amplitude (μV)'].append(first_peak_amplitude)
metrics_data['Latency to First Peak (ms)'].append(latency_to_first_peak)
metrics_data['Amplitude Ratio (Peak1/Peak4)'].append(amplitude_ratio)
metrics_data['Estimated Threshold'].append(threshold)
for db in db_levels:
khz = file_df[(file_df['Freq(Hz)'] == freq) & (file_df[db_column] == db)]
if not khz.empty:
index = khz.index.values[0]
final = file_df.loc[index, '0':].dropna()
final = pd.to_numeric(final, errors='coerce').dropna()
y_values = interpolate_and_smooth(final)

if units == 'Nanovolts':
y_values /= 1000

y_values *= multiply_y_factor

highest_peaks, relevant_troughs = peak_finding(y_values)

if return_units == 'Nanovolts':
y_values *= 1000

if highest_peaks.size > 0: # Check if highest_peaks is not empty
first_peak_amplitude = y_values[highest_peaks[0]] - y_values[relevant_troughs[0]]
latency_to_first_peak = highest_peaks[0] * (10 / len(y_values)) # Assuming 10 ms duration for waveform

if len(highest_peaks) >= 4 and len(relevant_troughs) >= 4:
amplitude_ratio = (y_values[highest_peaks[0]] - y_values[relevant_troughs[0]]) / (
y_values[highest_peaks[3]] - y_values[relevant_troughs[3]])
else:
amplitude_ratio = np.nan

metrics_data['File Name'].append(file_name.split("/")[-1])
metrics_data['Frequency (Hz)'].append(freq)
metrics_data['dB Level'].append(db)
metrics_data[f'First Peak Amplitude ({ru})'].append(first_peak_amplitude)
metrics_data['Latency to First Peak (ms)'].append(latency_to_first_peak)
metrics_data['Amplitude Ratio (Peak1/Peak4)'].append(amplitude_ratio)
metrics_data['Estimated Threshold'].append(threshold)

metrics_table = pd.DataFrame(metrics_data)
st.dataframe(metrics_table, hide_index=True, use_container_width=True)
Expand Down Expand Up @@ -357,14 +392,14 @@ def plot_waves_stacked(freq):

# Plot the waveform
color_scale = glasbey_colors[i]
fig.add_trace(go.Scatter(x=np.linspace(0, 10, len(y_values)),
fig.add_trace(go.Scatter(x=np.linspace(0, time_scale, len(y_values)),
y=y_values,
mode='lines',
name=f'{int(db)} dB',
line=dict(color=color_scale)))

if db == threshold:
fig.add_trace(go.Scatter(x=np.linspace(0, 10, len(y_values)),
fig.add_trace(go.Scatter(x=np.linspace(0, time_scale, len(y_values)),
y=y_values,
mode='lines',
name=f'Thresh: {int(db)} dB',
Expand All @@ -385,7 +420,7 @@ def plot_waves_stacked(freq):
st.write(f"Error processing dB level {db}: {e}")

# Add vertical scale bar
if max_value:
if max_value and y_min >= -5 and y_min <= 1:
scale_bar_length = 2 / max_value
fig.add_trace(go.Scatter(x=[10.2, 10.2],
y=[0, scale_bar_length],
Expand Down Expand Up @@ -653,6 +688,7 @@ def all_thresholds():
st.dataframe(threshold_table, hide_index=True, use_container_width=True)
return threshold_table


def peak_finding(wave):
# Prepare waveform
waveform=interpolate_and_smooth(wave)
Expand Down Expand Up @@ -843,20 +879,32 @@ def calculate_unsupervised_threshold(df, freq):
# Get distinct frequency and dB level values across all files
distinct_freqs = sorted(pd.concat([df['Freq(Hz)'] for df in dfs]).unique())
distinct_dbs = sorted(pd.concat([df['Level(dB)'] if level else df['PostAtten(dB)'] for df in dfs]).unique())

time_scale = st.sidebar.number_input("Time Scale for Recording (ms)", value=10.0)

multiply_y_factor = st.sidebar.number_input("Multiply Y Values by Factor", value=1.0)

# Unit dropdown options
units = st.sidebar.selectbox("Select Units Used in Collecting Your Data", options=['Microvolts', 'Nanovolts'], index=0)

# Unit dropdown options
return_units = st.sidebar.selectbox("Select Units You Would Like to Analyze With", options=['Microvolts', 'Nanovolts'], index=0)

# Frequency dropdown options
freq = st.sidebar.selectbox("Select Frequency (Hz)", options=distinct_freqs, index=0)

# dB Level dropdown options
db = st.sidebar.selectbox(f'Select dB {is_level}', options=distinct_dbs, index=0)

y_min = st.sidebar.number_input("Y-axis Minimum", value=-5.0)
y_max = st.sidebar.number_input("Y-axis Maximum", value=5.0)
if return_units == 'Nanovolts':
ymin = -5000.0
ymax = 5000.0
else:
ymin = -5.0
ymax = 5.0

y_min = st.sidebar.number_input("Y-axis Minimum", value=ymin)
y_max = st.sidebar.number_input("Y-axis Maximum", value=ymax)

baseline_level_str = st.sidebar.text_input("Set Baseline Level", "0.0")
baseline_level = float(baseline_level_str)
Expand All @@ -871,13 +919,13 @@ def calculate_unsupervised_threshold(df, freq):
fig = plot_waves_single_frequency(df, freq, y_min, y_max, plot_time_warped=True)
else:
fig = plot_waves_single_frequency(df, freq, y_min, y_max, plot_time_warped=False)
display_metrics_table_all_db(selected_dfs, freq, distinct_dbs, baseline_level)
display_metrics_table_all_db(selected_dfs, [freq], distinct_dbs, baseline_level)

if st.sidebar.button("Plot Single Wave (Frequency, dB)"):
fig = plot_waves_single_tuple(freq, db, y_min, y_max)
st.plotly_chart(fig)
fig.write_image("fig1.pdf")
display_metrics_table_all_db(selected_dfs, freq, [db], baseline_level)
display_metrics_table_all_db(selected_dfs, [freq], [db], baseline_level)
# Create an in-memory buffer
buffer = io.BytesIO()

Expand All @@ -893,10 +941,7 @@ def calculate_unsupervised_threshold(df, freq):
)

if st.sidebar.button("Plot Stacked Waves at Single Frequency"):
if plot_time_warped:
plot_waves_stacked(freq)
else:
plot_waves_stacked(freq)
plot_waves_stacked(freq)

#if st.sidebar.button("Plot Waves with Cubic Spline"):
# fig = plotting_waves_cubic_spline(df, freq, db)
Expand All @@ -909,6 +954,9 @@ def calculate_unsupervised_threshold(df, freq):
if st.sidebar.button("Return All Thresholds"):
all_thresholds()

if st.sidebar.button("Return All Peak Analyses"):
display_metrics_table_all_db(selected_dfs, distinct_freqs, distinct_dbs, baseline_level)

#if st.sidebar.button("Plot Waves with Gaussian Smoothing"):
# fig_gauss = plotting_waves_gauss(dfs, freq, db)
# st.plotly_chart(fig_gauss)
Expand Down

0 comments on commit 2a1a8b0

Please sign in to comment.