Source code for contraction_net.annotation

import random
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Button


[docs] class TimeSeriesAnnotator: """ Class to annotate contraction intervals in multiple time-series files and save the results. This class allows users to interactively annotate the start and end of contractions in time-series data by clicking on the plot. The user can press the left mouse button at the start of a contraction interval and release at the end of the contraction interval. The results are saved to a specified output directory. Parameters ---------- file_list : list of str List of file paths to the time-series data files to be annotated. output_dir : str Directory where the output files with annotated contraction intervals will be saved. figsize : tuple, optional Figure size in inches. Methods ------- load_next_file() Load the next time-series file for annotation. press_callback(event) Record the x-coordinate (time) where the mouse button is pressed. release_callback(event) Record the x-coordinate (time) where the mouse button is released. save_and_load_next(event) Save the current annotations and load the next time-series file. reset_annotations(event) Reset the current annotations for the current time-series file. Notes ----- In an interactive Jupyter notebook, add the following before running the TimeSeriesAnnotator: ``` matplotlib.use('nbagg') %matplotlib notebook ``` """ def __init__(self, file_list, output_dir, figsize=(11, 3.5)): self.file_list = [f for f in file_list if not f.endswith('_contr.txt')] self.output_dir = output_dir self.current_file_index = 0 self.start_contraction = [] self.end_contraction = [] self.fig, self.ax = plt.subplots(figsize=figsize, constrained_layout=True) self.load_next_file() self.fig.canvas.mpl_connect('button_press_event', self.press_callback) self.fig.canvas.mpl_connect('button_release_event', self.release_callback) save_ax = self.fig.add_axes([0.85, 0.01, 0.1, 0.075]) self.save_button = Button(save_ax, 'Save & Next', color='lightgoldenrodyellow', hovercolor='0.975') self.save_button.on_clicked(self.save_and_load_next) reset_ax = self.fig.add_axes([0.75, 0.01, 0.1, 0.075]) self.reset_button = Button(reset_ax, 'Reset', color='lightcoral', hovercolor='0.975') self.reset_button.on_clicked(self.reset_annotations) plt.show()
[docs] def load_next_file(self): """Load the next time-series file for annotation.""" if self.current_file_index < len(self.file_list): file_path = self.file_list[self.current_file_index] self.data = np.loadtxt(file_path) self.ax.clear() self.ax.plot(self.data, c='k') self.ax.set_title(f'Annotating {os.path.basename(file_path)}') self.start_contraction = [] self.end_contraction = [] plt.draw() else: plt.close(self.fig) print('All files annotated.')
[docs] def press_callback(self, event): """Record the x-coordinate (time) where the mouse button is pressed.""" self.ax.axvline(event.xdata, c='r', linestyle=':') self.start_contraction.append(event.xdata)
[docs] def release_callback(self, event): """Record the x-coordinate (time) where the mouse button is released.""" self.ax.axvline(event.xdata, c='r', linestyle=':') self.ax.axvspan(self.start_contraction[-1], event.xdata, alpha=0.5, color='red') self.end_contraction.append(event.xdata)
[docs] def save_and_load_next(self, event): """Save the current annotations and load the next time-series file.""" if self.current_file_index < len(self.file_list): file_path = self.file_list[self.current_file_index] output_file = os.path.join(self.output_dir, f'{os.path.basename(file_path).split(".")[0]}_contr.txt') start_end_contraction = np.asarray([self.start_contraction, self.end_contraction]) np.savetxt(output_file, start_end_contraction.T, header='start end', comments='', fmt='%f') print(f'Saved contractions for {os.path.basename(file_path)} to {output_file}') self.current_file_index += 1 self.load_next_file()
[docs] def reset_annotations(self, event): """Reset the current annotations for the current time-series file.""" self.start_contraction = [] self.end_contraction = [] self.ax.clear() self.ax.plot(self.data, c='k') self.ax.set_title(f'Annotating {os.path.basename(self.file_list[self.current_file_index])}') plt.draw()
[docs] def split_trace(filepath, output_dir, chunk_size=512, p=1): """ Split a time-series trace into smaller chunks and save them to separate files. Parameters ---------- filepath : str Path to the input time-series data file. output_dir : str Directory where the output chunk files will be saved. chunk_size : int, optional The size of each chunk, by default 512. p : float, optional Probability of saving each chunk, by default 1 (save all chunks). """ # Read the time-series data data = np.loadtxt(filepath) # Create the output directory if it doesn't exist if not os.path.exists(output_dir): os.makedirs(output_dir) # Calculate the number of chunks num_chunks = len(data) // chunk_size + (len(data) % chunk_size > 0) # Split the data into chunks of size `chunk_size` and save each chunk to a separate file for i in range(num_chunks): if random.random() < p: chunk = data[i * chunk_size:(i + 1) * chunk_size] np.savetxt(os.path.join(output_dir, f'{os.path.basename(filepath).split(".")[0]}_{i}.txt'), chunk)