Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ lint.select = ["ALL"]
"docs/source/*.ipynb" = [
"E402", # Module level import not at top of cell in doc notebooks
]
"*" = [
"ANN401" # typing.Any type hint for *args and **kwargs
]


[tool.pytest.ini_options]
Expand Down
3 changes: 0 additions & 3 deletions src/osekit/audio_backend/audio_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ def read(
"""
_, frames, _ = self.info(path)

if stop is None:
stop = frames

if stop is None:
stop = frames

Expand Down
54 changes: 1 addition & 53 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@
from pathlib import Path
from unittest.mock import MagicMock

import numpy as np
import pandas as pd
import pytest
import soundfile as sf
from pandas import Timestamp

from osekit import config
from osekit.audio_backend.soundfile_backend import SoundFileBackend
from osekit.config import (
TIMESTAMP_FORMAT_EXPORTED_FILES_LOCALIZED,
TIMESTAMP_FORMAT_EXPORTED_FILES_UNLOCALIZED,
)
from osekit.core.audio_data import AudioData
from osekit.core.audio_file import AudioFile
from osekit.utils.audio import generate_sample_audio

Expand Down Expand Up @@ -110,7 +107,7 @@ def patch_filehandlers(
if "allow_log_write_to_file" in request.keywords:
return

def disabled_filewrite(self: any, record: any) -> None:
def disabled_filewrite(self: typing.Any, record: typing.Any) -> None:
"""Prevent the logger from actually writing files."""

monkeypatch.setattr(logging.FileHandler, "emit", disabled_filewrite)
Expand Down Expand Up @@ -167,55 +164,6 @@ def mock_open(self: SoundFileBackend, path: Path) -> None:
return opened_files


@pytest.fixture
def patch_audio_data(monkeypatch: pytest.MonkeyPatch) -> None:
original_init = AudioData.__init__
original_get_raw_value = AudioData.get_raw_value
original_length = AudioData.length

def mocked_init(
self: AudioData,
*args: list,
mocked_value: list[float] | np.ndarray | None = None,
**kwargs: dict,
) -> None:
defaults = {
"begin": Timestamp("2000-01-01 00:00:00"),
"end": Timestamp("2000-01-01 00:00:01"),
"sample_rate": 48000,
}
for key, value in defaults.items():
if key not in kwargs:
kwargs.update(**{key: value})

original_init(self, *args, **kwargs)
if mocked_value is not None:
self.mocked_value = mocked_value
if type(mocked_value) is list or len(mocked_value.shape) == 1:
self.mocked_value = np.array(self.mocked_value).reshape(
len(mocked_value),
1,
)

def mocked_length(self: AudioData) -> int:
if hasattr(self, "mocked_value"):
return len(self.mocked_value)
return original_length.fget(self)

def mocked_get_raw_value(self: AudioData) -> np.ndarray:
if hasattr(self, "mocked_value"):
return self.mocked_value
return original_get_raw_value(self)

monkeypatch.setattr(AudioData, "__init__", mocked_init)
monkeypatch.setattr(AudioData, "length", property(mocked_length))
monkeypatch.setattr(
AudioData,
"get_raw_value",
mocked_get_raw_value,
)


@pytest.fixture(autouse=True)
def restore_config() -> typing.Generator:
resample_quality_settings = {**config.resample_quality_settings}
Expand Down
39 changes: 39 additions & 0 deletions tests/helpers/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import typing

import numpy as np
from pandas import Timestamp

from osekit.core.audio_data import AudioData


class MockedAudioData(AudioData):
def __init__(
self,
mocked_value: list[float] | np.ndarray,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
defaults = {
"begin": Timestamp("2000-01-01 00:00:00"),
"end": Timestamp("2000-01-01 00:00:01"),
"sample_rate": 48000,
}
for key, value in defaults.items():
if key not in kwargs:
kwargs.update(**{key: value})

super().__init__(*args, **kwargs)
if mocked_value is not None:
self.mocked_value = mocked_value
if type(mocked_value) is list or len(mocked_value.shape) == 1:
self.mocked_value = np.array(self.mocked_value).reshape(
len(mocked_value),
1,
)

@property
def length(self) -> int:
return len(self.mocked_value)

def get_raw_value(self) -> np.ndarray:
return self.mocked_value
30 changes: 18 additions & 12 deletions tests/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@
from osekit.core.instrument import Instrument
from osekit.utils import audio
from osekit.utils.audio import Normalization, generate_sample_audio, normalize
from tests.helpers.audio import MockedAudioData


def test_patch_audio_data(patch_audio_data: None) -> None:
def test_mocked_audio_data() -> None:
mocked_value = [1.0, 2.0, 3.0]
audio_data = AudioData(
mocked_value=mocked_value, # Type: ignore # Unexpected argument
audio_data = MockedAudioData(
mocked_value=mocked_value,
)
assert np.array_equal(
audio_data.mocked_value[:, 0], # Type: ignore # Unresolved attribute
audio_data.mocked_value[:, 0],
mocked_value,
)

Expand Down Expand Up @@ -1534,10 +1535,16 @@ def test_audio_dataset_from_folder_errors_warnings(
assert all(f in caplog.text for f in corrupted_audio_files)


def test_audio_dataset_instrument(patch_audio_data: None) -> None:
def test_audio_dataset_instrument() -> None:
ad = [
AudioData(mocked_value=[1, 2, 3], instrument=Instrument(end_to_end_db=150.0)),
AudioData(mocked_value=[4, 5, 6], instrument=Instrument(end_to_end_db=150.0)),
MockedAudioData(
mocked_value=[1, 2, 3],
instrument=Instrument(end_to_end_db=150.0),
),
MockedAudioData(
mocked_value=[4, 5, 6],
instrument=Instrument(end_to_end_db=150.0),
),
]

ads = AudioDataset(data=ad)
Expand Down Expand Up @@ -1807,9 +1814,8 @@ def test_split_data(
assert subsubdata.normalization == subdata.normalization


def test_split_data_normalization_pass(patch_audio_data: None) -> None:
ad = AudioData()
ad.mocked_value = [1, 2, 3]
def test_split_data_normalization_pass() -> None:
ad = MockedAudioData(mocked_value=[1, 2, 3])
original_normalization_values = ad.get_normalization_values()
assert all(original_normalization_values.values())

Expand Down Expand Up @@ -1926,9 +1932,9 @@ def test_split_data_frames(
assert np.array_equal(ad.get_value()[:, 0], expected_data)


def test_split_frames_errors(patch_audio_data: None) -> None:
def test_split_frames_errors() -> None:
mocked_value = [1, 2, 3]
ad = AudioData(mocked_value=mocked_value)
ad = MockedAudioData(mocked_value=mocked_value)
error_msgs = [
"Start_frame must be greater than or equal to 0.",
"Stop_frame must be lower than the length of the data.",
Expand Down
Loading