Skip to content
48 changes: 45 additions & 3 deletions src/osekit/core/audio_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from osekit.core.audio_item import AudioItem
from osekit.core.base_data import BaseData
from osekit.core.instrument import Instrument
from osekit.utils.audio import Normalization, normalize
from osekit.utils.audio import Butterworth, Normalization, normalize

if TYPE_CHECKING:
from pathlib import Path
Expand All @@ -45,6 +45,7 @@ def __init__(
instrument: Instrument | None = None,
normalization: Normalization = Normalization.RAW,
normalization_values: dict | None = None,
butter: Butterworth | None = None,
) -> None:
"""Initialize an ``AudioData`` from a list of ``AudioItems``.

Expand All @@ -67,13 +68,16 @@ def __init__(
the wav audio data.
normalization: Normalization
The type of normalization to apply to the audio data.
butter: Butterworth | None
Butterworth filter to apply to the audio data.

"""
super().__init__(items=items, begin=begin, end=end, name=name)
self._set_sample_rate(sample_rate=sample_rate)
self.instrument = instrument
self.normalization = normalization
self.normalization_values = normalization_values
self.butter = butter

@property
def nb_channels(self) -> int:
Expand Down Expand Up @@ -123,6 +127,15 @@ def normalization_values(self, value: dict | None) -> None:
}
)

@property
def butter(self) -> Butterworth:
"""The Butterworth filter to apply to the audio data."""
return self._butter

@butter.setter
def butter(self, value: Butterworth) -> None:
self._butter = value

