Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 331 additions & 0 deletions tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@
"""Tests for collections"""

import copy
import json
import os
import shutil
import tempfile
import unittest
from pathlib import Path

import jsonschema
import numpy as np
from hypothesis import given
from hypothesis import strategies as st

from sigmf import schema
from sigmf.archive import SIGMF_COLLECTION_EXT, SIGMF_DATASET_EXT, SIGMF_METADATA_EXT
from sigmf.error import SigMFFileError, SigMFFileExistsError
from sigmf.sigmffile import SigMFCollection, SigMFFile, fromfile

from .testdata import TEST_FLOAT32_DATA, TEST_METADATA
Expand Down Expand Up @@ -74,3 +78,330 @@ def test_load_collection(self, subdir: str) -> None:

self.assertTrue(np.array_equal(TEST_FLOAT32_DATA, meta1_loopback.read_samples()))
self.assertTrue(np.array_equal(TEST_FLOAT32_DATA, meta2_loopback[:]))


class TestCollectionConstructor(unittest.TestCase):
"""tests for SigMFCollection constructor"""

def test_empty_constructor(self):
"""test that a collection can be created with no arguments"""
collection = SigMFCollection(skip_checksums=True)
self.assertIsInstance(collection, SigMFCollection)
self.assertEqual(len(collection), 0)
self.assertEqual(collection.get_stream_names(), [])

def test_constructor_with_metadata(self):
"""test that a collection can be created with a metadata dict"""
from sigmf import __specification__

metadata = {
SigMFCollection.COLLECTION_KEY: {
SigMFCollection.VERSION_KEY: __specification__,
SigMFCollection.STREAMS_KEY: [],
}
}
collection = SigMFCollection(metadata=metadata, skip_checksums=True)
self.assertIsInstance(collection, SigMFCollection)
self.assertEqual(len(collection), 0)


class TestCollectionRoundTrip(unittest.TestCase):
"""tests for SigMFCollection round-trip write/read"""

def setUp(self):
"""create temporary directory and populate with SigMF files"""
self.temp_dir = Path(tempfile.mkdtemp())
# create two SigMF recordings
meta_name1 = "stream0" + SIGMF_METADATA_EXT
meta_name2 = "stream1" + SIGMF_METADATA_EXT
data_path1 = self.temp_dir / ("stream0" + SIGMF_DATASET_EXT)
data_path2 = self.temp_dir / ("stream1" + SIGMF_DATASET_EXT)
TEST_FLOAT32_DATA.tofile(data_path1)
TEST_FLOAT32_DATA.tofile(data_path2)
meta1 = SigMFFile(metadata=copy.deepcopy(TEST_METADATA), data_file=data_path1)
meta2 = SigMFFile(metadata=copy.deepcopy(TEST_METADATA), data_file=data_path2)
meta1.tofile(self.temp_dir / meta_name1, overwrite=True)
meta2.tofile(self.temp_dir / meta_name2, overwrite=True)
self.meta_name1 = meta_name1
self.meta_name2 = meta_name2
self.collection_path = self.temp_dir / ("mycollection" + SIGMF_COLLECTION_EXT)

def tearDown(self):
"""remove temporary directory"""
shutil.rmtree(self.temp_dir)

def test_round_trip_metadata(self):
"""test that collection metadata survives a write/read round-trip"""
collection = SigMFCollection(
metafiles=[self.meta_name1, self.meta_name2],
base_path=str(self.temp_dir),
)
collection.set_collection_field(SigMFCollection.AUTHOR_KEY, "Round Trip Tester")
collection.set_collection_field(SigMFCollection.DESCRIPTION_KEY, "A round-trip test collection")
collection.set_collection_field(SigMFCollection.LICENSE_KEY, "https://creativecommons.org/licenses/by-sa/4.0/")

collection.tofile(self.collection_path)

# read back
collection_rt = fromfile(self.collection_path)

