diff --git a/pyproject.toml b/pyproject.toml index 2cc00bd9..58aaa758 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "distro>=1.8.0", "email-validator>=1.1", "imas-python", + "netCDF4>=1.5", "numpy>=1.14", "pydantic>=2.10.6", "python-dateutil>=2.6", @@ -49,7 +50,6 @@ dependencies = [ "semantic-version>=2.8", "sqlalchemy>=1.2.12,<2.0", "alembic~=1.13", - "urllib3>=1.26", "rich>=14.3.3", ] diff --git a/src/simdb/checksum.py b/src/simdb/checksum.py index 3d1bf09c..99aef230 100644 --- a/src/simdb/checksum.py +++ b/src/simdb/checksum.py @@ -1,9 +1,10 @@ import hashlib +from pathlib import Path -from .uri import URI +from simdb.imas.utils import SimDBUrl -def sha1_checksum(uri: URI) -> str: +def sha1_checksum(uri: SimDBUrl) -> str: """Generate a SHA1 checksum from the given file. :param uri: the URI of the file to checksum @@ -13,7 +14,7 @@ def sha1_checksum(uri: URI) -> str: raise ValueError(f"invalid scheme for file checksum: {uri.scheme}") if uri.path is None: raise ValueError("Path is not set") - path = uri.path + path = Path(uri.path) if not path.exists(): raise ValueError("File does not exist") diff --git a/src/simdb/cli/commands/manifest.py b/src/simdb/cli/commands/manifest.py index 954fa5f6..aa007041 100644 --- a/src/simdb/cli/commands/manifest.py +++ b/src/simdb/cli/commands/manifest.py @@ -2,7 +2,7 @@ import click -from simdb.cli.manifest import InvalidManifest, Manifest +from simdb.cli.manifest import Manifest @click.group() @@ -16,13 +16,8 @@ def manifest(): def check(file_name): """Check manifest FILE_NAME.""" - manifest = Manifest() - manifest.load(file_name) - try: - manifest.validate() - click.echo("ok") - except InvalidManifest as err: - click.echo(err, err=True) + Manifest.load_from_file(file_name) + click.echo("ok") @manifest.command() diff --git a/src/simdb/cli/commands/simulation.py b/src/simdb/cli/commands/simulation.py index 7f442ce4..af5545c5 100644 --- a/src/simdb/cli/commands/simulation.py +++ b/src/simdb/cli/commands/simulation.py @@ -7,7 +7,7 @@ import click -from simdb.cli.manifest import InvalidAlias, Manifest +from simdb.cli.manifest import Manifest from simdb.cli.remote_api import RemoteAPI, RemoteError from simdb.config.config import Config from simdb.database import DatabaseError, get_local_db @@ -148,13 +148,11 @@ def simulation_info(config: Config, sim_id: str): def simulation_ingest(config: Config, manifest_file: str, alias: str): """Ingest a MANIFEST_FILE.""" - manifest = Manifest() - manifest.load(Path(manifest_file)) - try: - manifest.validate() - except InvalidAlias: - if not alias: - raise + overrides = {} + if alias: + overrides["alias"] = alias + + manifest = Manifest.load_from_file(Path(manifest_file), overrides=overrides) simulation = Simulation(manifest, config) if alias: diff --git a/src/simdb/cli/manifest.py b/src/simdb/cli/manifest.py index 77ba7d05..7047b25a 100644 --- a/src/simdb/cli/manifest.py +++ b/src/simdb/cli/manifest.py @@ -1,596 +1,306 @@ import os -import re import urllib.parse +import warnings from enum import Enum, auto from pathlib import Path -from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Type, Union +from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, TextIO +from uuid import UUID import numpy as np import yaml +from netCDF4 import Dataset +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + UrlConstraints, + field_validator, + model_validator, +) -from simdb.uri import URI - - -class InvalidManifest(Exception): - """Exception to throw when a manifest fails to validate.""" - - pass - - -class InvalidAlias(InvalidManifest): - """Exception to throw when the alias specified in the manifest is invalid.""" - - pass +from simdb.imas.utils import SimDBUrl def _expand_path(path: Path, base_path: Path) -> Path: os.environ["MANIFEST_DIR"] = str(base_path) path = Path(os.path.expandvars(str(path))).expanduser() - path = Path(str(path).replace("//", "/")) if not path.is_absolute(): if not base_path.is_absolute(): raise ValueError("base_path must be absolute") return base_path / path else: - # Expand any /./ and /../ in absolute path path = path.resolve() return path -def _to_uri(uri_str: str, base_path: Path) -> Tuple["DataObject.Type", "URI"]: - uri = URI(uri_str) - if uri.authority: - raise InvalidManifest(f"invalid uri: {uri_str} - path must be absolute") - if uri.scheme is None: - raise InvalidManifest(f"invalid uri: {uri_str} - no scheme provided") - if uri.scheme == "file": - if uri.path is None: - raise InvalidManifest(f"invalid uri: {uri_str} - no path provided") - uri = URI(uri, path=_expand_path(uri.path, base_path)) - return DataObject.Type.FILE, uri - if uri.scheme == "imas": - if "path" not in uri.query and not all( - ("shot" in uri.query, "run" in uri.query, "database" in uri.query) - ): - raise InvalidManifest( - f"invalid uri: {uri_str} - no path or (shot, run, database) provided " - "in IMAS uri" - ) - return DataObject.Type.IMAS, uri - if uri.scheme == "simdb": - return DataObject.Type.UUID, uri - raise InvalidManifest(f"invalid uri: {uri_str}") - - -class DataObject: - """ - Simulation data object, either a file, an IDS or an already registered object - identifiable by the UUID. +ManifestUrl = Annotated[ + SimDBUrl, UrlConstraints(allowed_schemes=["file", "imas", "simdb"]) +] - PATH: file:/// - IMAS: imas:?path= - """ - class Type(Enum): - UNKNOWN = auto() - UUID = auto() - FILE = auto() - IMAS = auto() +class DataType(Enum): + UNKNOWN = auto() + UUID = auto() + FILE = auto() + IMAS = auto() - type: Type = Type.UNKNOWN - uri: Union[URI, None] = None - def __init__(self, base_path: Path, uri: str) -> None: - (self.type, self.uri) = _to_uri(uri, base_path) - if self.type == DataObject.Type.UNKNOWN or not self.uri: - raise InvalidManifest("invalid input") +def _get_data_object_type(uri: SimDBUrl) -> "DataType": + if uri.scheme == "imas": + return DataType.IMAS + elif uri.scheme == "file": + if uri.path is None: + raise ValueError("no path provided") + if Path(uri.path).suffix == ".nc": + with Dataset(uri.path, "r") as ds: + if getattr(ds, "Conventions", None) == "IMAS": + return DataType.IMAS + return DataType.FILE + elif uri.scheme == "simdb": + return DataType.UUID - @property - def name(self) -> str: - return str(self.uri) + raise ValueError(f"URI scheme ({uri.scheme}:) not recognized") -class Source(DataObject): - """ - Simulation data inputs. - """ +class DataObject(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) - pass + uri: ManifestUrl = Field() + _type: DataType = PrivateAttr(default=DataType.UNKNOWN) -class Sink(DataObject): - """ - Simulation data outputs. - """ + @model_validator(mode="after") + def _resolve_type(self) -> "DataObject": + self._type = _get_data_object_type(self.uri) + return self - pass - - -class ManifestValidator: - """ - Base class for validation of manifests. - """ - - version: int - - def __init__(self, version: int): - self.version = version - - def validate(self, values: Union[List, Dict]) -> None: - pass - - -class ListValuesValidator(ManifestValidator): - """ - Class for the validation of list items in the manifest. - """ - - def __init__( - self, - version: int, - section_name: Optional[str] = None, - expected_keys: Optional[Iterable] = None, - required_keys: Optional[Iterable] = None, - ) -> None: - self.section_name: Optional[str] = section_name - self.expected_keys: Optional[Iterable] = expected_keys - self.required_keys: Optional[Iterable] = required_keys - super().__init__(version) - - def validate(self, values: Union[list, dict]) -> None: - if values is None: - return - if isinstance(values, dict): - raise InvalidManifest( - f"badly formatted manifest - {self.section_name} should be provided as " - "a list" - ) - for item in values: - if not isinstance(item, dict) or len(item) > 1: - raise InvalidManifest( - f"badly formatted manifest - {self.section_name} values should be " - "a name value pair" - ) - name = next(iter(item)) + @property + def name(self) -> str: + return self.uri.encoded_string() - if isinstance(self.required_keys, tuple) and name not in self.required_keys: - raise InvalidManifest( - f"required {self.section_name} key not found in manifest: {name}" + @field_validator("uri", mode="after") + @classmethod + def validate_uri(cls, v: ManifestUrl, info): + context = info.context or {} + base_path = context.get("base_path") + if not base_path: + base_path = Path.cwd() + + if v.path is None: + raise ValueError("no uri path provided") + + if v.scheme == "imas": + qs = dict(v.query_params()) + if "path" not in qs and ( + "shot" not in qs or "run" not in qs or "database" not in qs + ): + raise ValueError( + "no path or (shot, run, database) provided in IMAS uri" ) - -class DictValuesValidator(ManifestValidator): - """ - Class for the validation of dictionary items in the manifest. - """ - - def __init__( - self, - version: int, - section_name: Optional[str] = None, - expected_keys: Optional[Iterable] = None, - required_keys: Optional[Iterable] = None, - ) -> None: - self.section_name: Optional[str] = section_name - self.expected_keys: Optional[Iterable] = expected_keys - self.required_keys: Optional[Iterable] = required_keys - super().__init__(version) - - def validate(self, values: Union[list, dict]) -> None: - if isinstance(values, list): - raise InvalidManifest( - f"badly formatted manifest - {self.section_name} should be provided as " - "a dict" + elif v.scheme == "file": + v = v.build( + scheme="file", + path=_expand_path(Path(v.path), base_path).as_posix(), ) - if self.expected_keys is not None: - for key in values: - if key not in self.expected_keys: - if re.match(r"code[0-9]+", key): - for code_key in values[key]: - if code_key not in ("name", "repo", "commit"): - raise InvalidManifest( - f"unknown {self.section_name}.{key} key in" - f"manifest: {code_key}" - ) - else: - raise InvalidManifest( - f"unknown {self.section_name} key in manifest: {key}" - ) - - if self.required_keys is not None: - for key in self.required_keys: - if isinstance(self.expected_keys, list) and key not in values: - raise InvalidManifest( - f"required {self.section_name} key not found in manifest: {key}" - ) - - -class DataObjectValidator(ListValuesValidator): - """ - Validator for the manifest data objects (inputs or outputs). - """ - - def __init__(self, version: int, section_name: str) -> None: - if version == 0: - expected_keys = ("uuid", "path", "imas") - elif version > 0: - expected_keys = ("uri",) - else: - raise KeyError("Invalid version.") - super().__init__(version, section_name, expected_keys) - - def validate(self, values: Union[list, dict]) -> None: - super().validate(values) - if values is None: - return - seen_uris = set() - for value in values: - if self.version > 0: - uri = URI(value["uri"]) - if uri.scheme not in ("file", "imas"): - raise InvalidManifest(f"unknown uri scheme: {uri.scheme}") - if str(uri) in seen_uris: - raise InvalidManifest( - f"Duplicate URI found in {self.section_name}: {uri}" - ) - seen_uris.add(str(uri)) - - -class InputsValidator(DataObjectValidator): - """ - Validator for the manifest inputs list. - """ + elif v.scheme == "simdb": + _ = UUID(v.path) - def __init__(self, version): - super().__init__(version, "inputs") - - -class OutputsValidator(DataObjectValidator): - """ - Validator for the manifest outputs list. - """ - - def __init__(self, version): - super().__init__(version, "outputs") - - -class VersionValidator(ManifestValidator): - """ - Validator for manifest version. - """ - - def __init__(self, version: int): - super().__init__(version) - - def validate(self, values: Union[List, Dict]) -> None: - if not isinstance(values, int): - raise InvalidManifest("version must be an integer") - - -class AliasValidator(ManifestValidator): - """ - Validator for simulation alias. - """ - - def __init__(self, version: int): - super().__init__(version) - - def validate(self, values: Union[List, Dict]) -> None: - if not isinstance(values, str): - raise InvalidManifest("alias must be a string") - if urllib.parse.quote(values) != values: - raise InvalidAlias(f"illegal characters in alias: {values}") + return v + @property + def type(self): + return self._type -class DescriptionValidator(ManifestValidator): - """ - Validator for simulation description. - """ +class Source(DataObject): pass -class ResponsibleValidator(ManifestValidator): - """ - Validator for simulation Responsible. - """ - +class Sink(DataObject): pass -def ndarray_constructor( - loader: yaml.SafeLoader, node: yaml.nodes.MappingNode -) -> np.ndarray: - mapping = loader.construct_mapping(node, deep=True) - return np.array(mapping["data"], mapping.get("dtype", None)) +class Manifest(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + manifest_version: Literal[2] = Field(default=2) + alias: Optional[str] = None + responsible_name: Optional[str] = None + inputs_raw: List[Source] = Field(default_factory=list, alias="inputs") + outputs_raw: List[Sink] = Field(default_factory=list, alias="outputs") + metadata_raw: List[Dict[str, Any]] = Field(default_factory=list, alias="metadata") -def get_loader() -> Type[yaml.SafeLoader]: - loader = yaml.SafeLoader - loader.add_constructor("!ndarray", ndarray_constructor) - return loader + _path: Path = PrivateAttr(default_factory=Path) + _inputs: List[Source] = PrivateAttr(default_factory=list) + _outputs: List[Sink] = PrivateAttr(default_factory=list) + _metadata: Dict[str, Any] = PrivateAttr(default_factory=dict) + @model_validator(mode="before") + @classmethod + def check_deprecated_version(cls, data: Any) -> Any: + if isinstance(data, dict) and "version" in data: + warnings.warn( + "The 'version' field is deprecated and will be removed " + "in a future version. Please use 'manifest_version' instead.", + DeprecationWarning, + stacklevel=3, + ) + if "manifest_version" not in data: + data["manifest_version"] = data.pop("version") + return data -class MetaDataValidator(ListValuesValidator): - """ - Validator for the manifest Metadata list. - """ + @field_validator("alias") + @classmethod + def validate_alias(cls, v: Optional[str]) -> Optional[str]: + if v is not None and urllib.parse.quote(v) != v: + raise ValueError(f"illegal characters in alias: {v}") + return v - forbidden_characters = (":", "=", "#") + @field_validator("metadata_raw") + @classmethod + def validate_metadata(cls, v: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + forbidden_characters = {":", "=", "#"} + for item in v: + if len(item) != 1: + raise ValueError("metadata values should be a name value pair") + name = next(iter(item)) + bad_chars = set(name).intersection(forbidden_characters) + if bad_chars: + raise ValueError( + f"invalid metadata field name {name} - " + f"contains forbidden character(s): {', '.join(bad_chars)}" + ) + return v - def __init__(self, version: int) -> None: - section_name = "metadata" - required_keys = ("machine", "code", "description") - super().__init__(version, section_name, required_keys) + @field_validator("inputs_raw", "outputs_raw") + @classmethod + def validate_uris(cls, v: List[DataObject], info) -> List[DataObject]: + seen_uris = set() + for item in v: + uri_str = item.name + if uri_str in seen_uris: + raise ValueError( + "Duplicate URI found in " + f"{info.field_name.replace('_raw', '')}: {uri_str}" + ) + seen_uris.add(uri_str) + return v - def validate(self, values: Union[list, dict]) -> None: - super().validate(values) + @model_validator(mode="after") + def resolve_metadata(self, info) -> "Manifest": + for metadata_item in self.metadata_raw: + self._metadata.update(metadata_item) + return self - for item in values: - name = next(iter(item)) - for char in MetaDataValidator.forbidden_characters: - if char in name: - raise InvalidManifest( - f"invalid metadata field name {name}- contains forbidden " - f"character {char}" - ) + def _resolve_manifest_items(self, items, factory_cls, skip_glob_check): + resolved = [] + for item in items: + if item.type == DataType.FILE and item.uri.path: + path_obj = Path(item.uri.path) + matches = list(path_obj.parent.glob(path_obj.name)) -class WorkflowValidator(DictValuesValidator): - """ - Validator for the manifest workflow dictionary. - """ - - def __init__(self, version: int) -> None: - section_name = "workflow" - if version == 0: - expected_keys = ("name", "git", "repo", "commit", "codes") - required_keys = ("name", "commit", "codes") - elif version == 1: - expected_keys = ( - "name", - "developer", - "date", - "repo", - "commit", - "codes", - "branch", - ) - required_keys = ("name", "repo", "commit", "branch") - else: - raise KeyError("Invalid version.") - super().__init__(version, section_name, expected_keys, required_keys) + if not matches and skip_glob_check: + matches = [path_obj] + if not matches: + raise ValueError(f"No files found matching path {path_obj}") -def _update_dict(old: Dict, new: Dict) -> None: - for k, v in new.items(): - if k in old: - if isinstance(old[k], list): - old[k].append(v) + for p in matches: + resolved.append( + factory_cls( + uri=SimDBUrl.build(scheme="file", path=p.as_posix()) + ) + ) else: - old[k] = [old[k], v] - else: - old[k] = v + resolved.append(item) + return resolved + @model_validator(mode="after") + def resolve_inputs_and_outputs(self, info) -> "Manifest": + context = info.context or {} + skip_glob_check = context.get("skip_glob_check", False) -class Manifest: - """ - Class to handle reading, writing & validation of simulation manifest files. - """ + context.setdefault( + "base_path", + self._path.absolute().parent if self._path != Path() else Path.cwd(), + ) - def __init__(self) -> None: - self._data: Union[Dict, List, None] = None - self._path: Path = Path() - self._metadata: Dict = {} + self._inputs = self._resolve_manifest_items( + self.inputs_raw, Source, skip_glob_check + ) + self._outputs = self._resolve_manifest_items( + self.outputs_raw, Sink, skip_glob_check + ) - @property - def metadata(self) -> Dict: - return self._metadata + return self @classmethod - def from_template(cls) -> "Manifest": - """ - Create an empty manifest from a template file. - - :return: A new manifest object. - """ - manifest = cls() - dir_path = Path(__file__).resolve().parent - manifest.load(dir_path / "template.yaml") - return manifest - - @property - def inputs(self) -> Iterable[Source]: - sources = [] - base_path = self._path.absolute().parent - if ( - isinstance(self._data, dict) - and "inputs" in self._data - and self._data["inputs"] - ): - for i in self._data["inputs"]: - source = Source(base_path, i["uri"]) - if source.type == DataObject.Type.FILE: - if source.uri and source.uri.path: - source_path = Path(source.uri.path) - names = [ - str(p) for p in source_path.parent.glob(source_path.name) - ] - if not names: - raise InvalidManifest( - f"No files found matching path {source.uri.path}" - ) - for name in names: - sources.append(Source(base_path, "file://" + name)) - else: - sources.append(source) - return sources - - @property - def outputs(self) -> Iterable[Sink]: - sinks = [] - base_path = self._path.absolute().parent - if isinstance(self._data, dict) and self._data["outputs"]: - for i in self._data["outputs"]: - sink = Sink(base_path, i["uri"]) - if sink.type == DataObject.Type.FILE: - if sink.uri and sink.uri.path: - sink_path = Path(sink.uri.path) - names = [str(p) for p in sink_path.parent.glob(sink_path.name)] - for name in names: - sinks.append(Sink(base_path, "file://" + name)) - else: - sinks.append(sink) - return sinks - - @property - def alias(self) -> Optional[str]: - if isinstance(self._data, dict): - return self._data.get("alias", None) - return None + def _get_loader(cls): + def ndarray_constructor( + loader: yaml.SafeLoader, node: yaml.nodes.MappingNode + ) -> np.ndarray: + mapping = loader.construct_mapping(node, deep=True) + return np.array(mapping["data"], mapping.get("dtype", None)) - @property - def responsible_name(self) -> Optional[str]: - if isinstance(self._data, dict): - return self._data.get("responsible_name", None) - return None + loader = yaml.SafeLoader + loader.add_constructor("!ndarray", ndarray_constructor) + return loader - @property - def version(self) -> int: - if isinstance(self._data, dict): - return self._data.get("version", 2) - return 0 + @classmethod + def from_template(cls) -> "Manifest": + dir_path = Path(__file__).resolve().parent + with (dir_path / "template.yaml").open() as file: + try: + raw_data = yaml.load(file, Loader=cls._get_loader()) + except yaml.YAMLError as err: + raise ValueError("badly formatted manifest") from err - @property - def manifest_version(self) -> int: - if isinstance(self._data, dict): - return self._data.get("manifest_version", 2) - return 0 - - def _load_metadata(self, root_path: Path, path: Path): - try: - if not path.is_absolute(): - root_dir = root_path.absolute().parent - path = root_dir / path - with path.open() as metadata_file: - _update_dict( - self._metadata, yaml.load(metadata_file, Loader=get_loader()) - ) - except yaml.YAMLError as err: - raise InvalidManifest(f"failed to read metadata file {path}") from err - - def _convert_version(self): - if isinstance(self._data, dict) and self.version == 0: - self._convert_metadata() - self._data["inputs"] = self._convert_files(self._data["inputs"]) - self._data["outputs"] = self._convert_files(self._data["outputs"]) - self._data["version"] = 1 - - def _convert_metadata(self) -> None: - if isinstance(self._data, dict): - for item in ("description", "workflow"): - if item in self._data: - self._metadata[item] = self._data[item] - del self._data[item] - - for key, value in self._metadata.items(): - if key == "workflow": - if "git" in value: - value["repo"] = value["git"] - del value["git"] - if "codes" in value: - codes = value["codes"] - new_codes = [] - for code in codes: - for _, v in code.items(): - new_codes.append(v) - value["codes"] = new_codes + model = cls.model_validate(raw_data, context={"skip_glob_check": True}) + model._path = dir_path / "template.yaml" + return model @classmethod - def _convert_files(cls, files: List[Dict[str, str]]) -> List[Dict[str, "URI"]]: - scheme_map = { - "uuid": "simdb", - "path": "file", - "imas": "imas", - } - - new_files = [] - for file in files: - for k, v in file.items(): - new_files.append({"uri": URI(scheme=scheme_map[k], path=v)}) - return new_files - - def load(self, file_path: Path) -> None: - """ - Load a manifest from the given file. - - :param file_path: Path to the file read. - :return: None - """ - - self._path: Path = file_path + def load_from_file( + cls, file_path: Path, overrides: Optional[dict] = None + ) -> "Manifest": with file_path.open() as file: try: - self._data = yaml.load(file, Loader=get_loader()) + raw_data = yaml.load(file, Loader=cls._get_loader()) except yaml.YAMLError as err: - raise InvalidManifest("badly formatted manifest") from err + raise ValueError("badly formatted manifest") from err - if isinstance(self._data, dict) and "metadata" in self._data: - self._data["metadata"] or [] - self._metadata["metadata"] = self._data["metadata"] + if overrides: + raw_data.update(overrides) + + model = cls.model_validate( + raw_data, context={"base_path": file_path.absolute().parent} + ) + model._path = file_path + return model def save(self, out_file: TextIO) -> None: - """ - Save the manifest to the given file. - - :param out_file: The output text stream to write the manifest to. - :return: None - """ - - yaml.dump(self._data, out_file, default_flow_style=False) - - def validate(self) -> None: - """ - Validate the manifest object. - - :return: None - """ - if self._data is None: - raise InvalidManifest("failed to read manifest") - if isinstance(self._data, list): - raise InvalidManifest( - "badly formatted manifest - top level sections must be keys not a list" - ) + yaml.dump( + self.model_dump(mode="json", by_alias=True, exclude_none=True), + out_file, + default_flow_style=False, + ) - if "manifest_version" not in self._data: - print("warning: no version given in manifest, assuming version 2.") - - version = self.version - - if version == 2: - section_validators = { - "manifest_version": VersionValidator(version), - "alias": AliasValidator(version), - "inputs": InputsValidator(version), - "outputs": OutputsValidator(version), - "metadata": MetaDataValidator(version), - "responsible_name": ResponsibleValidator(version), - } - else: - raise InvalidManifest(f"Unknown manifest version {version}.") - - for section in self._data: - if section not in section_validators: - raise InvalidManifest(f"Unknown manifest section found {section}.") - - required_sections = ("manifest_version", "outputs", "inputs") - for section in required_sections: - if section not in self._data: - raise InvalidManifest( - f"Required manifest section '{section}' not found." - ) + @property + def version(self) -> int: + return self.manifest_version - for name, values in self._data.items(): - section_validators[name].validate(values) - self._convert_version() + @property + def metadata(self) -> Dict[str, Any]: + return self._metadata + + @property + def inputs(self) -> Iterable[Source]: + return self._inputs + + @property + def outputs(self) -> Iterable[Sink]: + return self._outputs diff --git a/src/simdb/cli/remote_api.py b/src/simdb/cli/remote_api.py index a81d082c..b8312060 100644 --- a/src/simdb/cli/remote_api.py +++ b/src/simdb/cli/remote_api.py @@ -34,12 +34,11 @@ from simdb.config import Config from simdb.database.models import Simulation -from simdb.imas.utils import imas_files +from simdb.imas.utils import SimDBUrl, imas_files from simdb.json import CustomDecoder, CustomEncoder from simdb.remote import APIConstants -from simdb.uri import URI -from .manifest import DataObject +from .manifest import DataType if TYPE_CHECKING: from simdb.database.models import File, Simulation, Watcher @@ -145,9 +144,9 @@ def check_return(res: "requests.Response") -> None: def _get_paths(file: "File") -> Iterable[Path]: - if file.type == DataObject.Type.FILE: + if file.type == DataType.FILE: if file.uri and file.uri.path: - return [file.uri.path] + return [Path(file.uri.path)] return [] else: return imas_files(file.uri) @@ -676,7 +675,7 @@ def _push_file( sim_data: Dict[str, Any], chunk_size: int, out_stream: IO, - type: DataObject.Type, + type: DataType, ): msg = f"Uploading file {path} " print(msg, file=out_stream, end="") @@ -690,12 +689,12 @@ def _push_file( if num_chunks == 0: # empty file self._send_chunk(0, b"", chunk_size, uuid, file_type, sim_data) - if type == DataObject.Type.FILE: + if type == DataType.FILE: self.post( "files", data={ "simulation": sim_data, - "obj_type": DataObject.Type.FILE, + "obj_type": DataType.FILE, "files": [ { "chunks": num_chunks, @@ -787,7 +786,7 @@ def push_simulation( copy_ids = options.get("copy_ids", True) for file in simulation.inputs: - if file.type == DataObject.Type.IMAS: + if file.type == DataType.IMAS: if not copy_ids: print(f"Skipping IDS data {file}", file=out_stream, flush=True) continue @@ -833,7 +832,7 @@ def push_simulation( else: if file.uri and file.uri.path: self._push_file( - file.uri.path, + Path(file.uri.path), file.uuid, "input", sim_data, @@ -843,7 +842,7 @@ def push_simulation( ) for file in simulation.outputs: - if file.type == DataObject.Type.IMAS: + if file.type == DataType.IMAS: if not copy_ids: print(f"Skipping IDS data {file}", file=out_stream, flush=True) continue @@ -895,7 +894,7 @@ def push_simulation( else: if file.uri and file.uri.path: self._push_file( - file.uri.path, + Path(file.uri.path), file.uuid, "output", sim_data, @@ -1000,24 +999,28 @@ def pull_simulation( for file in itertools.chain(simulation.inputs, simulation.outputs): info = self._get_file_info(file.uuid) - if file.type == DataObject.Type.FILE: + if file.type == DataType.FILE: (path, checksum) = info[0] rel_path = directory / path.relative_to(common_root) self._pull_file(file.uuid, 0, checksum, path, rel_path, out_stream) - file.uri = URI(file.uri, path=rel_path.absolute()) - elif file.type == DataObject.Type.IMAS: + file.uri = SimDBUrl.build( + scheme="file", path=rel_path.absolute().as_posix() + ) + elif file.type == DataType.IMAS: for index, (path, checksum) in enumerate(info): rel_path = directory / path.relative_to(common_root) self._pull_file( file.uuid, index, checksum, path, rel_path, out_stream ) + qs = dict(file.uri.query_params()) to_path = ( - directory - / Path(file.uri.query.get("path")).relative_to(common_root) + directory / Path(qs.get("path", "")).relative_to(common_root) ).absolute() - backend = file.uri.query.get("backend") - file.uri = URI(f"imas:{backend}?path={to_path}") + backend = qs.get("backend") + file.uri = SimDBUrl.build( + scheme="imas", path=backend, query=f"path={to_path}" + ) return simulation diff --git a/src/simdb/cli/template.yaml b/src/simdb/cli/template.yaml index 08bba616..887f4f00 100644 --- a/src/simdb/cli/template.yaml +++ b/src/simdb/cli/template.yaml @@ -7,7 +7,6 @@ alias: simulation-alias inputs: - uri: file:///home/user/path/to/a/file1 - uri: imas:hdf5?path=/path/to/folder - - uri: imas://host:port/uda?path=/path/to/folder/on/host&backend=hdf5 # Data and log files. outputs: - uri: file:///home/user/path/to/a/file2 diff --git a/src/simdb/database/models/file.py b/src/simdb/database/models/file.py index 6cb8a561..5cd8aab1 100644 --- a/src/simdb/database/models/file.py +++ b/src/simdb/database/models/file.py @@ -7,13 +7,12 @@ from sqlalchemy import Column from sqlalchemy import types as sql_types -from simdb import uri as urilib from simdb.checksum import sha1_checksum -from simdb.cli.manifest import DataObject +from simdb.cli.manifest import DataType from simdb.config.config import Config from simdb.docstrings import inherit_docstrings from simdb.imas.checksum import checksum as imas_checksum -from simdb.imas.utils import imas_files, imas_timestamp +from simdb.imas.utils import SimDBUrl, imas_files, imas_timestamp from simdb.remote.models import FileData, FileGetDataResponse, FileInfo from .base import Base @@ -30,15 +29,15 @@ class File(Base): __tablename__ = "files" id = Column(sql_types.Integer, primary_key=True) uuid = Column(UUID, nullable=False, unique=True, index=True) - uri: urilib.URI = Column(URI(1024), nullable=True) + uri: SimDBUrl = Column(URI(1024), nullable=True) checksum = Column(sql_types.String(64), nullable=True) - type = Column(sql_types.Enum(DataObject.Type), nullable=True) + type = Column(sql_types.Enum(DataType), nullable=True) datetime = Column(sql_types.DateTime, nullable=False) def __init__( self, - type: DataObject.Type, - uri: urilib.URI, + type: DataType, + uri: SimDBUrl, ids_list: Optional[list] = None, perform_integrity_check: bool = True, config: Optional[Config] = None, @@ -49,7 +48,7 @@ def __init__( if perform_integrity_check: self.datetime = self.get_creation_date() - if type == DataObject.Type.IMAS and ids_list is None: + if type == DataType.IMAS and ids_list is None: raise ValueError("IDS list is not set") self.checksum = self.generate_checksum(config, ids_list or []) @@ -76,18 +75,18 @@ def __repr__(self): def generate_checksum(self, config, ids_list: list): if config and config.get_option("development.disable_checksum", default=False): return "" - elif self.type == DataObject.Type.IMAS: + elif self.type == DataType.IMAS: checksum = imas_checksum(self.uri, ids_list) - elif self.type == DataObject.Type.FILE: + elif self.type == DataType.FILE: checksum = sha1_checksum(self.uri) else: raise NotImplementedError(f"Cannot generate checksum for type {self.type}.") return checksum def get_creation_date(self) -> datetime_: - if self.type == DataObject.Type.IMAS: + if self.type == DataType.IMAS: return imas_timestamp(self.uri) - elif self.type == DataObject.Type.FILE: + elif self.type == DataType.FILE: if self.uri.path is None: raise ValueError("Data object uri path not set") return datetime_.fromtimestamp(Path(self.uri.path).stat().st_ctime) @@ -98,9 +97,7 @@ def get_creation_date(self) -> datetime_: def from_data(cls, data: Dict) -> "File": data_type = checked_get(data, "type", str) uri = checked_get(data, "uri", str) - file = File( - DataObject.Type[data_type], urilib.URI(uri), perform_integrity_check=False - ) + file = File(DataType[data_type], SimDBUrl(uri), perform_integrity_check=False) file.uuid = checked_get(data, "uuid", uuid.UUID) file.checksum = checked_get(data, "checksum", str) file.datetime = date_parser.parse(checked_get(data, "datetime", str)) @@ -110,9 +107,7 @@ def from_data(cls, data: Dict) -> "File": def from_data_model(cls, data: FileData) -> "File": data_type = data.type uri = data.uri - file = File( - DataObject.Type[data_type], urilib.URI(uri), perform_integrity_check=False - ) + file = File(DataType[data_type], SimDBUrl(uri), perform_integrity_check=False) file.uuid = data.uuid file.checksum = data.checksum file.datetime = data.datetime @@ -141,7 +136,7 @@ def to_model_with_path(self) -> FileGetDataResponse: if self.type.name == "FILE": if self.uri.path is None: raise ValueError("File path not set") - files = [FileInfo(path=self.uri.path, checksum=self.checksum)] + files = [FileInfo(path=Path(self.uri.path), checksum=self.checksum)] else: files = [ FileInfo(path=path, checksum=sha1_checksum(URI(f"file:{path}"))) diff --git a/src/simdb/database/models/simulation.py b/src/simdb/database/models/simulation.py index 201e9bfc..a508c137 100644 --- a/src/simdb/database/models/simulation.py +++ b/src/simdb/database/models/simulation.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union +from simdb.imas.utils import SimDBUrl from simdb.remote.models import ( FileDataList, MetadataData, @@ -32,9 +33,8 @@ ClauseElement.__bool__ = lambda self: True # type: ignore -import re -from simdb.cli.manifest import DataObject, Manifest +from simdb.cli.manifest import DataObject, DataType, Manifest from simdb.config.config import Config from simdb.docstrings import inherit_docstrings from simdb.imas.metadata import load_metadata @@ -42,10 +42,10 @@ check_time, extract_ids_occurrence, get_path_for_legacy_uri, + is_legacy_imas_uri, list_idss, open_imas, ) -from simdb.uri import URI from .base import Base from .file import File @@ -80,11 +80,10 @@ def _update_legacy_uri(data_object: DataObject): - if data_object.uri is None: - raise ValueError("Data object uri is not set") path = get_path_for_legacy_uri(data_object.uri) - backend = data_object.uri.query.get("backend", default="hdf5") - return URI(f"imas:{backend}?path={path}") + qs = dict(data_object.uri.query_params()) + backend = qs.get("backend", "hdf5") + return SimDBUrl.build(scheme="imas", path=backend, query=f"path={path}") class MetaDataWrapper: @@ -187,9 +186,7 @@ def __init__( all_input_idss = [] for input in manifest.inputs: - if input.uri is None: - raise ValueError("Source uri is not set") - if input.type == DataObject.Type.IMAS: + if input.type == DataType.IMAS: entry = open_imas(input.uri) idss = list_idss(entry) @@ -202,7 +199,7 @@ def __init__( entry.close() file = File(input.type, input.uri, all_input_idss, config=config) - if input.type == DataObject.Type.IMAS and "path" not in input.uri.query: + if input.type == DataType.IMAS and is_legacy_imas_uri(input.uri): file.uri = _update_legacy_uri(input) self.inputs.append(file) @@ -212,9 +209,7 @@ def __init__( all_output_idss = [] for output in manifest.outputs: - if output.uri is None: - raise ValueError("Sink uri is not set") - if output.type == DataObject.Type.IMAS: + if output.type == DataType.IMAS: entry = open_imas(output.uri) idss = list_idss(entry) for ids in idss: @@ -225,14 +220,13 @@ def __init__( meta = load_metadata(entry) entry.close() - flattened_meta: Dict[str, str] = {} - flatten_dict(flattened_meta, meta) + flattened_meta = flatten_dict(meta) for key, value in flattened_meta.items(): self.set_meta(key, value) file = File(output.type, output.uri, all_output_idss, config=config) - if output.type == DataObject.Type.IMAS and "path" not in output.uri.query: + if output.type == DataType.IMAS and is_legacy_imas_uri(output.uri): file.uri = _update_legacy_uri(output) self.outputs.append(file) @@ -240,12 +234,9 @@ def __init__( if all_output_idss: self.set_meta("ids", "[{}]".format(", ".join(all_output_idss))) - flattened_dict: Dict[str, str] = {} - flatten_dict(flattened_dict, manifest.metadata) + flattened_dict = flatten_dict(manifest.metadata) for key, value in flattened_dict.items(): - if "metadata#" in key: - key = re.sub(r"^metadata#\d+\.?", "", key) self.set_meta(key, value) if not self.find_meta("status"): self.set_meta("status", Simulation.Status.NOT_VALIDATED.value) @@ -325,18 +316,19 @@ def validate_meta(self) -> None: def file_paths(self) -> Set[Path]: def _get_path(file: File) -> Optional[Path]: if file.uri.scheme == "file": - if file.type == DataObject.Type.FILE: - return file.uri.path - elif file.type == DataObject.Type.IMAS: + if file.type == DataType.FILE: if file.uri.path is None: raise ValueError("Data object path is not set") - return file.uri.path.parent + return Path(file.uri.path) + elif file.type == DataType.IMAS: + if file.uri.path is None: + raise ValueError("Data object path is not set") + return Path(file.uri.path).parent else: raise ValueError(f"Unknown file type {file.type}") elif file.uri.scheme == "imas": - return ( - Path(file.uri.query["path"]) if "path" in file.uri.query else None - ) + qs = dict(file.uri.query_params()) + return Path(qs["path"]) if "path" in qs else None return None file_paths = set() diff --git a/src/simdb/database/models/types.py b/src/simdb/database/models/types.py index 24c8f479..1b18bb20 100644 --- a/src/simdb/database/models/types.py +++ b/src/simdb/database/models/types.py @@ -5,7 +5,7 @@ from sqlalchemy import types as sql_types from sqlalchemy.dialects import postgresql -from simdb import uri as urilib +from simdb.imas.utils import SimDBUrl class UUID(sql_types.TypeDecorator): @@ -64,21 +64,19 @@ class URI(sql_types.TypeDecorator): @property def python_type(self): - return urilib.URI + return SimDBUrl - def process_bind_param(self, value: Optional[urilib.URI], dialect) -> Optional[str]: + def process_bind_param(self, value: Optional[SimDBUrl], dialect) -> Optional[str]: if value is None: return value return str(value) - def process_result_value( - self, value: Optional[str], dialect - ) -> Optional[urilib.URI]: + def process_result_value(self, value: Optional[str], dialect) -> Optional[SimDBUrl]: if value is None: return value - return urilib.URI(value) + return SimDBUrl(value) - def process_literal_param(self, value, dialect) -> Optional[urilib.URI]: + def process_literal_param(self, value, dialect) -> Optional[SimDBUrl]: return self.process_result_value(value, dialect) diff --git a/src/simdb/database/models/utils.py b/src/simdb/database/models/utils.py index 517286c6..6be42f8b 100644 --- a/src/simdb/database/models/utils.py +++ b/src/simdb/database/models/utils.py @@ -1,22 +1,48 @@ from collections import deque -from typing import Any, Deque, Dict, List, Tuple, Type, Union +from typing import Any, Deque, Dict, List, Tuple, Type, Union, cast FLATTEN_DICT_DELIM = "." def flatten_dict( - out_dict: Dict[str, Any], - in_dict: Dict[str, Union[Dict, List, Any]], - prefix: Tuple = (), -): - for key, value in in_dict.items(): + data: Dict[str, Any], + prefix: str = "", + delim: str = ".", +) -> Dict[str, Any]: + """ + Recursively flattens a nested dictionary, representing nested structures with keys + delimited by a string, and lists with index suffixes. + + Example: + >>> flatten_dict({"a": {"b": 1}, "c": [2, {"d": 3}]}) + {'a.b': 1, 'c#1': 2, 'c#2.d': 3} + + :param data: The nested dictionary to flatten. + :param prefix: The prefix to prepend to keys. + :param delim: The delimiter string to separate keys. + :return: A flat dictionary. + """ + result: Dict[str, Any] = {} + + for key, value in data.items(): + full_key = f"{prefix}{delim}{key}" if prefix else key + if isinstance(value, dict): - flatten_dict(out_dict, value, (*prefix, key)) + result.update(flatten_dict(value, full_key, delim)) elif isinstance(value, list): for i, el in enumerate(value): - flatten_dict(out_dict, el, (*prefix, f"{key}#{i + 1}")) + if isinstance(el, dict): + result.update( + flatten_dict( + cast(Dict[str, Any], el), f"{full_key}#{i + 1}", delim + ) + ) + else: + result[f"{full_key}#{i + 1}"] = el else: - out_dict[FLATTEN_DICT_DELIM.join((*prefix, key))] = value + result[full_key] = value + + return result def _parse_index(head: str) -> Tuple[bool, str, int]: diff --git a/src/simdb/imas/checksum.py b/src/simdb/imas/checksum.py index 2451c8bb..d9d403ef 100644 --- a/src/simdb/imas/checksum.py +++ b/src/simdb/imas/checksum.py @@ -1,17 +1,14 @@ import hashlib from pathlib import Path -from simdb.uri import URI +from simdb.imas.utils import SimDBUrl from .utils import imas_files, list_idss, open_imas IGNORED_FIELDS = ("data_dictionary", "access_layer", "access_layer_language") -def checksum(uri: URI, ids_list: list) -> str: - if uri.scheme != "imas": - raise ValueError(f"invalid scheme for imas checksum: {uri.scheme}") - +def checksum(uri: SimDBUrl, ids_list: list) -> str: sha1 = hashlib.sha1() if not ids_list: diff --git a/src/simdb/imas/metadata.py b/src/simdb/imas/metadata.py index b6cd38a7..6f077167 100644 --- a/src/simdb/imas/metadata.py +++ b/src/simdb/imas/metadata.py @@ -6,6 +6,8 @@ import imas.dd_zip import imas.ids_defs +from simdb.remote.models import _array_to_range + class MetricException(Exception): pass @@ -129,7 +131,7 @@ def load_imas_metadata(ids_dist, entry) -> dict: ids = imas.convert_ids(ids, latest_dd_version) for node in imas.util.tree_iter(ids): metadata[extract_ids_path(str(node.coordinates)).replace("/", ".")] = ( # type: ignore - node.value # type: ignore + _array_to_range(node.value) # type: ignore ) return metadata diff --git a/src/simdb/imas/utils.py b/src/simdb/imas/utils.py index 0cc80c43..7b9ff234 100644 --- a/src/simdb/imas/utils.py +++ b/src/simdb/imas/utils.py @@ -1,7 +1,7 @@ import os from datetime import datetime from pathlib import Path -from typing import Any, List +from typing import Any, List, Optional import imas import imas.exception @@ -9,9 +9,9 @@ import semantic_version from dateutil import parser from imas import DBEntry +from pydantic import AnyUrl, TypeAdapter from simdb.config import Config -from simdb.uri import URI class ImasError(Exception): @@ -22,6 +22,36 @@ class ImasError(Exception): INT_MISSING_VALUE = -999999999 +class SimDBUrl(AnyUrl): + @classmethod + def build( + cls, + *, + scheme: str, + host: Optional[str] = None, + port: Optional[int] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + **kwargs, + ) -> "SimDBUrl": + url_str = f"{scheme}:" + + if host: + url_str += f"//{host}" + if port: + url_str += f":{port}" + url_str += "/" + + url_str += path or "" + if query: + url_str += f"?{query}" + if fragment: + url_str += f"#{fragment}" + + return TypeAdapter(cls).validate_python(url_str) + + def is_missing(value: Any): """ Returns whether the given value is one of IMASs 'missing' values. @@ -104,7 +134,7 @@ def check_time(entry: DBEntry, ids: str, occurrence) -> None: def _is_al5() -> bool: - al_env = os.environ.get("AL_VERSION", default=None) + al_env = os.environ.get("AL_VERSION") ual_env = os.environ.get("UAL_VERSION", default="5.0.0") version = ( semantic_version.Version(al_env) @@ -114,8 +144,9 @@ def _is_al5() -> bool: return version >= semantic_version.Version("5.0.0") -def _open_legacy(uri: URI) -> DBEntry: - path = uri.query.get("path", default=None) +def _open_legacy(uri: SimDBUrl) -> DBEntry: + qs = dict(uri.query_params()) + path = qs.get("path") if path is not None: raise ImasError(f"cannot open AL5 URI {uri} with AL4") @@ -123,12 +154,12 @@ def _open_legacy(uri: URI) -> DBEntry: "hdf5": imas.ids_defs.HDF5_BACKEND, } - backend = uri.query.get("backend", default=None) - user = uri.query.get("user", default=None) - database = uri.query.get("database", default=None) - version = uri.query.get("version", default="3") - shot = uri.query.get("shot", default=None) - run = uri.query.get("run", default=None) + backend = qs.get("backend") + user = qs.get("user") + database = qs.get("database") + version = qs.get("version", "3") + shot = qs.get("shot") + run = qs.get("run") if backend not in backend_ids: raise ImasError( @@ -174,7 +205,7 @@ def _open_legacy(uri: URI) -> DBEntry: return entry -def open_imas(uri: URI) -> DBEntry: +def open_imas(uri: SimDBUrl) -> DBEntry: """ Open an IMAS URI and return the IMAS entry object. @@ -182,30 +213,32 @@ def open_imas(uri: URI) -> DBEntry: @return: the IMAS data entry object """ - if uri.scheme != "imas": - raise ValueError(f"invalid imas URI: {uri} - invalid scheme") - - if uri.query is None: - raise ValueError(f"invalid imas URI: {uri} - no query found in URI") - if not _is_al5(): return _open_legacy(uri) - path = uri.query.get("path", default=None) - if path is None: - path = get_path_for_legacy_uri(uri) - backend = uri.query.get("backend", default="mdsplus") - uri = URI(f"imas:{backend}?path={path}") + if uri.path is None: + raise ValueError(f"invalid imas URI: {uri} - no path found in URI") + + if uri.scheme == "file": + imas_uri = uri.path + elif uri.scheme == "imas": + qs = dict(uri.query_params()) + path = qs.get("path") + if path is None: + raise ValueError(f"invalid imas URI: {uri} - no path found") + imas_uri = str(uri) + else: + raise ValueError(f"invalid imas URI: {uri} - invalid scheme") try: - entry = imas.DBEntry(str(uri), "r") + entry = imas.DBEntry(imas_uri, "r") except Exception as err: raise ImasError(f"failed to open IMAS data with URI {uri}") from err return entry -def imas_timestamp(uri: URI) -> datetime: +def imas_timestamp(uri: SimDBUrl) -> datetime: """ Extract the timestamp from the IDS data for the given IMAS URI. @@ -226,18 +259,23 @@ def imas_timestamp(uri: URI) -> datetime: return timestamp -def get_path_for_legacy_uri(uri: URI) -> Path: - user = uri.query.get("user", default=None) - database = uri.query.get("database", default=None) - version = uri.query.get("version", default="3") - shot = uri.query.get("shot", default=None) - run = uri.query.get("run", default=None) - backend = uri.query.get("backend", default="hdf5") +def is_legacy_imas_uri(uri: SimDBUrl) -> bool: + return bool(uri.scheme == "imas" and dict(uri.query_params()).get("path") is None) + + +def get_path_for_legacy_uri(uri: SimDBUrl) -> Path: + qs = dict(uri.query_params()) + user = qs.get("user") + database = qs.get("database") + version = qs.get("version", "3") + shot = qs.get("shot") + run = qs.get("run") + backend = qs.get("backend", "hdf5") if database is None or shot is None or run is None or version is None: raise ValueError(f"Invalid legacy URI {uri}") if user == "public": - imas_home = os.environ.get("IMAS_HOME", default=None) + imas_home = os.environ.get("IMAS_HOME") if imas_home is None: raise ValueError( "Legacy URI passed with user=public but $IMAS_HOME is not set" @@ -255,14 +293,15 @@ def get_path_for_legacy_uri(uri: URI) -> Path: return path / shot / run -def _get_path(uri: URI) -> Path: +def _get_path(uri: SimDBUrl) -> Path: """ Return the path to the data for a given IMAS URI @param uri: a valid IMAS URI @return: the path of the IDS data for the given IMAS URI """ - path = uri.query.get("path", default=None) + qs = dict(uri.query_params()) + path = qs.get("path") if path is None: raise ValueError("Invalid IMAS URI - path not found in query arguments") @@ -272,7 +311,7 @@ def _get_path(uri: URI) -> Path: return path -def imas_files(uri: URI) -> List[Path]: +def imas_files(uri: SimDBUrl) -> List[Path]: """ Return all the files associated with the given IMAS URI. @@ -280,6 +319,12 @@ def imas_files(uri: URI) -> List[Path]: @return: a list of files which contains the IDS data for the backend specified in the URI """ + if uri.path is None: + raise ValueError("URI path should not be none") + + if uri.scheme == "file": + return [Path(uri.path).absolute()] + backend = str(uri.path) if backend.startswith("/"): backend = backend[1:] @@ -300,7 +345,7 @@ def imas_files(uri: URI) -> List[Path]: raise ValueError(f"Unknown IMAS backend {backend}") -def convert_uri(uri: URI, path: Path, config: Config) -> URI: +def convert_uri(uri: SimDBUrl, path: Path, config: Config) -> SimDBUrl: """ Converts a local IMAS URI to a remote access IMAS URI based on the server.imas_remote_host configuration option. @@ -312,19 +357,21 @@ def convert_uri(uri: URI, path: Path, config: Config) -> URI: @param config: Config to read the server.imas_remote_host and server.imas_remote_port options from """ - host = config.get_option("server.imas_remote_host", default=None) + host = config.get_string_option("server.imas_remote_host", default=None) if host is None: raise ValueError( "Cannot process IMAS data as server.imas_remote_host configuration option " "not set" ) - port = config.get_option("server.imas_remote_port", default=None) + port = config.get_string_option("server.imas_remote_port", default=None) backend = uri.path - if port is None: - return URI(f"imas://{host}/uda?path={path}&backend={backend}") - else: - port = int(port) - return URI(f"imas://{host}:{port}/uda?path={path}&backend={backend}") + return SimDBUrl.build( + scheme="imas", + host=host, + port=None if port is None else int(port), + path="uda", + query=f"path={path}&backend={backend}", + ) def extract_ids_occurrence(ids: str) -> tuple[str, int]: diff --git a/src/simdb/remote/apis/files.py b/src/simdb/remote/apis/files.py index e846fa57..868abf6e 100644 --- a/src/simdb/remote/apis/files.py +++ b/src/simdb/remote/apis/files.py @@ -10,10 +10,10 @@ from werkzeug.datastructures import FileStorage from simdb.checksum import sha1_checksum -from simdb.cli.manifest import DataObject +from simdb.cli.manifest import DataType from simdb.database import DatabaseError, models from simdb.imas.checksum import checksum as imas_checksum -from simdb.imas.utils import imas_files +from simdb.imas.utils import SimDBUrl, imas_files from simdb.json import CustomDecoder from simdb.remote.core.auth import User, requires_auth from simdb.remote.core.errors import error @@ -21,7 +21,6 @@ from simdb.remote.core.pydantic_utils import pydantic_validate from simdb.remote.core.typing import current_app from simdb.remote.models import FileDataList, FileGetDataResponse -from simdb.uri import URI api = Namespace("files", path="/") @@ -40,29 +39,32 @@ def _verify_file( Path(current_app.simdb_config.get_string_option("server.upload_folder")) / sim_uuid.hex ) - if sim_file.type == DataObject.Type.FILE: + if sim_file.type == DataType.FILE: if sim_file.uri.path is None: raise ValueError("File does not have an associated path") - path = secure_path(sim_file.uri.path, common_root, staging_dir) + path = secure_path(Path(sim_file.uri.path), common_root, staging_dir) if not path.exists(): raise ValueError(f"file {path} does not exist") - checksum = sha1_checksum(URI(scheme="file", path=path)) + checksum = sha1_checksum(SimDBUrl.build(scheme="file", path=path.as_posix())) if sim_file.checksum != checksum: raise ValueError(f"checksum failed for file {sim_file!r}") - elif sim_file.type == DataObject.Type.IMAS: + elif sim_file.type == DataType.IMAS: uri = sim_file.uri - path_value = uri.query.get("path") + qs = dict(uri.query_params()) + path_value = qs.get("path") if path_value is None: raise ValueError("The 'path' key is missing in the URI query") if common_root == Path("/"): - uri.query.set("path", str(staging_dir) + path_value) + path_value = str(staging_dir) + path_value elif common_root is not None and common_root == path_value: - uri.query.set( - "path", path_value.replace(str(common_root), str(staging_dir)) - ) + path_value = path_value.replace(str(common_root), str(staging_dir)) + else: - uri.query.set("path", str(staging_dir)) - checksum = imas_checksum(uri, ids_list or []) + path_value = str(staging_dir) + new_uri = uri.build( + scheme=uri.scheme, path=uri.path, query=f"path={path_value}" + ) + checksum = imas_checksum(new_uri, ids_list or []) if sim_file.checksum != checksum: raise ValueError(f"checksum failed for simulation {sim_file.uri}") @@ -104,7 +106,7 @@ def _stage_file_from_chunks( found_files.append((file, sim_file)) for file, sim_file in found_files: - path = secure_path(sim_file.uri.path, common_root, staging_dir) + path = secure_path(Path(sim_file.uri.path), common_root, staging_dir) path.parent.mkdir(parents=True, exist_ok=True) file_chunk_info = chunk_info.get( sim_file.uuid.hex, {"chunk_size": 0, "chunk": 0, "num_chunks": 1} @@ -126,13 +128,13 @@ def _process_simulation_data(data: dict) -> Response: simulation = models.Simulation.from_data(data["simulation"]) sim_file_paths = simulation.file_paths() common_root = find_common_root(sim_file_paths) - if DataObject.Type(data["obj_type"]) == DataObject.Type.FILE: + if DataType(data["obj_type"]) == DataType.FILE: for file in data["files"]: sim_file = _check_file_is_in_simulation( simulation, uuid.UUID(file["file_uuid"]), file["file_type"] ) _verify_file(simulation.uuid, sim_file, common_root) - elif DataObject.Type(data["obj_type"]) == DataObject.Type.IMAS: + elif DataType(data["obj_type"]) == DataType.IMAS: file = data["files"][0] sim_files = ( simulation.inputs if file["file_type"] == "input" else simulation.outputs @@ -204,7 +206,7 @@ class NonIMASFileDownload(Resource): def get(self, file_uuid: str, user: Optional[User] = None): try: file: models.File = current_app.db.get_file(file_uuid) - if file.type != DataObject.Type.FILE: + if file.type != DataType.FILE: return error("Invalid file type for download") if file.uri.path is None: return error("File path is not set") @@ -224,7 +226,7 @@ class FileDownload(Resource): def get(self, file_uuid: str, file_index: int, user: Optional[User] = None): try: file: models.File = current_app.db.get_file(file_uuid) - if file.type == DataObject.Type.FILE: + if file.type == DataType.FILE: if file_index != 0: return error(f"invalid file_index for file {file.uri}") if file.uri.path is None: diff --git a/src/simdb/remote/apis/v1/simulations.py b/src/simdb/remote/apis/v1/simulations.py index 8a7c299f..3a56c9cb 100644 --- a/src/simdb/remote/apis/v1/simulations.py +++ b/src/simdb/remote/apis/v1/simulations.py @@ -12,6 +12,7 @@ from simdb.database import DatabaseError from simdb.database.models import simulation as models_sim from simdb.email.server import EmailServer +from simdb.imas.utils import SimDBUrl from simdb.query import QueryType, parse_query_arg from simdb.remote import APIConstants from simdb.remote.core.alias import create_alias_dir @@ -20,7 +21,6 @@ from simdb.remote.core.errors import error from simdb.remote.core.path import secure_path from simdb.remote.core.typing import current_app -from simdb.uri import URI from simdb.validation import ValidationError, Validator api = Namespace("simulations", path="/") @@ -206,7 +206,7 @@ def post(self, user: User): if not path.exists(): raise ValueError(f"simulation file {sim_file.uuid} not uploaded") if sim_file.uri.scheme.name == "file": - sim_file.uri = URI(scheme="file", path=path) + sim_file.uri = SimDBUrl.build(scheme="file", path=path) result = { "ingested": simulation.uuid.hex, diff --git a/src/simdb/remote/apis/v1_1/simulations.py b/src/simdb/remote/apis/v1_1/simulations.py index 539d889b..03c956b9 100644 --- a/src/simdb/remote/apis/v1_1/simulations.py +++ b/src/simdb/remote/apis/v1_1/simulations.py @@ -12,6 +12,7 @@ from simdb.database import DatabaseError from simdb.database.models import simulation as models_sim from simdb.email.server import EmailServer +from simdb.imas.utils import SimDBUrl from simdb.query import QueryType, parse_query_arg from simdb.remote import APIConstants from simdb.remote.core.alias import create_alias_dir @@ -20,7 +21,6 @@ from simdb.remote.core.errors import error from simdb.remote.core.path import secure_path from simdb.remote.core.typing import current_app -from simdb.uri import URI from simdb.validation import ValidationError, Validator api = Namespace("simulations", path="/") @@ -235,7 +235,7 @@ def post(self, user: User): if not path.exists(): raise ValueError(f"simulation file {sim_file.uuid} not uploaded") if sim_file.uri.scheme.name == "file": - sim_file.uri = URI(scheme="file", path=path) + sim_file.uri = SimDBUrl.build(scheme="file", path=path) result = { "ingested": simulation.uuid.hex, diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index d9d38ab7..169248af 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -14,7 +14,7 @@ from simdb.database.models import simulation as models_sim from simdb.database.models import watcher as models_watcher from simdb.email.server import EmailServer -from simdb.imas.utils import convert_uri +from simdb.imas.utils import SimDBUrl, convert_uri from simdb.query import QueryType, parse_query_arg from simdb.remote.core.alias import create_alias_dir from simdb.remote.core.auth import User, requires_auth @@ -46,7 +46,6 @@ StatusPatchData, ValidationResult, ) -from simdb.uri import URI from simdb.validation import ValidationError, Validator from simdb.validation.file import find_file_validator @@ -266,22 +265,25 @@ def post( and sim_file.uri.scheme == "file" and sim_file.uri.path is not None ): - path = secure_path(sim_file.uri.path, common_root, staging_dir) + path = secure_path( + Path(sim_file.uri.path), common_root, staging_dir + ) if not path.exists(): raise ResponseException( f"simulation file {sim_file.uuid} not uploaded" ) - sim_file.uri = URI(scheme="file", path=path) + sim_file.uri = SimDBUrl.build(scheme="file", path=path.as_posix()) elif sim_file.uri.scheme == "imas": + qs = dict(sim_file.uri.query_params()) if copy_files: path = secure_path( - Path(sim_file.uri.query["path"]), + Path(qs["path"]), common_root, staging_dir, is_file=common_root is not None, ) else: - path = Path(sim_file.uri.query["path"]) + path = Path(qs["path"]) sim_file.uri = convert_uri(sim_file.uri, path, config) result = SimulationPostResponse( diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index 0a8218ce..14d18806 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -33,7 +33,7 @@ RootModel as _RootModel, ) -from simdb.cli.manifest import DataObject +from simdb.cli.manifest import DataType HexUUID = Annotated[UUID, PlainSerializer(lambda x: x.hex, return_type=str)] """UUID serialized as a hex string.""" @@ -493,7 +493,7 @@ class FileRegistrationData(BaseModel): simulation: SimulationData """The simulation the files belong to.""" - obj_type: DataObject.Type + obj_type: DataType """The type of the data object being registered.""" files: List[FileRegistrationItem] """List of file registration items.""" diff --git a/src/simdb/uri.py b/src/simdb/uri.py deleted file mode 100644 index 0cacb824..00000000 --- a/src/simdb/uri.py +++ /dev/null @@ -1,165 +0,0 @@ -from pathlib import Path -from typing import Dict, Optional, Union - -from urllib3.util.url import LocationParseError, Url, parse_url - - -class URIParserError(ValueError): - def __init__(self, msg: str): - super().__init__(msg) - - -class Query: - """ - Class representing the URI query parameters. - """ - - _args: Dict[str, Optional[str]] - - def __init__(self, query: Optional[str]): - query = "" if query is None else query - self._args = {} - for arg in query.split("&"): - key, *value = arg.split("=") - if key and value: - self._args[key] = "=".join(value) - elif key: - self._args[key] = None - - @classmethod - def empty(cls): - return cls(None) - - def __str__(self): - return "&".join(f"{k}={v}" for k, v in self._args.items()) - - def __bool__(self): - return len(self._args) > 0 - - def __contains__(self, item) -> bool: - return item in self._args - - def __getitem__(self, name): - return self._args[name] - - def get(self, name: str, *, default: Optional[str] = None) -> Optional[str]: - return self._args.get(name, default) - - def set(self, name: str, value: str) -> None: - self._args[name] = value - - def remove(self, name: str) -> None: - del self._args[name] - - -class Authority: - """ - Class representing URI authority. - """ - - __slots__ = ("auth", "host", "port") - - def __init__(self, host: Optional[str], port: Optional[int], auth: Optional[str]): - self.host: Optional[str] = host - self.port: Optional[int] = port - self.auth: Optional[str] = auth - - @classmethod - def empty(cls): - return cls(None, None, None) - - def __bool__(self): - return bool(self.host) or bool(self.port) or bool(self.auth) - - def __str__(self): - string = "" - if self.host: - string = f"{self.host}" - if self.auth: - string = f"{self.auth}@{string}" - if self.port is not None: - string = f"{string}:{self.port}" - return string - - def __repr__(self): - return f"Authority({self.host}, {self.port}, {self.auth})" - - -class URI: - """ - Class for parsing and representing a URI. - """ - - __slots__ = ("authority", "fragment", "path", "query", "scheme") - - def __init__(self, uri: Union[str, "URI", None] = None, *, scheme=None, path=None): - """ - Create a URI object by either parsing a URI string or copying from an existing - URI object. - - :param uri: A URI string, another URI to copy from or None for an empty URI. - :param scheme: The URI scheme. Takes precedence over any scheme found from the - URI argument. - :param path: The URI path. Takes precedence over any path found from the URI - argument. - """ - self.scheme: Optional[str] = None - self.query: Query = Query.empty() - self.path: Optional[Path] = None - self.authority: Authority = Authority.empty() - self.fragment: Optional[str] = None - - if uri is not None: - try: - result: Url = parse_url(str(uri)) - except LocationParseError as err: - raise URIParserError("failed to parse URI") from err - self.scheme = result.scheme - self.query = Query(result.query) - self.authority = Authority(result.host, result.port, result.auth) - if result.path is not None: - if ( - self.scheme == "imas" - and not self.authority - and result.path.startswith("/") - ): - self.path = Path(result.path[1:]) - else: - self.path = Path(result.path) - self.fragment = result.fragment - if scheme is not None: - self.scheme = scheme - if path is not None: - self.path = Path(path) - if not self.scheme: - raise URIParserError("failed to parse URI: no scheme specified") - - @property - def uri(self) -> str: - """ - Return the URI object as a URI string. - - :return: A string representation of the URI. - """ - uri = f"{self.scheme}:" - if self.authority: - path = "" - if self.path and str(self.path) != ".": - path = self.path if self.path.is_absolute() else "/" / self.path - uri += f"//{self.authority}{path}" - elif self.path and str(self.path) != ".": - uri += f"{self.path}" - if self.query: - uri += f"?{self.query}" - if self.fragment: - uri += f"#{self.fragment}" - return uri - - def __repr__(self): - return f"URI({self.uri})" - - def __str__(self): - return self.uri - - def __eq__(self, other): - return self.uri == other.uri diff --git a/src/simdb/validation/file/ids_validator.py b/src/simdb/validation/file/ids_validator.py index 5f668e08..07452a41 100644 --- a/src/simdb/validation/file/ids_validator.py +++ b/src/simdb/validation/file/ids_validator.py @@ -1,5 +1,7 @@ from pathlib import Path +from simdb.imas.utils import SimDBUrl + try: from imas_validator.report.validationReportGenerator import ( ValidationReportGenerator, @@ -11,7 +13,6 @@ except ImportError: imas_validator_available = False -from simdb.uri import URI from simdb.validation.validator import ValidationError from .validator_base import FileValidatorBase @@ -72,7 +73,7 @@ def options(self) -> dict: "rule_files": [], } - def validate_uri(self, uri: URI, validate_options): + def validate_uri(self, uri: SimDBUrl, validate_options): if not imas_validator_available: raise RuntimeError( "IMAS-validator not available, please install this optional dependency" @@ -82,12 +83,16 @@ def validate_uri(self, uri: URI, validate_options): return try: - backend = uri.query.get("backend") - path = uri.query.get("path") - validate_uri = f"imas:{backend}?path={path}" + qs = dict(uri.query_params()) + backend = qs.get("backend") + path = qs.get("path") + validate_uri = SimDBUrl.build( + scheme="imas", path=backend, query=f"path={path}" + ) validate_output = validate( - imas_uri=URI(validate_uri).uri, validate_options=validate_options + imas_uri=validate_uri.encoded_string(), + validate_options=validate_options, ) validate_result = all(result.success for result in validate_output.results) diff --git a/src/simdb/validation/file/validator_base.py b/src/simdb/validation/file/validator_base.py index 5e7f222e..c2fa2a64 100644 --- a/src/simdb/validation/file/validator_base.py +++ b/src/simdb/validation/file/validator_base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from simdb.uri import URI +from simdb.imas.utils import SimDBUrl class FileValidatorBase(ABC): @@ -26,7 +26,7 @@ def options(self) -> dict: """ @abstractmethod - def validate_uri(self, uri: URI, validate_options): + def validate_uri(self, uri: SimDBUrl, validate_options): """ Validate the given simulation output file. """ diff --git a/tests/cli/test_cli_manifest_command.py b/tests/cli/test_cli_manifest_command.py index ba9b0758..efc96740 100644 --- a/tests/cli/test_cli_manifest_command.py +++ b/tests/cli/test_cli_manifest_command.py @@ -16,11 +16,10 @@ def test_manifest_check_command(manifest): ) assert result.exception is None assert "ok" in result.output - assert manifest.return_value.load.called - (args, kwargs) = manifest.return_value.load.call_args + assert manifest.load_from_file.called + (args, kwargs) = manifest.load_from_file.call_args assert str(args[0]) == str(file_name) assert kwargs == {} - assert manifest.return_value.validate.called def test_manifest_check_command_integration(): diff --git a/tests/cli/utils.py b/tests/cli/utils.py index 51e96af6..7c99c36f 100644 --- a/tests/cli/utils.py +++ b/tests/cli/utils.py @@ -21,14 +21,14 @@ def create_manifest() -> Path: # Data and configuration files inputs: # - uri: simdb://simdb.iter.org/123e4567-e89b-12d3-a456-426655440000 - - uri: file:///home/user/path/to/a/file1 - - uri: imas:///user?shot=10000&run=0 + - uri: file:///$MANIFEST_DIR/utils.py + - uri: imas:///user?shot=10000&run=0&database=west # - uri: imas+uda:///TOKAMAK?shot=10000&run=0&server=uda.server.org:56565 # Data and log files. outputs: - - uri: file:///home/user/path/to/a/file2 - - uri: imas:///user?shot=10000&run=1 + - uri: file:///$MANIFEST_DIR/utils.py + - uri: imas:///user?shot=10000&run=1&database=west metadata: - values: diff --git a/tests/database/test_models.py b/tests/database/test_models.py index 3c311970..ffc5be5b 100644 --- a/tests/database/test_models.py +++ b/tests/database/test_models.py @@ -1,9 +1,9 @@ from pathlib import Path from unittest import mock -from simdb.cli.manifest import DataObject +from simdb.cli.manifest import DataType from simdb.database.models import Simulation -from simdb.uri import URI +from simdb.imas.utils import SimDBUrl def test_create_simulation_without_manifest_creates_empty_sim(): @@ -23,18 +23,18 @@ def test_create_simulation_with_manifest(manifest_cls, data_object_cls): path = Path(__file__).absolute() manifest = manifest_cls() data_object = data_object_cls() - data_object.type = DataObject.Type.FILE - data_object.uri = URI(f"file://{path}") + data_object.type = DataType.FILE + data_object.uri = SimDBUrl(f"file://{path}") manifest.inputs = [data_object] manifest.outputs = [data_object] manifest.metadata = {"description": "test description", "uploaded_by": "test user"} sim = Simulation(manifest=manifest) assert len(sim.inputs) == 1 - assert sim.inputs[0].type == DataObject.Type.FILE - assert sim.inputs[0].uri == URI(f"file://{path}") + assert sim.inputs[0].type == DataType.FILE + assert sim.inputs[0].uri == SimDBUrl(f"file://{path}") assert len(sim.outputs) == 1 - assert sim.outputs[0].type == DataObject.Type.FILE - assert sim.outputs[0].uri == URI(f"file://{path}") + assert sim.outputs[0].type == DataType.FILE + assert sim.outputs[0].uri == SimDBUrl(f"file://{path}") assert len(sim.meta) == 3 meta = {m.element: m.value for m in sim.meta} assert meta == { diff --git a/tests/remote/api/conftest.py b/tests/remote/api/conftest.py index 6fbe2e3e..b8f9ff84 100644 --- a/tests/remote/api/conftest.py +++ b/tests/remote/api/conftest.py @@ -35,7 +35,7 @@ @pytest.fixture(scope="session") def client(): if not has_flask: - pytest.skip("Flask not installed") # type: ignore + pytest.skip("Flask not installed") config = Config() config.load() db_fd, db_file = tempfile.mkstemp() @@ -68,7 +68,7 @@ def client(): @pytest.fixture(scope="session") def client_copy_files(): if not has_flask: - pytest.skip("Flask not installed") # type: ignore + pytest.skip("Flask not installed") config = Config() config.load() db_fd, db_file = tempfile.mkstemp() diff --git a/tests/remote/api/test_files.py b/tests/remote/api/test_files.py index f37107a4..e4e6e328 100644 --- a/tests/remote/api/test_files.py +++ b/tests/remote/api/test_files.py @@ -13,7 +13,7 @@ post_simulation, ) -from simdb.cli.manifest import DataObject +from simdb.cli.manifest import DataType from simdb.json import CustomEncoder from simdb.remote.models import ( ChunkInfo, @@ -97,7 +97,7 @@ def create_simulation_with_file( registration_data = FileRegistrationData( simulation=simulation_data.simulation, - obj_type=DataObject.Type.FILE, + obj_type=DataType.FILE, files=[ FileRegistrationItem( chunks=num_chunks, diff --git a/tests/remote/test_authentication.py b/tests/remote/test_authentication.py index 48c0309f..d62f56a3 100644 --- a/tests/remote/test_authentication.py +++ b/tests/remote/test_authentication.py @@ -24,7 +24,7 @@ def test_check_role(get_string_option): app = Flask("test") config = Config() app.simdb_config = config # type: ignore - with app.app_context(): # type: ignore + with app.app_context(): get_string_option.return_value = 'user1,"user2", user3' ok = check_role(config, User("user1", ""), "test_role") assert ok diff --git a/tests/test_manifest.py b/tests/test_manifest.py new file mode 100644 index 00000000..d090e374 --- /dev/null +++ b/tests/test_manifest.py @@ -0,0 +1,172 @@ +from pathlib import Path + +import pytest +from pydantic import ValidationError + +from simdb.cli.manifest import DataType, Manifest + + +def test_valid_manifest_loading_and_validation(tmp_path): + # Setup some dummy input/output files to satisfy the file globbing check + input_file = tmp_path / "input.json" + input_file.write_text("{}") + output_file = tmp_path / "output.json" + output_file.write_text("{}") + + manifest_yaml = f"""\ +manifest_version: 2 +alias: test-simulation-alias +responsible_name: "John Doe" +inputs: + - uri: file://{input_file.as_posix()} + - uri: imas:///user?shot=10000&run=0&database=west +outputs: + - uri: file://{output_file.as_posix()} + - uri: imas:///user?shot=10000&run=1&database=west +metadata: + - machine: ITER + - code: + name: METIS + - description: sample description +""" + manifest_file = tmp_path / "manifest.yaml" + manifest_file.write_text(manifest_yaml) + + manifest = Manifest.load_from_file(manifest_file) + + assert manifest.manifest_version == 2 + assert manifest.version == 2 + assert manifest.alias == "test-simulation-alias" + assert manifest.responsible_name == "John Doe" + + inputs = list(manifest.inputs) + assert len(inputs) == 2 + assert inputs[0].type == DataType.FILE + assert inputs[1].type == DataType.IMAS + + outputs = list(manifest.outputs) + assert len(outputs) == 2 + assert outputs[0].type == DataType.FILE + assert outputs[1].type == DataType.IMAS + + +def test_manifest_path_expansion_with_manifest_dir(tmp_path): + # Setup some dummy file + input_file = tmp_path / "test_input.json" + input_file.write_text("{}") + + manifest_yaml = """\ +manifest_version: 2 +inputs: + - uri: file:///$MANIFEST_DIR/test_input.json +outputs: + - uri: imas:///user?shot=10000&run=1&database=west +""" + manifest_file = tmp_path / "manifest.yaml" + manifest_file.write_text(manifest_yaml) + + manifest = Manifest.load_from_file(manifest_file) + + inputs = list(manifest.inputs) + assert len(inputs) == 1 + assert inputs[0].uri.path is not None + assert Path(inputs[0].uri.path) == input_file + + +def test_invalid_manifest_version(tmp_path): + # version must be 2 + manifest_yaml = """\ +manifest_version: 1 +inputs: [] +outputs: [] +""" + manifest_file = tmp_path / "manifest.yaml" + manifest_file.write_text(manifest_yaml) + + with pytest.raises(ValidationError, match="Input should be 2"): + Manifest.load_from_file(manifest_file) + + +def test_manifest_version_must_be_integer(tmp_path): + manifest_yaml = """\ +manifest_version: "2" +inputs: [] +outputs: [] +""" + manifest_file = tmp_path / "manifest.yaml" + manifest_file.write_text(manifest_yaml) + + with pytest.raises(ValidationError, match="Input should be 2"): + Manifest.load_from_file(manifest_file) + + +def test_unknown_section(tmp_path): + manifest_yaml = """\ +manifest_version: 2 +inputs: [] +outputs: [] +unknown_field: true +""" + manifest_file = tmp_path / "manifest.yaml" + manifest_file.write_text(manifest_yaml) + + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + Manifest.load_from_file(manifest_file) + + +def test_invalid_alias_characters(tmp_path): + manifest_yaml = """\ +manifest_version: 2 +alias: "invalid alias" +inputs: [] +outputs: [] +""" + manifest_file = tmp_path / "manifest.yaml" + manifest_file.write_text(manifest_yaml) + + with pytest.raises(ValidationError, match="illegal characters in alias"): + Manifest.load_from_file(manifest_file) + + +def test_duplicate_uris_in_inputs(tmp_path): + manifest_yaml = """\ +manifest_version: 2 +inputs: + - uri: imas:///user?shot=10000&run=0&database=west + - uri: imas:///user?shot=10000&run=0&database=west +outputs: [] +""" + manifest_file = tmp_path / "manifest.yaml" + manifest_file.write_text(manifest_yaml) + + with pytest.raises(ValidationError, match="Duplicate URI found in inputs"): + Manifest.load_from_file(manifest_file) + + +def test_invalid_metadata_forbidden_characters(tmp_path): + manifest_yaml = """\ +manifest_version: 2 +inputs: [] +outputs: [] +metadata: + - "machine:name": "value" +""" + manifest_file = tmp_path / "manifest.yaml" + manifest_file.write_text(manifest_yaml) + + with pytest.raises(ValidationError, match="contains forbidden character"): + Manifest.load_from_file(manifest_file) + + +def test_missing_files_causes_validation_error(tmp_path): + manifest_yaml = """\ +manifest_version: 2 +inputs: + - uri: file:///nonexistent_file_path_xyz.json +outputs: [] +""" + manifest_file = tmp_path / "manifest.yaml" + manifest_file.write_text(manifest_yaml) + + with pytest.raises(ValidationError, match="No files found matching path"): + Manifest.load_from_file(manifest_file) diff --git a/tests/test_uri.py b/tests/test_uri.py deleted file mode 100644 index 1f6d8a2c..00000000 --- a/tests/test_uri.py +++ /dev/null @@ -1,63 +0,0 @@ -import unittest -from pathlib import Path - -from simdb.uri import URI - - -class URITests(unittest.TestCase): - def test_empty_uri(self): - uri = URI("imas:") - self.assertEqual(uri.scheme, "imas") - - def test_uri_without_scheme_throws(self): - with self.assertRaises(ValueError): - _uri = URI() - - def test_uri_with_path(self): - uri = URI("imas:hdf5") - self.assertEqual(uri.path, Path("hdf5")) - - def test_uri_with_query(self): - uri = URI("imas:hdf5?path=foo") - self.assertIn("path", uri.query) - self.assertEqual(uri.query["path"], "foo") - - def test_uri_with_authority(self): - uri = URI("imas://uda.iter.org/hdf5?path=foo") - self.assertEqual(uri.authority.host, "uda.iter.org") - - def test_uri_with_authority_with_port(self): - uri = URI("imas://uda.iter.org:56565/hdf5?path=foo") - self.assertEqual(uri.authority.port, 56565) - - def test_uri_with_authority_with_auth(self): - uri = URI("imas://user:passwd@uda.iter.org/hdf5?path=foo") - self.assertEqual(uri.authority.auth, "user:passwd") - - def test_get_query_argument_with_default(self): - uri = URI("imas:hdf5") - self.assertNotIn("path", uri.query) - self.assertEqual(uri.query.get("path", default="foo"), "foo") - - def test_updating_uri_query(self): - uri = URI("imas:hdf5?path=foo") - uri.query.set("path", "bar") - self.assertIn("path", uri.query) - self.assertEqual(uri.query["path"], "bar") - - def test_removing_argument_from_uri_query(self): - uri = URI("imas:hdf5?path=foo") - uri.query.remove("path") - self.assertNotIn("path", uri.query) - - def test_uri_to_string_just_path(self): - uri = URI("imas:hdf5") - self.assertEqual(str(uri), "imas:hdf5") - - def test_uri_to_string_full_uri(self): - uri = URI("imas://authority/hdf5?path=foo#frag") - self.assertEqual(str(uri), "imas://authority/hdf5?path=foo#frag") - - def test_uri_with_empty_authority_to_string(self): - uri = URI("imas:///hdf5?path=foo#frag") - self.assertEqual(str(uri), "imas:hdf5?path=foo#frag")