Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,13 @@ def media_storage(self, value: MediaStorage) -> None:
def artifact_path_exists(self) -> bool:
return self.artifact_path.exists()

@cached_property
def resolved_artifact_path(self) -> Path:
return self.artifact_path.resolve()

@cached_property
def resolved_dataset_name(self) -> str:
dataset_path = self.artifact_path / self.dataset_name
dataset_path = self._resolve_artifact_subpath(self.dataset_name)
if dataset_path.exists() and len(list(dataset_path.iterdir())) > 0:
if self.resume in (ResumeMode.ALWAYS, ResumeMode.IF_POSSIBLE):
return self.dataset_name
Expand All @@ -93,27 +97,27 @@ def resolved_dataset_name(self) -> str:

@property
def base_dataset_path(self) -> Path:
return self.artifact_path / self.resolved_dataset_name
return self._resolve_artifact_subpath(self.resolved_dataset_name)

@property
def dropped_columns_dataset_path(self) -> Path:
return self.base_dataset_path / self.dropped_columns_folder_name
return self._resolve_artifact_subpath(self.resolved_dataset_name, self.dropped_columns_folder_name)

@property
def final_dataset_path(self) -> Path:
return self.base_dataset_path / self.final_dataset_folder_name
return self._resolve_artifact_subpath(self.resolved_dataset_name, self.final_dataset_folder_name)

@property
def metadata_file_path(self) -> Path:
return self.base_dataset_path / METADATA_FILENAME
return self._resolve_artifact_subpath(self.resolved_dataset_name, METADATA_FILENAME)

@property
def partial_results_path(self) -> Path:
return self.base_dataset_path / self.partial_results_folder_name
return self._resolve_artifact_subpath(self.resolved_dataset_name, self.partial_results_folder_name)

@property
def processors_outputs_path(self) -> Path:
return self.base_dataset_path / self.processors_outputs_folder_name
return self._resolve_artifact_subpath(self.resolved_dataset_name, self.processors_outputs_folder_name)

@field_validator("artifact_path")
def validate_artifact_path(cls, v: Path | str) -> Path:
Expand Down Expand Up @@ -143,6 +147,8 @@ def validate_folder_names(self):
for name in folder_names:
if any(char in invalid_chars for char in name):
raise ArtifactStorageError(f"🛑 Directory name '{name}' contains invalid characters.")
if name in {".", ".."}:
raise ArtifactStorageError(f"🛑 Directory name '{name}' must not be '.' or '..'.")

# Initialize media storage with DISK mode by default
self._media_storage = MediaStorage(
Expand Down Expand Up @@ -364,5 +370,16 @@ def update_metadata(self, updates: dict) -> Path:
existing_metadata.update(updates)
return self.write_metadata(existing_metadata)

def _resolve_artifact_subpath(self, *parts: str) -> Path:
candidate_path = self.resolved_artifact_path.joinpath(*parts).resolve()
try:
candidate_path.relative_to(self.resolved_artifact_path)
except ValueError as exc:
joined_parts = str(Path(*parts)) if parts else "."
raise ArtifactStorageError(
f"🛑 Directory name '{joined_parts}' resolves outside the artifact path."
) from exc
return candidate_path

def _get_stage_path(self, stage: BatchStage) -> Path:
return getattr(self, resolve_string_enum(stage, BatchStage).value)
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,30 @@ def test_artifact_storage_invalid_characters_in_folder_names(tmp_path, invalid_c
{"final_dataset_folder_name": f"invalid{invalid_char}name"},
{"partial_results_folder_name": f"invalid{invalid_char}name"},
{"dropped_columns_folder_name": f"invalid{invalid_char}name"},
{"processors_outputs_folder_name": f"invalid{invalid_char}name"},
]

for params in invalid_params:
with pytest.raises(ArtifactStorageError, match="contains invalid characters"):
ArtifactStorage(artifact_path=tmp_path, **params)


@pytest.mark.parametrize("reserved_name", [".", ".."])
@pytest.mark.parametrize(
"field_name",
[
"dataset_name",
"final_dataset_folder_name",
"partial_results_folder_name",
"dropped_columns_folder_name",
"processors_outputs_folder_name",
],
)
def test_artifact_storage_rejects_reserved_directory_names(tmp_path, field_name, reserved_name):
with pytest.raises(ArtifactStorageError, match=r"must not be '\.' or '\.\.'"):
ArtifactStorage(artifact_path=tmp_path, **{field_name: reserved_name})


def test_artifact_storage_read_parquet_files(stub_artifact_storage):
df1 = lazy.pd.DataFrame([{"id": 1, "data": {"some_list": ["yes"]}}, {"id": 2, "data": {"some_list": ["no"]}}])
df2 = lazy.pd.DataFrame({"id": 3, "data": {"some_list": []}})
Expand Down Expand Up @@ -197,6 +214,14 @@ def test_artifact_storage_path_validation(stub_artifact_storage):
assert stub_artifact_storage.dropped_columns_dataset_path.is_absolute()


def test_base_dataset_path_rejects_paths_outside_artifact_root(tmp_path):
storage = ArtifactStorage(artifact_path=tmp_path, dataset_name="dataset")
storage.__dict__["resolved_dataset_name"] = ".."

with pytest.raises(ArtifactStorageError, match="resolves outside the artifact path"):
_ = storage.base_dataset_path


def test_artifact_storage_file_operations(stub_artifact_storage):
df = lazy.pd.DataFrame({"test": [1, 2, 3]})

Expand Down
15 changes: 15 additions & 0 deletions packages/data-designer/tests/interface/test_data_designer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
FileContentsSeedSource,
HuggingFaceSeedSource,
)
from data_designer.engine.dataset_builders.errors import ArtifactStorageError
from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode
from data_designer.engine.resources.seed_reader import (
FileSystemSeedReader,
Expand Down Expand Up @@ -736,6 +737,20 @@ def test_create_raises_error_when_builder_fails(
assert isinstance(exc_info.value.__cause__, RuntimeError)


def test_create_rejects_reserved_dataset_name(
stub_artifact_path, stub_model_providers, stub_sampler_only_config_builder, stub_managed_assets_path
):
data_designer = DataDesigner(
artifact_path=stub_artifact_path,
model_providers=stub_model_providers,
secret_resolver=PlaintextResolver(),
managed_assets_path=stub_managed_assets_path,
)

with pytest.raises(ArtifactStorageError, match=r"must not be '\.' or '\.\.'"):
data_designer.create(stub_sampler_only_config_builder, num_records=1, dataset_name="..")


def test_create_raises_error_when_profiler_fails(
stub_artifact_path, stub_model_providers, stub_sampler_only_config_builder, stub_managed_assets_path
):
Expand Down
Loading