self.assertIsInstance(collection_rt, SigMFCollection)
self.assertEqual(len(collection_rt), 2)
self.assertEqual(collection_rt.get_stream_names(), ["stream0", "stream1"])
self.assertEqual(collection_rt.get_collection_field(SigMFCollection.AUTHOR_KEY), "Round Trip Tester")
self.assertEqual(
collection_rt.get_collection_field(SigMFCollection.DESCRIPTION_KEY), "A round-trip test collection"
)
self.assertEqual(
collection_rt.get_collection_field(SigMFCollection.LICENSE_KEY),
"https://creativecommons.org/licenses/by-sa/4.0/",
)

def test_round_trip_collection_info(self):
"""test that get_collection_info returns a dict matching what was set"""
collection = SigMFCollection(
metafiles=[self.meta_name1, self.meta_name2],
base_path=str(self.temp_dir),
)
collection.set_collection_field(SigMFCollection.AUTHOR_KEY, "Test Author")
collection.tofile(self.collection_path)

collection_rt = fromfile(self.collection_path)
info = collection_rt.get_collection_info()
self.assertIsInstance(info, dict)
self.assertIn(SigMFCollection.AUTHOR_KEY, info)
self.assertEqual(info[SigMFCollection.AUTHOR_KEY], "Test Author")
self.assertIn(SigMFCollection.VERSION_KEY, info)
self.assertIn(SigMFCollection.STREAMS_KEY, info)

def test_round_trip_json_content(self):
"""test that the written collection file is valid JSON with expected structure"""
collection = SigMFCollection(
metafiles=[self.meta_name1, self.meta_name2],
base_path=str(self.temp_dir),
)
collection.tofile(self.collection_path)

with open(self.collection_path, "r") as f:
data = json.load(f)

self.assertIn(SigMFCollection.COLLECTION_KEY, data)
self.assertIn(SigMFCollection.STREAMS_KEY, data[SigMFCollection.COLLECTION_KEY])
self.assertIn(SigMFCollection.VERSION_KEY, data[SigMFCollection.COLLECTION_KEY])
streams = data[SigMFCollection.COLLECTION_KEY][SigMFCollection.STREAMS_KEY]
self.assertEqual(len(streams), 2)
for stream in streams:
self.assertIn("name", stream)
self.assertIn("hash", stream)


class TestCollectionValidation(unittest.TestCase):
"""tests for SigMFCollection validation against the JSON schema"""

def _validate(self, metadata):
"""helper: validate collection metadata against the collection schema"""
col_schema = schema.get_schema(schema_file=schema.SCHEMA_COLLECTION)
jsonschema.validators.validate(instance=metadata, schema=col_schema)

def test_valid_empty_collection(self):
"""a minimal collection with only core:version should be schema-valid"""
collection = SigMFCollection(skip_checksums=True)
self._validate(collection._metadata)

def test_valid_collection_with_optional_fields(self):
"""a collection with optional fields set should be schema-valid"""
collection = SigMFCollection(skip_checksums=True)
collection.set_collection_field(SigMFCollection.AUTHOR_KEY, "Test Author")
collection.set_collection_field(SigMFCollection.DESCRIPTION_KEY, "Test description")
collection.set_collection_field(SigMFCollection.LICENSE_KEY, "https://example.com/license")
collection.set_collection_field(SigMFCollection.COLLECTION_DOI_KEY, "10.1000/xyz123")
self._validate(collection._metadata)

def test_invalid_collection_missing_version(self):
"""a collection missing core:version should fail schema validation"""
metadata = {SigMFCollection.COLLECTION_KEY: {}}
col_schema = schema.get_schema(schema_file=schema.SCHEMA_COLLECTION)
with self.assertRaises(jsonschema.exceptions.ValidationError):
jsonschema.validators.validate(instance=metadata, schema=col_schema)