@classmethod
def _make_item(
cls,
Expand Down Expand Up @@ -178,7 +191,7 @@ def get_normalization_values(self) -> dict:
"std": standard deviation used for z-score normalization

"""
values = np.array(self.get_raw_value())
values = np.array(self.get_filtered_value())
self.normalization_values = {
"mean": values.mean(),
"peak": values.max(),
Expand Down Expand Up @@ -222,6 +235,22 @@ def get_raw_value(self) -> np.ndarray:
"""
return np.vstack(list(self.stream()))

def get_filtered_value(self) -> np.ndarray:
"""Return the value of the audio data after filtering.

Returns
-------
np.ndarray:
The value of the audio data filtered by the ``self.butter`` Butterworth filter.

"""
output = self.get_raw_value()
return (
output
if self.butter is None
else self.butter.filter(sig=output, fs=self.sample_rate)
)

@staticmethod
def _flush(
resampler: soxr.ResampleStream,
Expand Down Expand Up @@ -320,7 +349,7 @@ def get_value(self) -> np.ndarray:

"""
return normalize(
values=self.get_raw_value(),
values=self.get_filtered_value(),
normalization=self.normalization,
**self.normalization_values,
)
Expand Down Expand Up @@ -547,9 +576,13 @@ def to_dict(self) -> dict:
None if self.instrument is None else self.instrument.to_dict()
),
}
butter_dict = {
"butter": (None if self.butter is None else self.butter.to_dict()),
}
return (
base_dict
| instrument_dict
| butter_dict
| {
"sample_rate": self.sample_rate,
"normalization": self.normalization.value,
Expand Down Expand Up @@ -595,6 +628,11 @@ def _from_base_dict(
if dictionary["instrument"] is None
else Instrument.from_dict(dictionary["instrument"])
)
butter = (
None
if "butter" not in dictionary or dictionary["butter"] is None
else Butterworth.from_dict(dictionary["butter"])
)
return cls.from_files(
files=files,
begin=begin,
Expand All @@ -603,6 +641,7 @@ def _from_base_dict(
sample_rate=dictionary["sample_rate"],
normalization=Normalization(dictionary["normalization"]),
normalization_values=dictionary["normalization_values"],
butter=butter,
)

@classmethod
Expand Down Expand Up @@ -641,6 +680,9 @@ def from_files(
normalization: Normalization
The type of normalization to apply to the audio data.

butter: Butterworth
Butterworth filter to apply to the audio data.

Returns
-------
Self:
Expand Down
21 changes: 20 additions & 1 deletion src/osekit/core/audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from osekit.core.audio_file import AudioFile
from osekit.core.base_dataset import BaseDataset
from osekit.core.json_serializer import deserialize_json
from osekit.utils.audio import Normalization
from osekit.utils.audio import Butterworth, Normalization
from osekit.utils.multiprocess import multiprocess

if TYPE_CHECKING:
Expand Down Expand Up @@ -89,6 +89,17 @@ def normalization(self, normalization: Normalization) -> None:
for data in self.data:
data.normalization = normalization

@property
def butter(self) -> Butterworth:
"""Return the most frequent Butterworth filter among those of this dataset data."""
butters = [data.butter for data in self.data]
return max(set(butters), key=butters.count)

@butter.setter
def butter(self, butter: Butterworth) -> None:
for data in self.data:
data.butter = butter

@property
def instrument(self) -> Instrument | None:
"""Instrument that can be used to get acoustic pressure from wav audio data."""
Expand Down Expand Up @@ -187,6 +198,7 @@ def from_folder( # noqa: PLR0913
name: str | None = None,
instrument: Instrument | None = None,
normalization: Normalization = Normalization.RAW,
butter: Butterworth | None = None,
**kwargs, # noqa: ANN003
) -> Self:
"""Return an ``AudioDataset`` from a folder containing the audio files.
Expand Down Expand Up @@ -240,6 +252,8 @@ def from_folder( # noqa: PLR0913
the wav audio data.
normalization: Normalization
The type of normalization to apply to the audio data.
butter: Butterworth | None
Butterworth filter to apply to the audio data.
kwargs: any
Keyword arguments passed to the ``BaseDataset.from_folder()`` classmethod.

Expand All @@ -262,6 +276,7 @@ def from_folder( # noqa: PLR0913
name=name,
instrument=instrument,
normalization=normalization,
butter=butter,
)

@classmethod
Expand All @@ -277,6 +292,7 @@ def from_files( # noqa: PLR0913
sample_rate: float | None = None,
instrument: Instrument | None = None,
normalization: Normalization = Normalization.RAW,
butter: Butterworth | None = None,
) -> AudioDataset:
"""Return an AudioDataset object from a list of AudioFiles.

Expand Down Expand Up @@ -317,6 +333,8 @@ def from_files( # noqa: PLR0913
the wav audio data.
normalization: Normalization
The type of normalization to apply to the audio data.
butter: Butterworth | None
Butterworth filter to apply to the audio data.

Returns
-------
Expand All @@ -335,6 +353,7 @@ def from_files( # noqa: PLR0913
mode=mode,
overlap=overlap,
data_duration=data_duration,
butter=butter,
)

@classmethod
Expand Down
1 change: 1 addition & 0 deletions src/osekit/public/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def prepare_audio(self, transform: Transform) -> AudioDataset:
mode=transform.mode,
overlap=transform.overlap,
normalization=transform.normalization,
butter=transform.butter,
name=transform.name,
instrument=self.instrument,
)
Expand Down
6 changes: 5 additions & 1 deletion src/osekit/public/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Flag, auto
from typing import TYPE_CHECKING, Literal

from osekit.utils.audio import Normalization
from osekit.utils.audio import Butterworth, Normalization

if TYPE_CHECKING:
from pandas import Timedelta, Timestamp
Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
overlap: float = 0.0,
sample_rate: float | None = None,
normalization: Normalization = Normalization.RAW,
butter: Butterworth | None = None,
name: str | None = None,
subtype: str | None = None,
fft: ShortTimeFFT | None = None,
Expand Down Expand Up @@ -118,6 +119,8 @@ def __init__(
will be set to the one of the original dataset.
normalization: Normalization
The type of normalization to apply to the audio data.
butter: Butterworth | None
Butterworth filter to apply to the audio data.
name: str | None
Name of the transform dataset.
Defaulted as the begin timestamp of the transform dataset.
Expand Down Expand Up @@ -160,6 +163,7 @@ def __init__(
self.sample_rate = sample_rate
self.name = name
self.normalization = normalization
self.butter = butter
self.subtype = subtype
self.v_lim = v_lim
self.colormap = colormap
Expand Down
101 changes: 101 additions & 0 deletions src/osekit/utils/audio.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import dataclasses
import enum
from collections.abc import Iterable
from typing import Literal, Self

import numpy as np
import soxr
from pandas import Timedelta
from scipy import signal

from osekit.config import (
resample_quality_settings,
Expand Down Expand Up @@ -203,3 +206,101 @@ def normalize(
if Normalization.ZSCORE in normalization:
values = normalize_zscore(values=values, mean=mean, std=std)
return values


@dataclasses.dataclass
class Butterworth:
"""Class that represent a Butterworth sos filter.

Parameters
----------
N: int
The order of the filter.
For "bandpass" and "bandstop" filters, the resulting order of the final
second-order sections ("sos") matrix is ``2*N``,
with ``N`` the number of biquad sections of the desired system.
Wn: Iterable | int | float
The critical frequency or frequencies.
For lowpass and highpass filters, ``Wn`` is a scalar.
For bandpass and bandstop filters, ``Wn`` is a length-2 sequence.
For a Butterworth filter, this is the point at which the gain
drops to ``1/sqrt(2)`` that of the passband (the “-3 dB point”).
For digital filters, if ``fs`` is not specified,
``Wn`` units are normalized from ``0`` to ``1``,
where ``1`` is the Nyquist frequency
(``Wn`` is thus in half cycles / sample and defined as
``2*critical frequencies / fs``).
If ``fs`` is specified, ``Wn`` is in the same units as ``fs``.
For analog filters, ``Wn`` is an angular frequency (e.g. ``rad/s``).
btype: Literal["lowpass", "highpass", "bandpass", "bandstop"]
The type of filter. Default is "lowpass".

"""

N: int
Wn: Iterable | int | float
btype: Literal["lowpass", "highpass", "bandpass", "bandstop"] = "lowpass"

def to_dict(self) -> dict:
"""Serialize a Butterworth sos filter to a dictionary.

Returns
-------
dict:
Serialized Butterworth sos filter.

"""
return {
"N": self.N,
"Wn": self.Wn,
"btype": self.btype,
}

@classmethod
def from_dict(cls, data: dict) -> Butterworth:
"""Deserialize a Butterworth sos filter from a dictionary.

Parameters
----------
data: dict
Serialized Butterworth sos filter.

Returns
-------
Butterworth:
The Butterworth sos filter.

"""
return cls(
N=data["N"],
Wn=data["Wn"],
btype=data["btype"],
)

def filter(self, sig: np.typing.NDArray, fs: float) -> np.typing.NDArray:
"""Filter an input signal with the Butterworth sos filter.

Parameters
----------
sig: np.typing.NDArray
Input signal
fs: float
Sampling frequency of the signal

Returns
-------
np.typing.NDArray
Filtered signal

"""
sos = signal.butter(
N=self.N,
Wn=self.Wn,
btype=self.btype,
fs=fs,
output="sos",
)
return signal.sosfilt(sos=sos, x=sig, axis=0)

def __hash__(self) -> int:
return hash((self.N, self.Wn, self.btype))
Loading