def test_invalid_collection_missing_collection_key(self):
"""a metadata dict without the top-level 'collection' key should fail"""
metadata = {}
col_schema = schema.get_schema(schema_file=schema.SCHEMA_COLLECTION)
with self.assertRaises(jsonschema.exceptions.ValidationError):
jsonschema.validators.validate(instance=metadata, schema=col_schema)

def test_valid_collection_with_extensions(self):
"""a collection with a valid extensions array should be schema-valid"""
collection = SigMFCollection(skip_checksums=True)
collection.set_collection_field(
SigMFCollection.EXTENSIONS_KEY,
[{"name": "antenna", "version": "1.0.0", "optional": True}],
)
self._validate(collection._metadata)


class TestCollectionCommonUseCases(unittest.TestCase):
"""tests for common SigMFCollection use cases"""

def setUp(self):
"""create temporary directory and two SigMF recordings"""
self.temp_dir = Path(tempfile.mkdtemp())
for name in ("rec0", "rec1", "rec2"):
data_path = self.temp_dir / (name + SIGMF_DATASET_EXT)
meta_path = self.temp_dir / (name + SIGMF_METADATA_EXT)
TEST_FLOAT32_DATA.tofile(data_path)
meta = SigMFFile(metadata=copy.deepcopy(TEST_METADATA), data_file=data_path)
meta.tofile(meta_path, overwrite=True)
self.metafiles = [f"{name}{SIGMF_METADATA_EXT}" for name in ("rec0", "rec1", "rec2")]

def tearDown(self):
shutil.rmtree(self.temp_dir)

def _make_collection(self, metafiles=None):
"""helper: create a SigMFCollection using the temp dir"""
if metafiles is None:
metafiles = self.metafiles
return SigMFCollection(metafiles=metafiles, base_path=str(self.temp_dir))

def test_len(self):
"""__len__ should return the number of streams"""
collection = self._make_collection()
self.assertEqual(len(collection), 3)

def test_len_empty(self):
"""an empty collection should have length 0"""
collection = SigMFCollection(skip_checksums=True)
self.assertEqual(len(collection), 0)

def test_get_stream_names(self):
"""get_stream_names should return base names in order"""
collection = self._make_collection()
names = collection.get_stream_names()
self.assertEqual(names, ["rec0", "rec1", "rec2"])

def test_get_sigmffile_by_index(self):
"""get_SigMFFile with stream_index should return correct SigMFFile"""
collection = self._make_collection()
sf = collection.get_SigMFFile(stream_index=0)
self.assertIsInstance(sf, SigMFFile)
self.assertTrue(np.array_equal(TEST_FLOAT32_DATA, sf.read_samples()))

def test_get_sigmffile_by_name(self):
"""get_SigMFFile with stream_name should return correct SigMFFile"""
collection = self._make_collection()
sf = collection.get_SigMFFile(stream_name="rec1")
self.assertIsInstance(sf, SigMFFile)
self.assertTrue(np.array_equal(TEST_FLOAT32_DATA, sf.read_samples()))

def test_get_sigmffile_invalid_name(self):
"""get_SigMFFile with an unknown stream_name should return None"""
collection = self._make_collection()
result = collection.get_SigMFFile(stream_name="nonexistent")
self.assertIsNone(result)

def test_set_get_collection_field(self):
"""set_collection_field and get_collection_field should round-trip values"""
collection = self._make_collection()
collection.set_collection_field(SigMFCollection.AUTHOR_KEY, "Jane Doe")
self.assertEqual(collection.get_collection_field(SigMFCollection.AUTHOR_KEY), "Jane Doe")

def test_get_collection_field_default(self):
"""get_collection_field should return default when key is absent"""
collection = self._make_collection()
result = collection.get_collection_field("core:nonexistent_key", default="fallback")
self.assertEqual(result, "fallback")

def test_set_get_collection_info(self):
"""set_collection_info and get_collection_info should round-trip a dict"""
from sigmf import __specification__

collection = self._make_collection()
new_info = {
SigMFCollection.VERSION_KEY: __specification__,
SigMFCollection.AUTHOR_KEY: "Info Author",
SigMFCollection.STREAMS_KEY: collection.get_collection_field(SigMFCollection.STREAMS_KEY),
}
collection.set_collection_info(new_info)
info = collection.get_collection_info()
self.assertEqual(info[SigMFCollection.AUTHOR_KEY], "Info Author")

def test_overwrite_protection(self):
"""writing a collection to an existing file without overwrite=True should raise"""
collection_path = self.temp_dir / ("test" + SIGMF_COLLECTION_EXT)
collection = self._make_collection()
collection.tofile(collection_path)
with self.assertRaises(SigMFFileExistsError):
collection.tofile(collection_path)

def test_overwrite_allowed(self):
"""writing with overwrite=True should succeed even if file exists"""
collection_path = self.temp_dir / ("test" + SIGMF_COLLECTION_EXT)
collection = self._make_collection()
collection.tofile(collection_path)
collection.tofile(collection_path, overwrite=True)
self.assertTrue(collection_path.exists())

def test_skip_checksums(self):
"""skip_checksums=True should allow creating a collection without verifying hashes"""
collection_path = self.temp_dir / ("test" + SIGMF_COLLECTION_EXT)
collection = self._make_collection()
collection.tofile(collection_path)
# test via SigMFCollection constructor with skip_checksums=True
with open(collection_path, "r") as f:
metadata = json.load(f)
collection_loaded = SigMFCollection(metadata=metadata, base_path=str(self.temp_dir), skip_checksums=True)
self.assertIsInstance(collection_loaded, SigMFCollection)
self.assertEqual(len(collection_loaded), 3)

def test_verify_stream_hashes_valid(self):
"""verify_stream_hashes should not raise when hashes are correct"""
collection = self._make_collection()
# should not raise
collection.verify_stream_hashes()

def test_verify_stream_hashes_invalid(self):
"""verify_stream_hashes should raise when a stream hash is wrong"""
collection = self._make_collection()
# corrupt the hash of the first stream
streams = collection.get_collection_field(SigMFCollection.STREAMS_KEY)
streams[0]["hash"] = "badhash"
collection.set_collection_field(SigMFCollection.STREAMS_KEY, streams)
with self.assertRaises(SigMFFileError):
collection.verify_stream_hashes()

def test_error_on_nonexistent_metafile(self):
"""constructing a collection with a non-existent file should raise SigMFFileError"""
with self.assertRaises(SigMFFileError):
SigMFCollection(
metafiles=["does_not_exist" + SIGMF_METADATA_EXT],
base_path=str(self.temp_dir),
)

def test_error_on_non_meta_extension(self):
"""constructing a collection with a file lacking .sigmf-meta extension should raise"""
with self.assertRaises(SigMFFileError):
SigMFCollection(
metafiles=["rec0" + SIGMF_DATASET_EXT],
base_path=str(self.temp_dir),
)

def test_set_streams_updates_hashes(self):
"""set_streams should recompute hashes for the specified metafiles"""
collection = self._make_collection(metafiles=["rec0" + SIGMF_METADATA_EXT])
self.assertEqual(len(collection), 1)
# add more streams
collection.set_streams(["rec0" + SIGMF_METADATA_EXT, "rec1" + SIGMF_METADATA_EXT])
self.assertEqual(len(collection), 2)
names = collection.get_stream_names()
self.assertIn("rec0", names)
self.assertIn("rec1", names)

def test_collection_dumps_is_valid_json(self):
"""dumps() should produce valid JSON containing collection data"""
collection = self._make_collection()
s = collection.dumps()
data = json.loads(s)
self.assertIn(SigMFCollection.COLLECTION_KEY, data)
self.assertIn(SigMFCollection.STREAMS_KEY, data[SigMFCollection.COLLECTION_KEY])
self.assertEqual(len(data[SigMFCollection.COLLECTION_KEY][SigMFCollection.STREAMS_KEY]), 3)