diff --git a/CLAUDE.md b/CLAUDE.md index e5d53bd0..13ffd64a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -109,11 +109,17 @@ podman rm $(podman ps -a --filter name=forge- -q) | `forge:spec-pending` | Awaiting spec approval | | `forge:plan-pending` | Awaiting plan approval | | `forge:task-pending` | Awaiting task approval | +| `forge:task-takeover` | Standalone task/epic takeover trigger | +| `forge:task-triage-pending` | Task takeover awaiting triage completion | +| `forge:task-plan-pending` | Task takeover awaiting plan approval | +| `forge:task-plan-approved` | Task takeover plan approved | +| `forge:managed:task` | Task identity preservation label | +| `forge:managed:task-takeover` | Task takeover identity preservation label | | `forge:blocked` | Workflow blocked, needs intervention | | `forge:retry` | Trigger retry of failed step | | `forge:yolo` | Autonomous mode — skip all artifact approval gates (see warning below) | -> **⚠️ Warning — `forge:yolo`:** This label removes all human checkpoints for PRD, spec, plan, and task approval. Forge will proceed autonomously from ticket creation to implementation without pausing for review. Only use this on tickets where you are confident in the requirements and comfortable with Forge making all planning decisions. It does not bypass code review (the human review gate on the implementation PR is always required). +> **⚠️ Warning — `forge:yolo`:** This label removes all human checkpoints for PRD, spec, plan, task, and task plan approval. Forge will proceed autonomously from ticket creation to implementation without pausing for review. Only use this on tickets where you are confident in the requirements and comfortable with Forge making all planning decisions. It does not bypass code review (the human review gate on the implementation PR is always required). ## Jira Comment Syntax diff --git a/docs/guide/labels.md b/docs/guide/labels.md index 16d7461c..e1bda18c 100644 --- a/docs/guide/labels.md +++ b/docs/guide/labels.md @@ -24,17 +24,29 @@ These labels advance the pipeline. Forge watches for label changes via Jira webh | Plan Approval Gate | `forge:plan-pending` | Forge | Plan posted; waiting for approval | | Plan Approval Gate | `forge:plan-approved` | Human | Approve plan and trigger task decomposition + implementation | +### Task Takeover Workflow + +Standalone Tasks and Epics can be processed using Task Takeover trigger labels. These tickets bypass the standard parent Feature validation. + +| Stage | Pending Label | Approved Label | Purpose | +|-------|--------------|----------------|---------| +| Triage | `forge:task-triage-pending` | _N/A_ | Standalone ticket is missing required fields; waiting for update | +| Plan Approval | `forge:task-plan-pending` | `forge:task-plan-approved` | Plan is posted; waiting for approval | + ## Control Labels | Label | Purpose | |-------|---------| | `forge:managed` | Marks the ticket for Forge automation. Add this when creating a ticket to start the workflow. | +| `forge:task-takeover` | Triggers the Task Takeover workflow for standalone Tasks or Epics. | +| `forge:managed:task` | Identity preservation label used during Task Takeover transitions. | +| `forge:managed:task-takeover` | Identity preservation label used during Task Takeover transitions. | | `forge:blocked` | Set by Forge when a stage fails. Forge posts a comment with the error. | | `forge:retry` | Add this to resume from the exact node that failed. Forge removes it after resuming. | ## How to Use Labels -**Starting a workflow:** Create a Jira issue and add `forge:managed`. Forge detects the issue type (Feature or Bug) and begins the appropriate pipeline. +**Starting a workflow:** Create a Jira issue and add `forge:managed`. Forge detects the issue type (Feature or Bug) and begins the appropriate pipeline. For standalone Tasks or Epics, add `forge:task-takeover` (or another configured trigger label) to initiate the Task Takeover workflow. **Approving a stage:** When Forge posts a PRD, spec, or other artifact, it sets the `forge:*-pending` label. Change it to `forge:*-approved` to advance the workflow. Do not add the approved label manually before Forge posts — it won't be recognized until the pending state is set. diff --git a/docs/reference/config.md b/docs/reference/config.md index 72f94b5d..30381d71 100644 --- a/docs/reference/config.md +++ b/docs/reference/config.md @@ -128,3 +128,21 @@ These variables are used by `docker-compose.yml`, `devtools/docker-compose.dev.y ### MCP Servers MCP server configuration lives in `mcp-servers.json`, not `.env`. See the [MCP servers section](https://github.com/forge-sdlc/forge/blob/main/mcp-servers.json) of the repository. + +## Task Takeover Configuration + +Task Takeover allows Forge to process standalone Task and Epic issues directly from Jira. When a standalone Task/Epic issue is labeled with a task takeover trigger label, Forge bypasses the parent validation check and executes the task directly. + +Configuration settings can be defined in `Settings` under the `task_takeover` key (which can also be configured using environment variables as a JSON string under `TASK_TAKEOVER` or within the application config). + +### Settings Schema + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `enabled` | `bool` | `False` | Whether Task Takeover is enabled. | +| `issue_types` | `list[str]` | `[]` | List of Jira issue types that support task takeover (e.g., `["Task", "Epic"]`). | +| `require_tests` | `bool` | `True` | Whether tests are required to pass before merging the code. | +| `review_max_attempts` | `int` | `2` | Maximum number of PR review fix attempts. | +| `labels.trigger` | `str` | `"forge:task-takeover"` | Label that triggers the Task Takeover workflow. | +| `labels.pending` | `str` | `"forge:task-plan-pending"` | Label set by Forge when a task plan is pending approval. | +| `labels.approved` | `str` | `"forge:task-plan-approved"` | Label used by humans to approve the task plan. | diff --git a/src/forge/api/routes/jira.py b/src/forge/api/routes/jira.py index 4741b721..ea76a380 100644 --- a/src/forge/api/routes/jira.py +++ b/src/forge/api/routes/jira.py @@ -109,7 +109,7 @@ async def receive_jira_webhook( # Record webhook received metric record_webhook_received(source="jira", event_type=webhook_data.event_type) - # Filter: only process issues with forge:managed label + # Filter: only process issues with forge:managed label or task-takeover triggers issue_labels = payload.get("issue", {}).get("fields", {}).get("labels", []) has_forge_managed = "forge:managed" in issue_labels @@ -122,7 +122,27 @@ async def receive_jira_webhook( has_forge_managed = True break - if not has_forge_managed: + # Detect task-takeover trigger labels + has_takeover_trigger = False + if settings.task_takeover and settings.task_takeover.enabled: + takeover_triggers = { + "forge:task-takeover", + "forge:managed:task", + "forge:managed:task-takeover", + } + if settings.task_takeover.labels and settings.task_takeover.labels.trigger: + takeover_triggers.add(settings.task_takeover.labels.trigger) + + has_takeover_trigger = any(label in issue_labels for label in takeover_triggers) + for item in changelog_items: + if item.get("field") == "labels": + to_labels = item.get("toString", "") or "" + updated_labels = to_labels.split() + if any(label in updated_labels for label in takeover_triggers): + has_takeover_trigger = True + break + + if not (has_forge_managed or has_takeover_trigger): span.set_attribute("forge.skipped", True) span.set_attribute("forge.skip_reason", "missing forge:managed label") logger.debug(f"Skipping {webhook_data.ticket_key}: missing forge:managed label") @@ -163,6 +183,13 @@ async def receive_jira_webhook( f"Routing {issue_type} {source_ticket_key} webhook " f"to parent Feature {routing_ticket_key}" ) + elif has_takeover_trigger and issue_type in ("Epic", "Task"): + # Bypass parent validation for Epic/Task if takeover trigger label is present. + # routing_ticket_key remains webhook_data.ticket_key, source_ticket_key remains None. + logger.info( + f"Bypassing parent checks for standalone {issue_type} " + f"{webhook_data.ticket_key} due to task-takeover trigger label." + ) else: # Epics/Tasks without forge:parent are invalid - reject span.set_attribute("forge.skipped", True) diff --git a/src/forge/config.py b/src/forge/config.py index e1bc7db9..27e445e0 100644 --- a/src/forge/config.py +++ b/src/forge/config.py @@ -4,7 +4,7 @@ from functools import cached_property, lru_cache from typing import TYPE_CHECKING, Literal -from pydantic import Field, SecretStr +from pydantic import BaseModel, Field, SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict if TYPE_CHECKING: @@ -13,6 +13,24 @@ logger = logging.getLogger(__name__) +class TaskTakeoverLabels(BaseModel): + """Labels used for task takeover workflow.""" + + trigger: str = "forge:task-takeover" + pending: str = "forge:task-plan-pending" + approved: str = "forge:task-plan-approved" + + +class TaskTakeoverSettings(BaseModel): + """Settings configuration for task takeover.""" + + enabled: bool = False + issue_types: list[str] = Field(default_factory=list) + labels: TaskTakeoverLabels = Field(default_factory=TaskTakeoverLabels) + require_tests: bool = True + review_max_attempts: int = 2 + + class Settings(BaseSettings): """Application settings loaded from environment variables.""" @@ -360,6 +378,12 @@ def ignored_ci_checks(self) -> list[str]: description="Enable distributed tracing", ) + # Task Takeover Configuration + task_takeover: TaskTakeoverSettings = Field( + default_factory=TaskTakeoverSettings, + description="Configuration settings for Task Takeover feature", + ) + @property def langfuse_enabled(self) -> bool: """Check if Langfuse tracing is enabled and configured.""" diff --git a/src/forge/integrations/agents/agent.py b/src/forge/integrations/agents/agent.py index 2b69f5a1..2497bbda 100644 --- a/src/forge/integrations/agents/agent.py +++ b/src/forge/integrations/agents/agent.py @@ -1159,17 +1159,40 @@ async def answer_question( generation_context = context.get("generation_context", {}) raw_requirements = generation_context.get("raw_requirements", "Not available") - prompt = load_prompt( - "answer-question", - artifact_type=artifact_type, - artifact_content=artifact_content, - raw_requirements=raw_requirements, - question=question, - ) + ticket_type = context.get("ticket_type") + ticket_type_str = "" + if ticket_type is not None: + ticket_type_str = ( + ticket_type.value if hasattr(ticket_type, "value") else str(ticket_type) + ) + + if ( + artifact_type == "plan" + and context.get("current_node") == "task_plan_approval_gate" + and ticket_type_str == "task" + ): + prompt = load_prompt( + "task-takeover-qa", + ticket_key=context.get("ticket_key", ""), + summary=context.get("summary", ""), + description=context.get("description", ""), + plan_content=artifact_content, + question=question, + ) + task_name = "task-takeover-qa" + else: + prompt = load_prompt( + "answer-question", + artifact_type=artifact_type, + artifact_content=artifact_content, + raw_requirements=raw_requirements, + question=question, + ) + task_name = "answer-question" - logger.info(f"Answering question about {artifact_type}") + logger.info(f"Answering question about {artifact_type} using task={task_name}") result = await self.run_task( - task="answer-question", + task=task_name, prompt=prompt, context={ "artifact_type": artifact_type, diff --git a/src/forge/integrations/jira/client.py b/src/forge/integrations/jira/client.py index 46abecb4..82dfcb66 100644 --- a/src/forge/integrations/jira/client.py +++ b/src/forge/integrations/jira/client.py @@ -742,13 +742,15 @@ async def set_workflow_label( # Get current labels current_labels = await self.get_labels(issue_key) - # Find forge: labels to remove (except the new one and forge:managed) + # Find forge: labels to remove (except the new one, forge:managed, and identity preservation labels) labels_to_remove = [ label for label in current_labels if label.startswith(remove_prefix) and label != new_label.value and label != ForgeLabel.FORGE_MANAGED.value + and label != "forge:managed:task" + and label != "forge:managed:task-takeover" ] # Build update operations diff --git a/src/forge/integrations/jira/models.py b/src/forge/integrations/jira/models.py index b32c406c..8d94725d 100644 --- a/src/forge/integrations/jira/models.py +++ b/src/forge/integrations/jira/models.py @@ -183,6 +183,8 @@ def extract_children(nodes: list[dict[str, Any]]) -> list[str]: return [block for block in blocks if block] blocks = extract_blocks(adf) + if adf.get("type") == "doc" and not blocks: + return "" return "\n\n".join(blocks) if blocks else str(adf) diff --git a/src/forge/models/workflow.py b/src/forge/models/workflow.py index 3904386a..578f8915 100644 --- a/src/forge/models/workflow.py +++ b/src/forge/models/workflow.py @@ -124,6 +124,12 @@ class ForgeLabel(StrEnum): RCA_APPROVED = "forge:rca-approved" TRIAGE_PENDING = "forge:triage-pending" + # Task Takeover workflow + TASK_TAKEOVER = "forge:task-takeover" + TASK_TRIAGE_PENDING = "forge:task-triage-pending" + TASK_PLAN_PENDING = "forge:task-plan-pending" + TASK_PLAN_APPROVED = "forge:task-plan-approved" + # General FORGE_MANAGED = "forge:managed" BLOCKED = "forge:blocked" diff --git a/src/forge/orchestrator/worker.py b/src/forge/orchestrator/worker.py index 7c56de78..e6f994e5 100644 --- a/src/forge/orchestrator/worker.py +++ b/src/forge/orchestrator/worker.py @@ -61,6 +61,7 @@ def _is_workflow_errored(state: dict) -> bool: "prd_approval_gate", "spec_approval_gate", "plan_approval_gate", + "task_plan_approval_gate", "task_approval_gate", "rca_option_gate", } @@ -258,9 +259,10 @@ async def _process_workflow(self, message: QueueMessage) -> None: ) else: # Use router to resolve which workflow to use + labels = message.payload.get("issue", {}).get("fields", {}).get("labels", []) or [] workflow_instance = self.router.resolve( ticket_type=ticket_type, - labels=[], # TODO: Extract labels from message payload + labels=labels, event=message.payload, ) @@ -580,6 +582,8 @@ async def _handle_resume_event( approval_stage = "prd" elif "spec-approved" in to_labels.lower(): approval_stage = "spec" + elif "task-plan-approved" in to_labels.lower(): + approval_stage = "task_plan" elif "plan-approved" in to_labels.lower(): approval_stage = "plan" elif "task-approved" in to_labels.lower(): @@ -597,10 +601,18 @@ async def _handle_resume_event( "decompose_epics": "plan", "regenerate_all_epics": "plan", "update_single_epic": "plan", + "task_plan_approval_gate": "task_plan", "task_approval_gate": "task", "generate_tasks": "task", } expected_stage = node_to_stage.get(current_node) + if current_node == "plan_approval_gate" and current_state.get("ticket_type") in ( + "Task", + "Epic", + TicketType.TASK, + TicketType.EPIC, + ): + expected_stage = "task_plan" if approval_stage and expected_stage and approval_stage == expected_stage: is_approved = True @@ -622,7 +634,11 @@ async def _handle_resume_event( gate_to_approved_label = { "prd_approval_gate": "forge:prd-approved", "spec_approval_gate": "forge:spec-approved", - "plan_approval_gate": "forge:plan-approved", + "plan_approval_gate": "forge:task-plan-approved" + if current_state.get("ticket_type") + in ("Task", "Epic", TicketType.TASK, TicketType.EPIC) + else "forge:plan-approved", + "task_plan_approval_gate": "forge:task-plan-approved", "task_approval_gate": "forge:task-approved", } expected_label = gate_to_approved_label.get(current_node) @@ -1121,6 +1137,7 @@ async def _handle_resume_event( "plan_approval_gate", "task_approval_gate", "plan_approval_gate_bug", + "task_plan_approval_gate", } prev_error = current_state.get("last_error") is_paused_at_gate = current_state.get("is_paused") and current_node in approval_gates @@ -1314,6 +1331,7 @@ def _stage_label_for_node(current_node: str) -> str: "update_single_epic": "the plan", "rca_option_gate": "the RCA", "plan_approval_gate_bug": "the plan", + "task_plan_approval_gate": "the task plan", "task_approval_gate": "the tasks", "generate_tasks": "the tasks", "regenerate_all_tasks": "the tasks", @@ -1535,8 +1553,23 @@ def _extract_ticket_type(self, message: QueueMessage) -> TicketType: # by the Jira webhook handler. The payload still carries the child's # issue type, which won't match any workflow. Fall through to UNKNOWN # so _find_workflow_by_state resolves it from checkpoint. + # stand-alone task takeover events (which have trigger labels) bypass child checks. + labels = fields.get("labels", []) or [] + takeover_triggers = { + "forge:task-takeover", + "forge:managed:task", + "forge:managed:task-takeover", + } + if ( + self.settings.task_takeover + and self.settings.task_takeover.labels + and self.settings.task_takeover.labels.trigger + ): + takeover_triggers.add(self.settings.task_takeover.labels.trigger) + is_takeover = any(label in labels for label in takeover_triggers) + child_types = {"Epic", "Task", "Sub-task"} - if ticket_type_str in child_types: + if ticket_type_str in child_types and not is_takeover: return TicketType.UNKNOWN # Map string to TicketType enum diff --git a/src/forge/prompts/v1/task-takeover-planning.md b/src/forge/prompts/v1/task-takeover-planning.md new file mode 100644 index 00000000..6e323ef4 --- /dev/null +++ b/src/forge/prompts/v1/task-takeover-planning.md @@ -0,0 +1,47 @@ +## Task Ticket + +**Key:** {ticket_key} +**Summary:** {summary} + +**Description:** +{description} + +**Comments:** +{comments} + +## Available Repositories + +Use only these exact repository names when tagging `repo:/` in the plan: + +{known_repos} + +## File Metadata + +Here is the file metadata gathered from the repository to help guide your plan: + +{file_metadata} + +## Repository Grounding Requirements + +Before writing `.forge/plan.md`, inspect the relevant repository using available repository, GitHub, or filesystem tools. + +- Read repo guidance when present: `AGENTS.md`, `CLAUDE.md`, `.claude/AGENTS.md`, `.claude/CLAUDE.md`, `README.md`, `CONTRIBUTING.md`, `Makefile`, language-specific project files, docs, and repo-local skills or agent instructions. +- Confirm planned files, functions/classes, test locations, generated-file requirements, and validation commands against real repository contents. +- Follow discovered repository standards for architecture, naming, error handling, testing, packaging, documentation, and local agent workflow. +- Prefer codebase exploration focused on the ticket description, proposed solution/approach, nearby code, and validation commands. Broaden the search when needed to understand the context safely. Do not inspect project-management metadata such as unrelated branches, open issues, pull requests, milestones, or release boards unless explicitly required. +- Use nearby code and test patterns instead of guessing from path names alone. +- Do not invent generic paths, symbols, frameworks, test runners, or directory layouts. If repository inspection is unavailable, write the plan with an explicit blocking note explaining what repo access or configuration is required. + +## Formulate Implementation Plan + +Formulate a concrete implementation plan mapping the proposed solution to specific target files and test plans. + +Your plan MUST include: +1. **Target Files**: List the specific, existing repository files to be modified, or new files to be created, incorporating the gathered file metadata and repository inspection. +2. **Implementation Steps**: Clear, sequential steps for implementing the proposed solution/approach. +3. **Test Plans**: A detailed validation plan describing how the changes will be tested. Map the proposed solutions to concrete unit or integration tests, naming specific test commands and test files (existing or new) to run. + +--- + +Produce a detailed implementation plan. +Write the plan to `.forge/plan.md`. diff --git a/src/forge/prompts/v1/task-takeover-qa.md b/src/forge/prompts/v1/task-takeover-qa.md new file mode 100644 index 00000000..83700804 --- /dev/null +++ b/src/forge/prompts/v1/task-takeover-qa.md @@ -0,0 +1,30 @@ +You are answering a user's clarifying question during the interactive planning gate of Task Takeover. + +## Task Ticket + +**Key:** {ticket_key} +**Summary:** {summary} + +**Description:** +{description} + +## Current Implementation Plan + +{plan_content} + +## Question / Feedback + +{question} + +## Instructions + +Formulate a high-quality, professional, and technically accurate response to the user's question or feedback. + +Your response should: +1. **Directly Address the Question**: Provide clear, specific answers or explanations for each point raised in the question/feedback. +2. **Reference the Ticket and Plan**: Base your reasoning on the ticket description, comments, and the current draft of the implementation plan. +3. **Incorporate Repository Context**: When the question asks about specific files, tests, commands, or project conventions, refer to real file paths and code structures in the repository. Do not guess or invent details. +4. **Suggest Actionable Updates**: If the user's feedback requires changes to the plan, clearly explain how the plan should be updated to address their concerns. +5. **Follow Repository Standards**: Ensure the proposed solutions align with the project's architecture, naming conventions, error handling, testing, and other conventions. + +Format your answer in clear prose. Do not use excessive markdown formatting. diff --git a/src/forge/prompts/v1/task-takeover-review.md b/src/forge/prompts/v1/task-takeover-review.md new file mode 100644 index 00000000..cf33100c --- /dev/null +++ b/src/forge/prompts/v1/task-takeover-review.md @@ -0,0 +1,31 @@ +## Task Takeover Qualitative Review + +You are a senior read-only LLM code reviewer. Your job is to assess the git diff of the implemented changes against the Jira ticket's "Acceptance Criteria". + +### Ticket Acceptance Criteria +{acceptance_criteria} + +### Git Diff of Implemented Changes +{git_diff} + +--- + +## Qualitative Review Guidelines & Assertions + +Please carefully evaluate the git diff and perform the following explicit assertions: +1. **Acceptance Criteria**: Verify whether every target acceptance criteria requirement is fully met. +2. **Automated Test Coverage**: Verify that at least one automated test has been written or updated in the diff to cover the changes. + +## Output Format + +Your response must contain exactly one of the following verdicts on its own line: +`verdict: adequate` +or +`verdict: tests_incomplete` + +Followed by your constructive feedback in this format: +`feedback: ` + +Only these two verdict values are valid: `adequate` or `tests_incomplete`. +- Use `adequate` only if both assertions (all acceptance criteria requirements are fully met and at least one automated test is written/updated) are completely satisfied. +- Use `tests_incomplete` if any acceptance criteria requirement is unmet, or if no automated test has been written or updated. diff --git a/src/forge/prompts/v1/task-takeover-triage.md b/src/forge/prompts/v1/task-takeover-triage.md new file mode 100644 index 00000000..838ea7f7 --- /dev/null +++ b/src/forge/prompts/v1/task-takeover-triage.md @@ -0,0 +1,41 @@ +## Task Ticket + +**Summary:** {summary} + +**Description:** +{description} + +**Comments:** +{comments} + +--- + +### System Guidelines + +You are an AI software engineer evaluating the completeness of a Task/Epic ticket for Task Takeover triage. + +Evaluate the ticket description and comments to check if they provide enough clear, actionable information to formulate a concrete implementation plan. You must strictly enforce the presence and clarity of the following three mandatory sections: + +1. **Problem Statement**: A clear statement of what the current problem is, why it occurs, or what new capability is required. +2. **Proposed Solution/Approach**: A concrete plan, design, or guidance on how to implement the solution. +3. **Acceptance Criteria**: A list of specific requirements, behaviors, or conditions that must be satisfied to consider the task complete. + +### Output Format + +Output exactly one of the following: + +1. If all three mandatory sections ("Problem Statement", "Proposed Solution/Approach", "Acceptance Criteria") are sufficiently detailed and clear to begin planning, output ONLY the exact bare string: +sufficient + +2. If any of the three sections are missing, incomplete, or require clarification, output ONLY a JSON array of the missing/incomplete fields. Choose only from these three exact names: +[ + "Problem Statement", + "Proposed Solution/Approach", + "Acceptance Criteria" +] + +Strictly adhere to the following output rules: +- Do NOT wrap your output in markdown code blocks (such as ``` or ```json). +- Do NOT include any additional comments, explanations, greetings, or whitespace. +- If sufficient, output only the word "sufficient" (case-insensitive). +- If insufficient, output only a valid JSON list of strings representing the missing fields. diff --git a/src/forge/workflow/gates/__init__.py b/src/forge/workflow/gates/__init__.py index 41894c2e..9cb10e29 100644 --- a/src/forge/workflow/gates/__init__.py +++ b/src/forge/workflow/gates/__init__.py @@ -20,6 +20,10 @@ route_task_approval, task_approval_gate, ) +from forge.workflow.gates.task_plan_approval import ( + route_task_plan_approval, + task_plan_approval_gate, +) __all__ = [ "prd_approval_gate", @@ -30,4 +34,6 @@ "route_plan_approval", "route_task_approval", "task_approval_gate", + "route_task_plan_approval", + "task_plan_approval_gate", ] diff --git a/src/forge/workflow/gates/task_plan_approval.py b/src/forge/workflow/gates/task_plan_approval.py new file mode 100644 index 00000000..0ae6b803 --- /dev/null +++ b/src/forge/workflow/gates/task_plan_approval.py @@ -0,0 +1,92 @@ +"""Task plan approval gate for standalone task-takeover workflow review. + +The task plan approval workflow uses labels: +- forge:task-plan-pending - Task plan awaiting approval +- forge:task-plan-approved - Task plan approved (triggers isolated execution workspace setup) + +To approve: Change label to forge:task-plan-approved +To request revision: Add a comment with prefix '!' (keep forge:task-plan-pending) +To ask clarifying questions: Add a comment with prefix '?' or '@forge ask' +""" + +import logging +from typing import Any, cast + +from langgraph.graph import END + +from forge.api.routes.metrics import record_approval, record_revision_requested +from forge.workflow.task_takeover.state import TaskTakeoverState +from forge.workflow.utils import set_paused +from forge.workflow.utils.comment_classifier import CommentType, classify_comment + +logger = logging.getLogger(__name__) + + +def task_plan_approval_gate(state: TaskTakeoverState) -> TaskTakeoverState: + """Pause task takeover workflow for human review of the generated plan. + + Args: + state: Current task takeover workflow state. + + Returns: + State with is_paused=True and current_node="task_plan_approval_gate". + """ + ticket_key = state.get("ticket_key", "unknown") + logger.info(f"Task plan approval gate: pausing workflow for {ticket_key}") + return cast( + TaskTakeoverState, + set_paused(cast(dict[str, Any], state), "task_plan_approval_gate"), + ) + + +def route_task_plan_approval(state: TaskTakeoverState) -> str: + """Route after task plan approval gate resumes. + + Args: + state: Current TaskTakeoverState. + + Returns: + Name of the next node or END. + """ + ticket_key = state.get("ticket_key", "unknown") + feedback = state.get("feedback_comment") + is_question = state.get("is_question", False) + revision_requested = state.get("revision_requested", False) + + # Classify comment text if available + if feedback: + comment_type = classify_comment(feedback) + if comment_type == CommentType.QUESTION: + is_question = True + elif comment_type == CommentType.FEEDBACK: + revision_requested = True + + # 1. Q&A Mode + if is_question: + logger.info(f"Q&A mode: routing to answer_question for {ticket_key}") + return "answer_question" + + # 2. Revision/Feedback requested (comment starting with !) + if revision_requested: + logger.info(f"Revision requested for {ticket_key}: routing to regenerate_plan") + record_revision_requested("task_plan") + return "regenerate_plan" + + # 3. YOLO Mode + if state.get("yolo_mode"): + logger.info(f"YOLO mode: auto-approving task plan for {ticket_key}") + record_approval("task_plan") + return "setup_workspace" + + # 4. If still paused, remain in paused state + if state.get("is_paused"): + logger.info( + f"Task plan approval gate: workflow paused for {ticket_key}, " + "waiting for approval webhook/label update" + ) + return END + + # 5. Approved -> route to isolated execution setup node (setup_workspace) + logger.info(f"Task plan approved for {ticket_key}, proceeding to workspace setup") + record_approval("task_plan") + return "setup_workspace" diff --git a/src/forge/workflow/nodes/__init__.py b/src/forge/workflow/nodes/__init__.py index 676a5903..4a348c87 100644 --- a/src/forge/workflow/nodes/__init__.py +++ b/src/forge/workflow/nodes/__init__.py @@ -63,6 +63,11 @@ route_tasks_parallel, should_use_parallel_execution, ) +from forge.workflow.nodes.task_takeover_execution import execute_task_changes +from forge.workflow.nodes.task_takeover_planning import generate_plan +from forge.workflow.nodes.task_takeover_pr import create_task_takeover_pr +from forge.workflow.nodes.task_takeover_review import run_qualitative_review +from forge.workflow.nodes.task_takeover_triage import triage_task from forge.workflow.nodes.triage import route_triage_gate, triage_check, triage_gate from forge.workflow.nodes.workspace_setup import ( get_workspace_manager, @@ -120,6 +125,16 @@ "triage_check", "triage_gate", "route_triage_gate", + # Task takeover workflow — triage + "triage_task", + # Task takeover workflow — planning + "generate_plan", + # Task takeover workflow — execution + "execute_task_changes", + # Task takeover workflow — PR creation + "create_task_takeover_pr", + # Task takeover workflow — review + "run_qualitative_review", # Bug workflow — RCA analysis "analyze_bug", "reflect_rca", diff --git a/src/forge/workflow/nodes/qa_handler.py b/src/forge/workflow/nodes/qa_handler.py index cce9e73e..0ff46033 100644 --- a/src/forge/workflow/nodes/qa_handler.py +++ b/src/forge/workflow/nodes/qa_handler.py @@ -99,6 +99,17 @@ async def answer_question(state: WorkflowState) -> WorkflowState: artifact_content = _get_artifact_content(state, artifact_type) generation_context = state.get("generation_context", {}).get(artifact_type, {}) + # Fetch issue details for Q&A if not already present in state + summary = state.get("summary") or "" + description = state.get("description") or "" + if not summary or not description: + try: + issue = await jira.get_issue(ticket_key) + summary = summary or issue.summary or "" + description = description or issue.description or "" + except Exception as ex: + logger.warning(f"Could not fetch issue for Q&A: {ex}") + # Generate answer using agent answer = await agent.answer_question( question=question, @@ -112,6 +123,8 @@ async def answer_question(state: WorkflowState) -> WorkflowState: "retry_count": state.get("retry_count", 0), "artifact_type": artifact_type, "generation_context": generation_context, + "summary": summary, + "description": description, }, ) diff --git a/src/forge/workflow/nodes/task_takeover_execution.py b/src/forge/workflow/nodes/task_takeover_execution.py new file mode 100644 index 00000000..12687bb0 --- /dev/null +++ b/src/forge/workflow/nodes/task_takeover_execution.py @@ -0,0 +1,182 @@ +"""Task execution node for Task Takeover workflow.""" + +import contextlib +import logging +from pathlib import Path +from typing import cast + +from forge.config import get_settings +from forge.integrations.jira.client import JiraClient +from forge.sandbox.runner import ContainerConfig, ContainerRunner +from forge.workflow.task_takeover.state import TaskTakeoverState +from forge.workflow.utils import update_state_timestamp +from forge.workflow.utils.jira_status import post_status_comment +from forge.workspace.git_ops import GitOperations +from forge.workspace.manager import Workspace + +logger = logging.getLogger(__name__) + + +async def execute_task_changes(state: TaskTakeoverState) -> TaskTakeoverState: + """Execute code modifications and run tests in a container sandbox. + + Args: + state: Current TaskTakeoverState. + + Returns: + Updated TaskTakeoverState. + """ + ticket_key = state["ticket_key"] + workspace_path = state.get("workspace_path") + current_repo = state.get("current_repo", "") + branch_name = state.get("context", {}).get("branch_name", "") + current_task = state.get("current_task_key") or ticket_key + + settings = get_settings() + jira = JiraClient(settings) + + if not workspace_path: + logger.error(f"No workspace for task execution on {ticket_key}") + return cast( + TaskTakeoverState, + update_state_timestamp( + { + **state, + "last_error": "Workspace not set up", + "current_node": "execute_task_changes", + } + ), + ) + + try: + # Get details from Jira for task implementation context + task_issue = await jira.get_issue(current_task) + task_summary = task_issue.summary + task_description = task_issue.description or "" + plan_content = state.get("plan_content") or "" + + # Post status comment that we are starting execution + await post_status_comment( + jira, + ticket_key, + f"🔨 Forge is implementing changes and tests for [{current_task}]: {task_summary}", + ) + + # Build task description with requirements injected + review_feedback = state.get("review_feedback") + feedback_section = "" + if review_feedback: + feedback_section = f"## Previous Qualitative Review Feedback\nPlease address the following feedback from the qualitative review:\n{review_feedback}\n\n" + + task_prompt = ( + f"You are implementing changes for task takeover [{current_task}].\n\n" + f"{feedback_section}" + f"## Approved Implementation Plan\n{plan_content}\n\n" + f"## Task Description\n{task_description}\n\n" + f"## Critical Instructions\n" + f"1. Read and understand the existing codebase.\n" + f"2. Apply code modifications according to the approved plan.\n" + f"3. You MUST inject at least one new or modified test file inside the workspace to verify the changes.\n" + f"4. Run compilation and local test suite commands inside the container workspace.\n" + f"5. Feed any build/test error and failure logs directly back to your reasoning process to enable iterative self-correction.\n" + f"6. Make sure all compilation and local tests pass successfully before finishing.\n" + ) + + # Initialize ContainerRunner matching sandbox configuration + runner = ContainerRunner(settings) + config = ContainerConfig() + + # Run task execution inside the container + result = await runner.run( + workspace_path=Path(workspace_path), + task_summary=f"Execute task takeover changes for {current_task}", + task_description=task_prompt, + config=config, + ticket_key=ticket_key, + task_key=current_task, + repo_name=current_repo, + previous_task_keys=state.get("implemented_tasks", []), + ) + + # Initialize GitOperations on the host to stage and commit + workspace_obj = Workspace( + path=Path(workspace_path), + repo_name=current_repo or "", + branch_name=branch_name or "", + ticket_key=ticket_key, + ) + git = GitOperations(workspace_obj) + + committed = False + commit_message = ( + f"[{current_task}] feat: implement task takeover execution changes and tests" + ) + + # Check for uncommitted changes on host and stage/commit + if git.has_uncommitted_changes(): + git.stage_all() + committed = git.commit(commit_message) + + current_sha = git.get_current_sha() + + # Post status comment based on results + if result.success: + await post_status_comment( + jira, + ticket_key, + f"✅ Task takeover implementation succeeded. Created commit: {commit_message[:50]}...", + ) + else: + await post_status_comment( + jira, + ticket_key, + f"⚠️ Task takeover implementation failed/exited with code {result.exit_code}. Logs recorded.", + ) + + # Store results, logs, and commit info in state + return cast( + TaskTakeoverState, + update_state_timestamp( + { + **state, + "task_execution_results": { + "success": result.success, + "exit_code": result.exit_code, + "error_message": result.error_message, + }, + "task_execution_logs": { + "stdout": result.stdout, + "stderr": result.stderr, + }, + "commit_info": { + "sha": current_sha, + "message": commit_message, + "committed": committed, + }, + "current_node": "execute_task_changes", + "last_error": None if result.success else result.error_message, + "retry_count": 0 if result.success else state.get("retry_count", 0) + 1, + } + ), + ) + + except Exception as e: + logger.error(f"execute_task_changes failed for {ticket_key}: {e}") + with contextlib.suppress(Exception): + from forge.workflow.nodes.error_handler import notify_error + + await notify_error(state, str(e), "execute_task_changes") # type: ignore[arg-type] + + return cast( + TaskTakeoverState, + update_state_timestamp( + { + **state, + "last_error": str(e), + "current_node": "execute_task_changes", + "retry_count": state.get("retry_count", 0) + 1, + } + ), + ) + finally: + await jira.close() diff --git a/src/forge/workflow/nodes/task_takeover_planning.py b/src/forge/workflow/nodes/task_takeover_planning.py new file mode 100644 index 00000000..3a2bbb1a --- /dev/null +++ b/src/forge/workflow/nodes/task_takeover_planning.py @@ -0,0 +1,306 @@ +"""Planning node for Task Takeover workflow.""" + +import contextlib +import logging +from pathlib import Path +from typing import Any, cast + +from forge.config import get_settings +from forge.integrations.jira.client import JiraClient +from forge.models.workflow import ForgeLabel +from forge.prompts import load_prompt +from forge.sandbox.runner import ContainerConfig, ContainerRunner +from forge.workflow.task_takeover.state import TaskTakeoverState +from forge.workflow.utils import set_paused, update_state_timestamp +from forge.workspace.git_ops import GitOperations +from forge.workspace.manager import WorkspaceManager + +logger = logging.getLogger(__name__) + +_MAX_COMMENT_CHARS = 25_000 +_TRUNCATION_NOTE = "*(Plan truncated — full plan available in container logs.)*" + +__all__ = ["generate_plan", "plan_approval_gate", "route_plan_approval"] + + +def _gather_file_metadata(workspace_path: Path) -> str: + """Gather file structure and metadata from the cloned workspace.""" + lines = [] + ignore_dirs = { + ".git", + "node_modules", + ".venv", + "__pycache__", + ".pytest_cache", + "dist", + "build", + "target", + ".mypy_cache", + ".ruff_cache", + ".forge", + } + + count = 0 + max_files = 300 + for path in sorted(workspace_path.rglob("*")): + try: + # Skip if any part is ignored + if any(part in ignore_dirs for part in path.relative_to(workspace_path).parts): + continue + except ValueError: + continue + + if path.is_file(): + # Skip common binary/unwanted extensions + if path.suffix.lower() in { + ".png", + ".jpg", + ".jpeg", + ".gif", + ".ico", + ".pyc", + ".pdf", + ".zip", + ".tar", + ".gz", + ".woff", + ".woff2", + ".ttf", + ".eot", + }: + continue + + try: + rel_path = path.relative_to(workspace_path) + size = path.stat().st_size + lines.append(f"- {rel_path} ({size} bytes)") + count += 1 + if count >= max_files: + lines.append(f"- ... and more files (truncated at {max_files} files)") + break + except Exception: + continue + + if not lines: + return "No files found in workspace." + return "\n".join(lines) + + +def _truncate_plan_comment(plan_content: str, max_chars: int = _MAX_COMMENT_CHARS) -> str: + """Truncate plan comment at last paragraph boundary before the character limit.""" + if len(plan_content) <= max_chars: + return plan_content + + available = max_chars - len(_TRUNCATION_NOTE) - 4 + truncated = plan_content[:available] + last_para = truncated.rfind("\n\n") + if last_para > 0: + truncated = truncated[:last_para] + + return truncated + "\n\n" + _TRUNCATION_NOTE + + +def _harvest_plan(workspace_path: Path) -> str: + """Read .forge/plan.md from the container workspace. + + Raises: + FileNotFoundError: if plan.md was not written. + ValueError: if plan.md is empty. + """ + plan_file = workspace_path / ".forge" / "plan.md" + if not plan_file.exists(): + raise FileNotFoundError(f"plan.md not found at {plan_file}") + content = plan_file.read_text() + if not content.strip(): + raise ValueError("plan.md is empty") + return content + + +async def generate_plan(state: TaskTakeoverState) -> TaskTakeoverState: + """Generate or regenerate task takeover plan. + + Args: + state: Current TaskTakeoverState. + + Returns: + Updated TaskTakeoverState. + """ + ticket_key = state["ticket_key"] + retry_count = state.get("retry_count", 0) + is_revision = ( + state.get("revision_requested", False) or state.get("feedback_comment") is not None + ) + feedback_comment = state.get("feedback_comment") or "" + original_plan = state.get("plan_content") or "" + + settings = get_settings() + jira = JiraClient(settings) + + try: + issue = await jira.get_issue(ticket_key) + comments = await jira.get_comments(ticket_key) + comment_text = "\n\n".join(c.body for c in comments if c.body) + + # Notify Jira before we start container + if is_revision: + await jira.add_comment( + ticket_key, + "Revising the plan based on your feedback — this will take a few minutes.", + ) + else: + await jira.add_comment( + ticket_key, + "Starting implementation plan generation — gathering codebase metadata and drafting the plan. This will take a few minutes.", + ) + + # 1. Determine and clone/checkout repository + current_repo = state.get("current_repo") + if not current_repo: + with contextlib.suppress(Exception): + current_repo = await jira.get_project_default_repo(issue.project_key) + if not current_repo: + with contextlib.suppress(Exception): + repos = await jira.get_project_repos(issue.project_key) + if repos: + current_repo = repos[0] + + if not current_repo or current_repo == "unknown" or "/" not in current_repo: + raise ValueError(f"No valid repository found for project {issue.project_key}") + + # Update current_repo in state + state = cast(TaskTakeoverState, {**state, "current_repo": current_repo}) + + # 2. Get Workspace and clone if needed + workspace_manager = WorkspaceManager(base_dir=settings.workspace_base_dir) + workspace = workspace_manager.create_workspace( + repo_name=current_repo, + ticket_key=ticket_key, + ) + git = GitOperations(workspace) + if not (workspace.path / ".git").exists(): + git.clone() + + # 3. Gather repository file structure & metadata + file_metadata = _gather_file_metadata(workspace.path) + + # 4. Load project's known repos + known_repos: list[str] = [] + with contextlib.suppress(Exception): + known_repos = await jira.get_project_repos(issue.project_key) + if not known_repos: + known_repos = [current_repo] + + # 5. Formulate prompt + task_description = load_prompt( + "task-takeover-planning", + ticket_key=ticket_key, + summary=issue.summary or "", + description=issue.description or "", + comments=comment_text, + known_repos="\n".join(known_repos), + file_metadata=file_metadata, + ) + + # If this is a revision, append the feedback details to task_description + if is_revision: + task_description += f"\n\n## Revision Request\nThis is a revision request. Please update the original plan based on the feedback below.\n\n### Original Plan\n{original_plan}\n\n### Feedback Comment\n{feedback_comment}\n" + + # 6. Run container with ContainerRunner (skipping tests for planning speed) + runner = ContainerRunner(settings) + config = ContainerConfig(skip_tests=True) + result = await runner.run( + workspace_path=workspace.path, + task_summary=f"Plan task takeover for {ticket_key}", + task_description=task_description, + config=config, + ticket_key=ticket_key, + task_key=f"{ticket_key}-plan", + ) + + if not result.success: + raise RuntimeError( + f"Container failed with exit_code={result.exit_code}: {result.stderr}" + ) + + new_plan = _harvest_plan(workspace.path) + + # 7. Post the plan to Jira + truncated_comment = _truncate_plan_comment(new_plan) + await jira.add_comment(ticket_key, truncated_comment) + await jira.set_workflow_label(ticket_key, ForgeLabel.TASK_PLAN_PENDING) + + return cast( + TaskTakeoverState, + update_state_timestamp( + { + **state, + "plan_content": new_plan, + "current_node": "task_plan_approval_gate", + "last_error": None, + "retry_count": 0, + "feedback_comment": None, + "revision_requested": False, + } + ), + ) + + except Exception as e: + logger.error(f"generate_plan failed for {ticket_key}: {e}") + new_retry = retry_count + 1 + return cast( + TaskTakeoverState, + update_state_timestamp( + { + **state, + "last_error": str(e), + "current_node": "generate_plan", + "retry_count": new_retry, + } + ), + ) + finally: + if ( + "workspace" in locals() + and workspace + and "workspace_manager" in locals() + and workspace_manager + ): + workspace_manager.destroy_workspace(workspace) + await jira.close() + + +def plan_approval_gate(state: TaskTakeoverState) -> TaskTakeoverState: + """Pause and wait for plan approval. + + Args: + state: Current task takeover workflow state. + + Returns: + State with is_paused=True and current_node=plan_approval_gate. + """ + return cast(TaskTakeoverState, set_paused(cast(dict[str, Any], state), "plan_approval_gate")) + + +def route_plan_approval(state: TaskTakeoverState) -> str: + """Route after plan approval gate resumes. + + Checks state flags: + 1. is_paused -> END + 2. revision_requested -> generate_plan + 3. (otherwise, approved) -> END + + Args: + state: Current TaskTakeoverState. + + Returns: + Name of next node or END. + """ + from langgraph.graph import END + + if state.get("is_paused"): + return END + + if state.get("revision_requested"): + return "generate_plan" + + return END diff --git a/src/forge/workflow/nodes/task_takeover_pr.py b/src/forge/workflow/nodes/task_takeover_pr.py new file mode 100644 index 00000000..e8c7c94c --- /dev/null +++ b/src/forge/workflow/nodes/task_takeover_pr.py @@ -0,0 +1,229 @@ +"""PR creation node for Task Takeover workflow.""" + +import asyncio +import contextlib +import logging +from pathlib import Path +from typing import cast + +from forge.integrations.github.client import GitHubClient +from forge.integrations.jira.client import JiraClient +from forge.workflow.nodes.workspace_setup import teardown_workspace +from forge.workflow.task_takeover.state import TaskTakeoverState as WorkflowState +from forge.workflow.utils import update_state_timestamp +from forge.workspace.git_ops import GitOperations +from forge.workspace.manager import Workspace + +logger = logging.getLogger(__name__) + + +async def cleanup_podman_containers(ticket_key: str) -> None: + """Stop and remove any running or stopped podman containers related to the ticket. + + Args: + ticket_key: Jira ticket key to match container names. + """ + try: + # Find containers with name matching forge-{ticket_key}-* + proc = await asyncio.create_subprocess_exec( + "podman", + "ps", + "-a", + "--filter", + f"name=forge-{ticket_key}-", + "--format", + "{{.Names}}", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, _ = await proc.communicate() + container_names = [name.strip() for name in stdout.decode().splitlines() if name.strip()] + + for name in container_names: + logger.info(f"Stopping container: {name}") + stop_proc = await asyncio.create_subprocess_exec( + "podman", + "stop", + "-t", + "5", + name, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + await stop_proc.wait() + + logger.info(f"Removing container: {name}") + rm_proc = await asyncio.create_subprocess_exec( + "podman", + "rm", + "-f", + name, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + await rm_proc.wait() + except Exception as e: + logger.warning(f"Error during podman container teardown for {ticket_key}: {e}") + + +async def create_task_takeover_pr(state: WorkflowState) -> WorkflowState: + """Create a pull request from workspace changes for the task takeover workflow. + + This node: + 1. Synchronizes local changes with repository fork. + 2. Pushes local changes using GitOperations. + 3. Opens a Pull Request using GitHubClient. + 4. Posts the PR markdown link as a comment on Jira. + 5. Transitions Jira ticket status to "In Review". + 6. Teardown the workspace and container runner, freeing all resources. + + Args: + state: Current task takeover workflow state. + + Returns: + Updated state with PR details, workspace cleared. + """ + ticket_key = state["ticket_key"] + workspace_path = state.get("workspace_path") + current_repo = state.get("current_repo", "") + branch_name = state.get("context", {}).get("branch_name") or f"forge/{ticket_key.lower()}" + + if not workspace_path: + logger.error(f"No workspace for PR creation on {ticket_key}") + return cast( + WorkflowState, + update_state_timestamp( + { + **state, + "last_error": "Workspace not set up", + "current_node": "create_task_takeover_pr", + } + ), + ) + + github = GitHubClient() + jira = JiraClient() + + try: + # Step 1: Set up GitOperations + workspace = Workspace( + path=Path(workspace_path), + repo_name=current_repo, + branch_name=branch_name, + ticket_key=ticket_key, + ) + git = GitOperations(workspace) + + # Step 2: Push changes to fork + if not current_repo or "/" not in current_repo: + raise ValueError( + f"Invalid repository format '{current_repo}': must be in owner/repo format" + ) + + owner, repo = current_repo.split("/") + logger.info(f"Getting or creating fork for {current_repo}") + fork_data = await github.get_or_create_fork(owner, repo) + fork_owner = fork_data["owner"]["login"] + fork_repo = fork_data["name"] + + # Sync fork with upstream main branch + await github.sync_fork_with_upstream(fork_owner, fork_repo) + + # Add fork remote and push + git.add_fork_remote(fork_owner, fork_repo) + git.push_to_fork() + + # Step 3: Fetch Jira issue details to construct the PR title/description + ticket_summary = "" + ticket_description = "" + try: + ticket_issue = await jira.get_issue(ticket_key) + ticket_summary = ticket_issue.summary or "" + ticket_description = ticket_issue.description or "" + except Exception as e: + logger.warning(f"Could not fetch ticket details for PR: {e}") + + pr_title = f"[{ticket_key}] {ticket_summary or 'Task Takeover Implementation'}" + pr_body = ( + f"This Pull Request implements task takeover for ticket **[{ticket_key}]**.\n\n" + f"### Ticket Description\n" + f"{ticket_description}\n\n" + f"Co-authored-by: Forge " + ) + + # Step 4: Open a Pull Request from fork to upstream + pr_data = await github.create_pull_request( + owner=owner, + repo=repo, + title=pr_title, + body=pr_body, + head=f"{fork_owner}:{branch_name}", + base="main", + ) + pr_url = pr_data.get("html_url", "") + pr_number = pr_data.get("number") + + # Step 5: Post the PR markdown link as a comment on Jira + pr_label = f"PR #{pr_number}" if pr_number is not None else "Pull Request" + pr_markdown_link = f"[{pr_label}]({pr_url})" + comment_text = ( + f"🚀 Task takeover implementation complete. Pull Request created:\n\n{pr_markdown_link}" + ) + await jira.add_comment(ticket_key, comment_text) + + # Create remote link in Jira as well for better integration + with contextlib.suppress(Exception): + await jira.create_remote_link(ticket_key, pr_url, pr_label) + + # Step 6: Transition the Jira ticket status to "In Review" + await jira.transition_issue(ticket_key, "In Review") + + # Update PR URL lists + pr_urls = state.get("pr_urls", []) + if pr_url and pr_url not in pr_urls: + pr_urls.append(pr_url) + + # Update state with PR information before teardown + state_with_pr = { + **state, + "pr_urls": pr_urls, + "current_pr_url": pr_url, + "current_pr_number": pr_number, + "fork_owner": fork_owner, + "fork_repo": fork_repo, + } + + # Step 7: Teardown workspace and container runner resources + # Clean up any lingering container runners + await cleanup_podman_containers(ticket_key) + + # Clean up files and delete workspace + teardown_state = await teardown_workspace(cast(WorkflowState, state_with_pr)) + + return cast( + WorkflowState, + update_state_timestamp( + { + **teardown_state, + "current_node": "complete", + "last_error": None, + } + ), + ) + + except Exception as e: + logger.error(f"Task takeover PR creation node failed for {ticket_key}: {e}") + return cast( + WorkflowState, + update_state_timestamp( + { + **state, + "last_error": str(e), + "current_node": "create_task_takeover_pr", + "retry_count": state.get("retry_count", 0) + 1, + } + ), + ) + finally: + await github.close() + await jira.close() diff --git a/src/forge/workflow/nodes/task_takeover_review.py b/src/forge/workflow/nodes/task_takeover_review.py new file mode 100644 index 00000000..45821724 --- /dev/null +++ b/src/forge/workflow/nodes/task_takeover_review.py @@ -0,0 +1,197 @@ +"""Qualitative review node for Task Takeover workflow.""" + +import contextlib +import logging +import re +from pathlib import Path +from typing import cast + +from forge.config import get_settings +from forge.integrations.agents import ForgeAgent +from forge.integrations.jira.client import JiraClient +from forge.workflow.task_takeover.state import TaskTakeoverState as WorkflowState +from forge.workflow.utils import update_state_timestamp +from forge.workflow.utils.jira_status import post_status_comment +from forge.workspace.git_ops import GitOperations +from forge.workspace.manager import Workspace + +logger = logging.getLogger(__name__) + + +def _extract_acceptance_criteria(description: str) -> str: + """Extract Acceptance Criteria section from description, or fall back to the entire description.""" + if not description: + return "No description or acceptance criteria provided." + # Look for "Acceptance Criteria" case-insensitively + lower_desc = description.lower() + index = lower_desc.find("acceptance criteria") + if index != -1: + # Return everything from the found heading to the end + return description[index:].strip() + return description.strip() + + +def _get_git_diff(git: GitOperations) -> str: + """Retrieve git diff of the implemented changes.""" + for args in [("diff", "HEAD~1", "HEAD"), ("diff", "HEAD~1"), ("diff",), ("show", "HEAD")]: + try: + res = git._run_git(*args, check=False) + if res.returncode == 0 and res.stdout.strip(): + return cast(str, res.stdout) + except Exception: + continue + return "No changes detected or unable to retrieve git diff." + + +def _parse_qualitative_review(output: str) -> tuple[str, str]: + """Parse qualitative review response to extract verdict and constructive feedback. + + Looks for a line matching 'verdict: ' (case-insensitive). + Everything after a 'feedback:' line is treated as the constructive feedback. + + Defaults to 'tests_incomplete' if verdict is absent or unrecognized. + """ + verdict = "tests_incomplete" + feedback = "" + + verdict_match = re.search(r"verdict:\s*`?([a-zA-Z_]+)", output, re.IGNORECASE) + if verdict_match: + candidate = verdict_match.group(1).strip().lower() + if candidate in {"adequate", "tests_incomplete"}: + verdict = candidate + else: + logger.warning( + f"Unrecognized verdict string '{candidate}', defaulting to tests_incomplete" + ) + + feedback_match = re.search(r"feedback:\s*(.*)", output, re.IGNORECASE | re.DOTALL) + if feedback_match: + feedback = feedback_match.group(1).strip() + + return verdict, feedback + + +async def run_qualitative_review(state: WorkflowState) -> WorkflowState: + """Assess git diff against Jira ticket Acceptance Criteria using a read-only LLM reviewer. + + Args: + state: Current workflow state. + + Returns: + Updated workflow state with verdict, feedback, and retry metrics. + """ + ticket_key = state["ticket_key"] + workspace_path = state.get("workspace_path") + current_repo = state.get("current_repo", "") + branch_name = state.get("context", {}).get("branch_name", "") + current_task = state.get("current_task_key") or ticket_key + + settings = get_settings() + jira = JiraClient(settings) + + if not workspace_path: + logger.error(f"No workspace for qualitative review on {ticket_key}") + return cast( + WorkflowState, + update_state_timestamp( + { + **state, + "last_error": "Workspace not set up", + "current_node": "qualitative_review", + } + ), + ) + + try: + # Fetch ticket details from Jira + task_issue = await jira.get_issue(current_task) + description = task_issue.description or "" + acceptance_criteria = _extract_acceptance_criteria(description) + + await post_status_comment( + jira, + ticket_key, + f"🔍 Forge is performing a qualitative review on the changes for {current_task}...", + ) + + # Initialize GitOperations to retrieve git diff + workspace_obj = Workspace( + path=Path(workspace_path), + repo_name=current_repo or "", + branch_name=branch_name or "", + ticket_key=ticket_key, + ) + git = GitOperations(workspace_obj) + git_diff = _get_git_diff(git) + + # Set up a read-only ForgeAgent (include_tools=False) + agent = ForgeAgent(settings) + + # Prepare the qualitative review prompt + from forge.prompts import load_prompt + + prompt_content = load_prompt( + "task-takeover-review", + acceptance_criteria=acceptance_criteria, + git_diff=git_diff, + ) + + # Run review via agent + response = await agent.run_task( + task="task-takeover-review", + prompt=prompt_content, + include_tools=False, + trace_context={ + "ticket_key": ticket_key, + "current_node": "qualitative_review", + }, + ) + + # Parse verdict and feedback + verdict, feedback = _parse_qualitative_review(response) + + # Update retry metrics + current_retry_count = state.get("qualitative_review_retry_count", 0) + new_retry_count = current_retry_count + (0 if verdict == "adequate" else 1) + failed = verdict != "adequate" + + await post_status_comment( + jira, + ticket_key, + f"📋 Qualitative review verdict: **{verdict}**\n\nFeedback:\n{feedback}", + ) + + return cast( + WorkflowState, + update_state_timestamp( + { + **state, + "review_verdict": verdict, + "review_feedback": feedback, + "qualitative_review_retry_count": new_retry_count, + "qualitative_review_failed": failed, + "current_node": "qualitative_review", + "last_error": None, + } + ), + ) + + except Exception as e: + logger.error(f"run_qualitative_review failed for {ticket_key}: {e}") + with contextlib.suppress(Exception): + from forge.workflow.nodes.error_handler import notify_error + + await notify_error(state, str(e), "qualitative_review") # type: ignore[arg-type] + + return cast( + WorkflowState, + update_state_timestamp( + { + **state, + "last_error": str(e), + "current_node": "qualitative_review", + } + ), + ) + finally: + await jira.close() diff --git a/src/forge/workflow/nodes/task_takeover_triage.py b/src/forge/workflow/nodes/task_takeover_triage.py new file mode 100644 index 00000000..56bdac81 --- /dev/null +++ b/src/forge/workflow/nodes/task_takeover_triage.py @@ -0,0 +1,169 @@ +"""Triage node for Task Takeover workflow. + +Evaluates whether a Task or Epic ticket contains sufficient details (Problem Statement, +Proposed Solution/Approach, and Acceptance Criteria) before starting plan generation. +""" + +import json +import logging +from typing import cast + +from forge.config import get_settings +from forge.integrations.agents import ForgeAgent +from forge.integrations.jira.client import JiraClient +from forge.models.workflow import ForgeLabel +from forge.prompts import load_prompt +from forge.workflow.task_takeover.state import TaskTakeoverState +from forge.workflow.utils import update_state_timestamp + +logger = logging.getLogger(__name__) + +_MAX_RETRIES = 3 + +__all__ = ["triage_task"] + + +async def triage_task(state: TaskTakeoverState) -> TaskTakeoverState: + """Evaluate a Task Takeover ticket for completeness before planning. + + Posts an acknowledgement comment on first execution, then evaluates the + ticket against "Problem Statement", "Proposed Solution/Approach", and + "Acceptance Criteria". + + If sufficient, transitions current_node to generate_plan and proceeds. + If missing sections, applies forge:task-triage-pending label, posts a + detailed public comment, sets is_paused = True, and routes to triage_gate. + + On resume, re-evaluates the updated ticket and proceeds to planning if now + sufficient. + + Args: + state: Current TaskTakeoverState. + + Returns: + Updated TaskTakeoverState. + """ + ticket_key = state["ticket_key"] + retry_count = state.get("retry_count", 0) + is_resume = state.get("current_node") == "triage_gate" + + settings = get_settings() + jira = JiraClient(settings) + agent = ForgeAgent(settings) + + try: + if retry_count >= _MAX_RETRIES: + logger.error("triage_task exceeded max retries for %s", ticket_key) + return cast( + TaskTakeoverState, + { + **state, + "current_node": "escalate_blocked", + "is_paused": False, + }, + ) + + # Step 1: Post acknowledgement on first execution only (not on resume) + if not is_resume: + await jira.add_comment( + ticket_key, + "Received task/epic for Task Takeover — checking ticket completeness before starting planning.", + ) + + # Step 2: Fetch full ticket content + issue = await jira.get_issue(ticket_key) + comments = await jira.get_comments(ticket_key) + comment_text = "\n\n".join(c.body for c in comments if c.body) + + # Step 3: Invoke task takeover triage prompt + user_prompt = load_prompt( + "task-takeover-triage", + summary=issue.summary or "", + description=issue.description or "", + comments=comment_text, + ) + raw_result = await agent.run_task( + task="task-takeover-triage", + prompt=user_prompt, + context={"ticket_key": ticket_key}, + ) + + # Step 4: Parse result + result_stripped = raw_result.strip() + if result_stripped.lower() == "sufficient": + pass_msg = ( + "Thanks for the update — ticket now has enough information to proceed. " + "Starting plan generation — results will be posted here." + if is_resume + else "Ticket has enough information to proceed. Starting plan generation — results will be posted here." + ) + await jira.add_comment(ticket_key, pass_msg) + return cast( + TaskTakeoverState, + update_state_timestamp( + { + **state, + "triage_passed": True, + "triage_missing_fields": [], + "current_node": "generate_plan", + "is_paused": False, + "last_error": None, + "retry_count": 0, + } + ), + ) + + # Step 5: Missing fields path + # Strip markdown code fences that LLMs sometimes add despite instructions + json_candidate = result_stripped + if json_candidate.startswith("```"): + lines = json_candidate.splitlines() + json_candidate = "\n".join(line for line in lines if not line.startswith("```")).strip() + try: + missing_fields = json.loads(json_candidate) + if not isinstance(missing_fields, list): + raise ValueError("Expected a list") + except (json.JSONDecodeError, ValueError): + logger.warning("Unexpected triage output for %s: %r", ticket_key, result_stripped) + missing_fields = [ + "(could not determine — please provide additional context about the task)" + ] + + fields_listed = "\n".join(f"- {f}" for f in missing_fields) + await jira.add_comment( + ticket_key, + f"To proceed with task takeover planning, please provide the following information:\n\n{fields_listed}", + ) + await jira.set_workflow_label(ticket_key, ForgeLabel.TASK_TRIAGE_PENDING) + + return cast( + TaskTakeoverState, + update_state_timestamp( + { + **state, + "triage_passed": False, + "triage_missing_fields": missing_fields, + "current_node": "triage_gate", + "is_paused": True, + "last_error": None, + "retry_count": 0, + } + ), + ) + + except Exception as e: + logger.error("triage_task failed for %s: %s", ticket_key, e) + new_retry = retry_count + 1 + return cast( + TaskTakeoverState, + { + **state, + "last_error": str(e), + "retry_count": new_retry, + "current_node": "escalate_blocked" if new_retry >= _MAX_RETRIES else "triage_check", + "is_paused": False, + }, + ) + finally: + await jira.close() + await agent.close() diff --git a/src/forge/workflow/registry.py b/src/forge/workflow/registry.py index 7e34df78..47db96cc 100644 --- a/src/forge/workflow/registry.py +++ b/src/forge/workflow/registry.py @@ -3,11 +3,13 @@ from forge.workflow.bug import BugWorkflow from forge.workflow.feature import FeatureWorkflow from forge.workflow.router import WorkflowRouter +from forge.workflow.task_takeover import TaskTakeoverWorkflow def create_default_router() -> WorkflowRouter: """Create router with built-in workflows.""" router = WorkflowRouter() + router.register(TaskTakeoverWorkflow) router.register(FeatureWorkflow) router.register(BugWorkflow) return router diff --git a/src/forge/workflow/router.py b/src/forge/workflow/router.py index b68387a8..8435297a 100644 --- a/src/forge/workflow/router.py +++ b/src/forge/workflow/router.py @@ -24,6 +24,42 @@ def resolve( ) -> BaseWorkflow | None: """Find the first matching workflow for given ticket/event.""" for workflow_class in self._workflows: + if workflow_class.name == "task_takeover": + # Guarantee exact label matching for resolving triggers, avoiding any prefix-based triggers + allowed_triggers = { + "forge:task-takeover", + "forge:managed:task", + "forge:managed:task-takeover", + } + try: + from forge.config import get_settings + + settings = get_settings() + if ( + settings.task_takeover + and settings.task_takeover.labels + and settings.task_takeover.labels.trigger + ): + allowed_triggers.add(settings.task_takeover.labels.trigger) + except Exception: + pass + + # Filter out labels that start with trigger prefixes but are not exact matches + cleaned_labels = [] + for label in labels: + is_prefix_trigger = False + for trigger_prefix in ["forge:task-takeover", "forge:managed:task"]: + if label.startswith(trigger_prefix) and label not in allowed_triggers: + is_prefix_trigger = True + break + if not is_prefix_trigger: + cleaned_labels.append(label) + + instance = workflow_class() + if instance.matches(ticket_type, cleaned_labels, event): + return instance + continue + instance = workflow_class() if instance.matches(ticket_type, labels, event): return instance diff --git a/src/forge/workflow/task_takeover/__init__.py b/src/forge/workflow/task_takeover/__init__.py new file mode 100644 index 00000000..a66e1377 --- /dev/null +++ b/src/forge/workflow/task_takeover/__init__.py @@ -0,0 +1,65 @@ +"""Task Takeover workflow implementation.""" + +from typing import Any, cast + +from langgraph.graph import StateGraph + +from forge.models.workflow import TicketType +from forge.workflow.base import BaseWorkflow +from forge.workflow.task_takeover.state import ( + TaskTakeoverState, + create_initial_task_takeover_state, +) + + +class TaskTakeoverWorkflow(BaseWorkflow): + """Workflow for Task Takeover tickets.""" + + name = "task_takeover" + description = "Task Takeover workflow" + + @property + def state_schema(self) -> type: + return TaskTakeoverState + + def matches(self, _ticket_type: TicketType, labels: list[str], _event: dict[str, Any]) -> bool: + """Return True only if task_takeover is enabled and any exact task-takeover trigger is present.""" + try: + from forge.config import get_settings + + settings = get_settings() + if not settings.task_takeover or not settings.task_takeover.enabled: + return False + except Exception: + return False + + # Define the exact trigger labels + trigger_labels = { + "forge:task-takeover", + "forge:managed:task", + "forge:managed:task-takeover", + } + + # Include custom trigger from settings if available + if ( + settings.task_takeover + and settings.task_takeover.labels + and settings.task_takeover.labels.trigger + ): + trigger_labels.add(settings.task_takeover.labels.trigger) + + # Check if any exact trigger label is present in the labels list + return any(label in labels for label in trigger_labels) + + def build_graph(self) -> StateGraph[Any]: + """Construct the LangGraph StateGraph for Task Takeover.""" + from forge.workflow.task_takeover.graph import build_task_takeover_graph + + return build_task_takeover_graph() + + def create_initial_state(self, ticket_key: str, **kwargs: Any) -> dict[str, Any]: + """Create initial state for a new Task Takeover workflow run.""" + return cast(dict[str, Any], create_initial_task_takeover_state(ticket_key, **kwargs)) + + +__all__ = ["TaskTakeoverWorkflow", "TaskTakeoverState", "create_initial_task_takeover_state"] diff --git a/src/forge/workflow/task_takeover/graph.py b/src/forge/workflow/task_takeover/graph.py new file mode 100644 index 00000000..1a95b0d6 --- /dev/null +++ b/src/forge/workflow/task_takeover/graph.py @@ -0,0 +1,242 @@ +"""Task Takeover workflow graph construction. + +This module builds the LangGraph StateGraph for the Task Takeover workflow. +""" + +import logging +from typing import Any + +from langgraph.graph import END, StateGraph + +from forge.workflow.gates.task_plan_approval import ( + route_task_plan_approval, + task_plan_approval_gate, +) +from forge.workflow.nodes import ( + answer_question, + create_task_takeover_pr, + escalate_to_blocked, + execute_task_changes, + generate_plan, + route_triage_gate, + run_qualitative_review, + setup_workspace, + triage_gate, + triage_task, +) +from forge.workflow.task_takeover.state import TaskTakeoverState +from forge.workflow.utils import resolve_shared_resume_node + +logger = logging.getLogger(__name__) + + +def route_entry(state: TaskTakeoverState) -> str: + """Route workflow based on current progress for resume/retry. + + New tickets start at triage_check. In-flight tickets with a saved current_node + resume at the appropriate point. + + Args: + state: Current workflow state. + + Returns: + Next node name based on current progress. + """ + current_node = state.get("current_node", "") + + if current_node and current_node not in ("entry", "route_entry", "__end__", "", "start"): + logger.info(f"Resuming task takeover workflow at node: {current_node}") + + # Shared nodes: same resume mapping across all workflow types + shared = resolve_shared_resume_node(current_node) + if shared is not None: + if shared is END: + logger.info(f"Workflow at terminal state '{current_node}', returning END") + return shared + + # Task takeover-specific resume mapping + if current_node == "triage_check": + return "triage_check" + elif current_node == "triage_gate": + return "triage_gate" + elif current_node == "generate_plan": + return "generate_plan" + elif current_node == "task_plan_approval_gate": + return "task_plan_approval_gate" + elif current_node == "setup_workspace": + return "setup_workspace" + elif current_node == "execute_task_changes": + return "execute_task_changes" + elif current_node == "qualitative_review": + return "run_qualitative_review" + elif current_node == "create_task_takeover_pr": + return "create_task_takeover_pr" + elif current_node == "escalate_blocked": + return "escalate_blocked" + else: + logger.warning(f"Unrecognized current_node '{current_node}', restarting from triage") + + # New tasks start at triage + return "triage_check" + + +def _route_after_triage_check(state: TaskTakeoverState) -> str: + """Route after triage_check based on what triage_check set as current_node.""" + node = state.get("current_node", "triage_gate") + if node in ("analyze_bug", "generate_plan"): + return "generate_plan" + if node in ("triage_gate", "escalate_blocked"): + return node + return "triage_gate" + + +def _route_after_answer(state: TaskTakeoverState) -> str: + """Route back to the original gate after answering a question. + + The answer_question node preserves current_node as the gate to return to. + """ + current_node = state.get("current_node", "") + if current_node and "gate" in current_node: + return current_node + return "task_plan_approval_gate" + + +def _route_after_qualitative_review(state: TaskTakeoverState) -> str: + """Route after run_qualitative_review considering qualitative verdict and retry count. + + If the review is adequate (success), proceed to create_task_takeover_pr. + If the review is failed or incomplete: + - Check if we've reached the configured retry limit. + - If limit reached: transition to escalate_blocked. + - Otherwise: transition back to execute_task_changes. + """ + verdict = state.get("review_verdict") + retry_count = state.get("qualitative_review_retry_count", 0) + + if verdict == "adequate": + return "create_task_takeover_pr" + + # Fetch configured retry limit (review_max_attempts) from settings, default to 2 + try: + from forge.config import get_settings + + settings = get_settings() + limit = settings.task_takeover.review_max_attempts + except Exception: + limit = 2 + + if retry_count >= limit: + logger.warning( + f"Qualitative review cap ({limit}) reached on task takeover workflow, transitioning to escalate_blocked" + ) + return "escalate_blocked" + + logger.info( + f"Qualitative review verdict is {verdict!r}, retry attempt {retry_count}/{limit}, " + "routing back to execute_task_changes" + ) + return "execute_task_changes" + + +def build_task_takeover_graph() -> StateGraph[TaskTakeoverState, Any, Any]: + """Create the Task Takeover workflow graph. + + Returns: + Configured StateGraph ready for compilation. + """ + graph = StateGraph(TaskTakeoverState) + + # Entry routing + graph.add_node("route_entry", lambda state: state) + + # Nodes + graph.add_node("triage_check", triage_task) + graph.add_node("triage_gate", triage_gate) + graph.add_node("generate_plan", generate_plan) + graph.add_node("task_plan_approval_gate", task_plan_approval_gate) + graph.add_node("escalate_blocked", escalate_to_blocked) + graph.add_node("answer_question", answer_question) + graph.add_node("setup_workspace", setup_workspace) + graph.add_node("execute_task_changes", execute_task_changes) + graph.add_node("run_qualitative_review", run_qualitative_review) + graph.add_node("create_task_takeover_pr", create_task_takeover_pr) + + # Set entry point + graph.set_entry_point("route_entry") + + # Entry routing edges + graph.add_conditional_edges( + "route_entry", + route_entry, + { + "triage_check": "triage_check", + "triage_gate": "triage_gate", + "generate_plan": "generate_plan", + "task_plan_approval_gate": "task_plan_approval_gate", + "setup_workspace": "setup_workspace", + "execute_task_changes": "execute_task_changes", + "run_qualitative_review": "run_qualitative_review", + "create_task_takeover_pr": "create_task_takeover_pr", + "escalate_blocked": "escalate_blocked", + END: END, + }, + ) + + # Triage flow + graph.add_conditional_edges( + "triage_check", + _route_after_triage_check, + { + "triage_gate": "triage_gate", + "generate_plan": "generate_plan", + "escalate_blocked": "escalate_blocked", + }, + ) + graph.add_conditional_edges( + "triage_gate", + route_triage_gate, + { + END: END, + "triage_check": "triage_check", + }, + ) + + # Planning flow + graph.add_edge("generate_plan", "task_plan_approval_gate") + graph.add_conditional_edges( + "task_plan_approval_gate", + route_task_plan_approval, + { + "regenerate_plan": "generate_plan", + "answer_question": "answer_question", + "setup_workspace": "setup_workspace", + END: END, + }, + ) + + # Execution flow + graph.add_edge("setup_workspace", "execute_task_changes") + graph.add_edge("execute_task_changes", "run_qualitative_review") + graph.add_conditional_edges( + "run_qualitative_review", + _route_after_qualitative_review, + { + "execute_task_changes": "execute_task_changes", + "create_task_takeover_pr": "create_task_takeover_pr", + "escalate_blocked": "escalate_blocked", + }, + ) + graph.add_edge("create_task_takeover_pr", END) + + # Q&A routing + graph.add_conditional_edges( + "answer_question", + _route_after_answer, + { + "task_plan_approval_gate": "task_plan_approval_gate", + }, + ) + + graph.add_edge("escalate_blocked", END) + + return graph diff --git a/src/forge/workflow/task_takeover/state.py b/src/forge/workflow/task_takeover/state.py new file mode 100644 index 00000000..252d17fd --- /dev/null +++ b/src/forge/workflow/task_takeover/state.py @@ -0,0 +1,67 @@ +"""Task Takeover workflow state definition.""" + +from datetime import datetime +from typing import Any, cast + +from forge.models.workflow import TicketType +from forge.workflow.base import ( + BaseState, + CIIntegrationState, + PRIntegrationState, + ReviewIntegrationState, +) + + +class TaskTakeoverState( + BaseState, PRIntegrationState, CIIntegrationState, ReviewIntegrationState, total=False +): + """State specific to Task Takeover workflow.""" + + ticket_type: TicketType + plan_content: str | None + triage_passed: bool + triage_missing_fields: list[str] + review_verdict: str | None + review_feedback: str | None + qualitative_review_retry_count: int + qualitative_review_failed: bool + + +def create_initial_task_takeover_state(ticket_key: str, **kwargs: Any) -> TaskTakeoverState: + """Create initial state for a new Task Takeover workflow run.""" + now = datetime.utcnow().isoformat() + defaults: dict[str, Any] = { + "thread_id": ticket_key, + "ticket_key": ticket_key, + "ticket_type": TicketType.TASK, + "current_node": "start", + "is_paused": False, + "retry_count": 0, + "last_error": None, + "created_at": now, + "updated_at": now, + "workspace_path": None, + "pr_urls": [], + "fork_owner": None, + "fork_repo": None, + "merge_conflicts": [], + "local_review_attempts": 0, + "local_review_pass_number": 1, + "ci_status": None, + "current_pr_url": None, + "current_pr_number": None, + "current_repo": None, + "repos_to_process": [], + "repos_completed": [], + "implemented_tasks": [], + "current_task_key": None, + "triage_passed": False, + "triage_missing_fields": [], + "plan_content": None, + "review_verdict": None, + "review_feedback": None, + "qualitative_review_retry_count": 0, + "qualitative_review_failed": False, + } + defaults.update(kwargs) + return cast(TaskTakeoverState, defaults) diff --git a/tests/flows/status_transitions/test_prd_rejected.py b/tests/flows/status_transitions/test_prd_rejected.py index e0a356ea..88bcbf90 100644 --- a/tests/flows/status_transitions/test_prd_rejected.py +++ b/tests/flows/status_transitions/test_prd_rejected.py @@ -5,9 +5,9 @@ import pytest from forge.models.workflow import TicketType +from forge.workflow.feature.state import create_initial_feature_state as create_initial_state from forge.workflow.gates import route_prd_approval from forge.workflow.nodes import regenerate_prd_with_feedback -from forge.workflow.feature.state import create_initial_feature_state as create_initial_state class TestPrdRejectedOnce: @@ -54,6 +54,7 @@ async def test_regeneration_incorporates_feedback(self, prd_pending_state): mock_jira = MagicMock() mock_jira.update_description = AsyncMock() mock_jira.add_comment = AsyncMock() + mock_jira.add_structured_comment = AsyncMock() mock_jira.close = AsyncMock() mock_agent = MagicMock() @@ -94,6 +95,7 @@ async def test_after_regeneration_returns_to_pending(self, prd_pending_state): mock_jira = MagicMock() mock_jira.update_description = AsyncMock() mock_jira.add_comment = AsyncMock() + mock_jira.add_structured_comment = AsyncMock() mock_jira.close = AsyncMock() mock_agent = MagicMock() @@ -159,13 +161,12 @@ async def test_revision_count_increments(self, prd_state_first_revision): mock_jira = MagicMock() mock_jira.update_description = AsyncMock() mock_jira.add_comment = AsyncMock() + mock_jira.add_structured_comment = AsyncMock() mock_jira.close = AsyncMock() mock_agent = MagicMock() # Simulate error to increment retry count - mock_agent.regenerate_with_feedback = AsyncMock( - side_effect=Exception("Simulated error") - ) + mock_agent.regenerate_with_feedback = AsyncMock(side_effect=Exception("Simulated error")) mock_agent.close = AsyncMock() with patch("forge.workflow.nodes.prd_generation.JiraClient", return_value=mock_jira): @@ -202,6 +203,7 @@ async def test_regeneration_uses_original_prd(self, prd_with_context): mock_jira = MagicMock() mock_jira.update_description = AsyncMock() mock_jira.add_comment = AsyncMock() + mock_jira.add_structured_comment = AsyncMock() mock_jira.close = AsyncMock() mock_agent = MagicMock() @@ -222,6 +224,7 @@ async def test_feedback_is_passed_to_agent(self, prd_with_context): mock_jira = MagicMock() mock_jira.update_description = AsyncMock() mock_jira.add_comment = AsyncMock() + mock_jira.add_structured_comment = AsyncMock() mock_jira.close = AsyncMock() mock_agent = MagicMock() diff --git a/tests/integration/orchestrator/test_local_review_status_comments.py b/tests/integration/orchestrator/test_local_review_status_comments.py index f7da13b8..9f8f953f 100644 --- a/tests/integration/orchestrator/test_local_review_status_comments.py +++ b/tests/integration/orchestrator/test_local_review_status_comments.py @@ -129,67 +129,34 @@ def track_comment(ticket_key, message): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass1), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass1, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): state = await local_review_changes(state) - # Pass 2: has unfixed issues, should post fix comment with pass 2 and retry - mock_runner_pass2 = create_mock_container_runner(has_unfixed_issues=True) - - with ( - patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass2), - patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), - ): - state = await local_review_changes(state) - - # Pass 3: no unfixed issues, should post fix comment with pass 3 and route to create_pr - # Note: MAX_REVIEW_ATTEMPTS is 2, so pass 3 would be the final attempt - # We need to test the scenario where it succeeds on the last attempt - mock_runner_pass3 = create_mock_container_runner(has_unfixed_issues=False) + # Pass 2: no unfixed issues, should post fix comment with pass 2 and route to create_pr + mock_runner_pass2 = create_mock_container_runner(has_unfixed_issues=False) with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass3), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass2, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): - result = await local_review_changes(state) + await local_review_changes(state) - # Verify all comments were posted: initial + fix(2) + fix(3) - # Note: Only 2 comments will be posted because MAX_REVIEW_ATTEMPTS=2 - # Pass 1: initial comment, Pass 2: fix comment (pass 2) - # Pass 3 would exceed max attempts, so it doesn't run the container - # Let me reconsider the test scenario based on MAX_REVIEW_ATTEMPTS=2 - - # With MAX_REVIEW_ATTEMPTS=2: - # Pass 1 (attempt 0): initial comment, finds issues, increments to attempt 1, pass 2 - # Pass 2 (attempt 1): fix comment (pass 2), finds no issues OR hits max attempts - - # For a 3-comment scenario (initial + 2 fix comments), we need: - # Pass 1: initial, finds issues -> retry - # Pass 2: fix (pass 2), finds issues -> retry - # Pass 3: Would be attempt 2 which equals MAX_REVIEW_ATTEMPTS, so it runs one more time - - # Actually reviewing the code: review_attempts + 1 < MAX_REVIEW_ATTEMPTS - # So with MAX_REVIEW_ATTEMPTS=2: - # - attempt 0: runs, if issues and 0+1 < 2, retry (yes) - # - attempt 1: runs, if issues and 1+1 < 2, retry (no, 2 is not < 2) - - # So we can only get 2 passes max with MAX_REVIEW_ATTEMPTS=2 - # Pass 1 (attempt 0): initial comment - # Pass 2 (attempt 1): fix comment (pass 2) - - # For TS-005 to work as specified (3 fix passes), I need to adjust the test - # or acknowledge that MAX_REVIEW_ATTEMPTS limits this - - # Let me verify what comments were actually posted + # Verify all comments were posted: initial + fix(pass 2) assert len(all_comments) == 2 # Initial + fix(pass 2) - + # Verify initial comment assert all_comments[0][0] == "FEAT-201" assert all_comments[0][1] == "🔍 Running local code review on changes before creating PR." - + # Verify fix comment with pass 2 assert all_comments[1][0] == "FEAT-201" assert all_comments[1][1] == "🔧 Local review found issues, applying fixes (pass 2)." @@ -225,7 +192,10 @@ def track_comment(ticket_key, message): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass1), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass1, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): state = await local_review_changes(state) @@ -235,7 +205,10 @@ def track_comment(ticket_key, message): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass2), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass2, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): state = await local_review_changes(state) @@ -245,22 +218,25 @@ def track_comment(ticket_key, message): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass3), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass3, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): result = await local_review_changes(state) # Verify all comments were posted: initial + fix(2) + fix(3) assert len(all_comments) == 3 - + # Verify initial comment assert all_comments[0][0] == "FEAT-202" assert all_comments[0][1] == "🔍 Running local code review on changes before creating PR." - + # Verify fix comment with pass 2 assert all_comments[1][0] == "FEAT-202" assert all_comments[1][1] == "🔧 Local review found issues, applying fixes (pass 2)." - + # Verify fix comment with pass 3 assert all_comments[2][0] == "FEAT-202" assert all_comments[2][1] == "🔧 Local review found issues, applying fixes (pass 3)." @@ -307,23 +283,31 @@ def track_comment(ticket_key, message): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner), - patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner, + ), + patch( + "forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git + ), ): state = await local_review_changes(state) # Verify all comments were posted: initial + fix(2) + fix(3) + fix(4) + fix(5) + fix(6) assert len(all_comments) == 6 - + # Verify initial comment assert all_comments[0][0] == "FEAT-203" assert all_comments[0][1] == "🔍 Running local code review on changes before creating PR." - + # Verify fix comments with incrementing pass numbers for i in range(1, 6): pass_num = i + 1 assert all_comments[i][0] == "FEAT-203" - assert all_comments[i][1] == f"🔧 Local review found issues, applying fixes (pass {pass_num})." + assert ( + all_comments[i][1] + == f"🔧 Local review found issues, applying fixes (pass {pass_num})." + ) # Verify workflow routed to create_pr assert state["current_node"] == "create_pr" @@ -363,7 +347,7 @@ async def test_pass_number_resets_when_transitioning_from_implementation_to_loca ): mock_git = create_mock_git_operations(has_changes=False) mock_git_class.return_value = mock_git - + result = await implement_task(state) # Verify pass_number was reset to 1 when entering local_review phase @@ -405,7 +389,10 @@ async def test_pass_number_persists_and_increments_within_same_feature(self): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass1), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass1, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): state = await local_review_changes(state) @@ -420,7 +407,10 @@ async def test_pass_number_persists_and_increments_within_same_feature(self): with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner_pass2), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner_pass2, + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): result = await local_review_changes(state) @@ -448,13 +438,18 @@ async def test_pass_number_increments_correctly_across_multiple_iterations(self) # Passes 1-3: have unfixed issues for expected_pass_num in [1, 2, 3]: assert state["local_review_pass_number"] == expected_pass_num - + mock_runner = create_mock_container_runner(has_unfixed_issues=True) with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner), - patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", + return_value=mock_runner, + ), + patch( + "forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git + ), ): state = await local_review_changes(state) @@ -468,7 +463,9 @@ async def test_pass_number_increments_correctly_across_multiple_iterations(self) with ( patch("forge.workflow.nodes.local_reviewer.JiraClient", return_value=mock_jira), - patch("forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner), + patch( + "forge.workflow.nodes.local_reviewer.ContainerRunner", return_value=mock_runner + ), patch("forge.workflow.nodes.local_reviewer.GitOperations", return_value=mock_git), ): result = await local_review_changes(state) diff --git a/tests/integration/orchestrator/test_task_handoff.py b/tests/integration/orchestrator/test_task_handoff.py index c4c36ce1..fbf0e316 100644 --- a/tests/integration/orchestrator/test_task_handoff.py +++ b/tests/integration/orchestrator/test_task_handoff.py @@ -41,7 +41,7 @@ async def test_workspace_setup_creates_forge_directory(self): async def test_workspace_setup_node_creates_forge_directory(self): """The setup_workspace node should create .forge directory structure.""" - from forge.orchestrator.nodes import setup_workspace + from forge.workflow.nodes import setup_workspace initial_state = create_initial_state( thread_id="TEST-123", @@ -50,14 +50,17 @@ async def test_workspace_setup_node_creates_forge_directory(self): ) initial_state["tasks_by_repo"] = {"test-org/test-repo": ["TASK-1", "TASK-2"]} - with patch("forge.workflow.nodes.workspace_setup.GitOperations") as MockGit, \ - patch("forge.workflow.nodes.workspace_setup.GuardrailsLoader") as MockGuardrails: - + with ( + patch("forge.workflow.nodes.workspace_setup.GitOperations") as MockGit, + patch("forge.workflow.nodes.workspace_setup.GuardrailsLoader") as MockGuardrails, + ): mock_git = MagicMock() MockGit.return_value = mock_git mock_guardrails = MagicMock() - mock_guardrails.load.return_value = MagicMock(get_system_context=MagicMock(return_value="")) + mock_guardrails.load.return_value = MagicMock( + get_system_context=MagicMock(return_value="") + ) MockGuardrails.return_value = mock_guardrails result = await setup_workspace(initial_state) @@ -66,7 +69,9 @@ async def test_workspace_setup_node_creates_forge_directory(self): if result.get("workspace_path"): workspace_path = Path(result["workspace_path"]) assert (workspace_path / ".forge").exists(), ".forge should be created" - assert (workspace_path / ".forge" / "history").exists(), ".forge/history should be created" + assert (workspace_path / ".forge" / "history").exists(), ( + ".forge/history should be created" + ) class TestPreviousTaskKeysPassing: @@ -80,9 +85,10 @@ async def test_runner_passes_previous_task_keys_in_task_file(self): workspace = Path(workspace_dir) # Mock podman and settings - with patch("forge.sandbox.runner.shutil.which", return_value="/usr/bin/podman"), \ - patch("forge.sandbox.runner.get_settings") as mock_settings: - + with ( + patch("forge.sandbox.runner.shutil.which", return_value="/usr/bin/podman"), + patch("forge.sandbox.runner.get_settings") as mock_settings, + ): settings = MagicMock() settings.anthropic_api_key.get_secret_value.return_value = "test-key" settings.use_vertex_ai = False @@ -96,9 +102,10 @@ async def test_runner_passes_previous_task_keys_in_task_file(self): runner = ContainerRunner(settings) # Mock the actual run to just create the task file - with patch.object(runner, "_build_podman_command", return_value=["echo", "test"]), \ - patch("asyncio.create_subprocess_exec") as mock_exec: - + with ( + patch.object(runner, "_build_podman_command", return_value=["echo", "test"]), + patch("asyncio.create_subprocess_exec") as mock_exec, + ): mock_process = AsyncMock() mock_process.communicate = AsyncMock(return_value=(b"", b"")) mock_process.returncode = 0 @@ -118,8 +125,8 @@ async def test_runner_passes_previous_task_keys_in_task_file(self): async def test_implementation_node_passes_implemented_tasks(self): """Implementation node should pass implemented_tasks as previous_task_keys.""" - from forge.orchestrator.nodes import implement_task from forge.workflow.feature.state import FeatureState as WorkflowState + from forge.workflow.nodes import implement_task with tempfile.TemporaryDirectory() as workspace_dir: state: WorkflowState = { @@ -133,10 +140,11 @@ async def test_implementation_node_passes_implemented_tasks(self): "context": {"guardrails": ""}, } - with patch("forge.workflow.nodes.implementation.JiraClient") as MockJira, \ - patch("forge.workflow.nodes.implementation.ContainerRunner") as MockRunner, \ - patch("forge.workflow.nodes.implementation.get_settings") as mock_settings: - + with ( + patch("forge.workflow.nodes.implementation.JiraClient") as MockJira, + patch("forge.workflow.nodes.implementation.ContainerRunner") as MockRunner, + patch("forge.workflow.nodes.implementation.get_settings") as mock_settings, + ): # Setup mocks mock_jira = MagicMock() mock_jira.get_issue = AsyncMock( @@ -149,9 +157,7 @@ async def test_implementation_node_passes_implemented_tasks(self): MockJira.return_value = mock_jira mock_runner = MagicMock() - mock_runner.run = AsyncMock( - return_value=MagicMock(success=True, exit_code=0) - ) + mock_runner.run = AsyncMock(return_value=MagicMock(success=True, exit_code=0)) MockRunner.return_value = mock_runner mock_settings.return_value = MagicMock() @@ -178,8 +184,9 @@ def test_container_system_prompt_includes_handoff_instructions(self): assert ".forge/history/" in prompt, "Prompt should reference history directory" # Check for handoff writing instructions - assert "Update handoff" in prompt or "update `.forge/handoff.md`" in prompt, \ + assert "Update handoff" in prompt or "update `.forge/handoff.md`" in prompt, ( "Prompt should instruct agent to update handoff" + ) def test_entrypoint_builds_prompt_with_previous_task_keys(self): """Entrypoint build_system_prompt should include previous task keys.""" @@ -228,8 +235,9 @@ def test_entrypoint_handles_empty_previous_tasks(self): ) # Should indicate this is the first task - assert "first task" in prompt.lower() or "none" in prompt.lower(), \ + assert "first task" in prompt.lower() or "none" in prompt.lower(), ( "Prompt should indicate no previous tasks" + ) finally: sys.path.remove(str(containers_path)) @@ -301,8 +309,9 @@ def test_container_prompt_includes_gitignore_instructions(self): # Prompt should warn against committing .forge/ (using "NEVER commit" wording) assert ".forge/" in prompt, "Prompt should mention .forge/ directory" - assert "NEVER commit" in prompt or "never commit" in prompt.lower(), \ + assert "NEVER commit" in prompt or "never commit" in prompt.lower(), ( "Prompt should warn against committing .forge/" + ) class TestHistoryPersistence: diff --git a/tests/integration/orchestrator/test_task_implementation_status.py b/tests/integration/orchestrator/test_task_implementation_status.py index 76060b86..b1e7de9a 100644 --- a/tests/integration/orchestrator/test_task_implementation_status.py +++ b/tests/integration/orchestrator/test_task_implementation_status.py @@ -76,7 +76,9 @@ async def test_single_task_receives_start_comment(self): assert mock_jira.add_comment.call_count >= 1 start_call = mock_jira.add_comment.call_args_list[0] assert start_call[0][0] == "TASK-001" - assert start_call[0][1] == "🔨 Forge is implementing this task." + assert ( + start_call[0][1] == "🔨 Forge started implementing [TASK-001]: Task summary for testing" + ) @pytest.mark.asyncio async def test_single_task_receives_completion_comment_on_success(self): @@ -105,12 +107,17 @@ async def test_single_task_receives_completion_comment_on_success(self): # Verify start comment start_call = mock_jira.add_comment.call_args_list[0] assert start_call[0][0] == "TASK-001" - assert start_call[0][1] == "🔨 Forge is implementing this task." + assert ( + start_call[0][1] == "🔨 Forge started implementing [TASK-001]: Task summary for testing" + ) # Verify completion comment with exact text completion_call = mock_jira.add_comment.call_args_list[1] assert completion_call[0][0] == "TASK-001" - assert completion_call[0][1] == "✅ Implementation complete. Running local code review before PR." + assert ( + completion_call[0][1] + == "✅ Implementation complete. Running local code review before PR." + ) # Verify task was marked as implemented assert "TASK-001" in result["implemented_tasks"] @@ -119,7 +126,9 @@ async def test_single_task_receives_completion_comment_on_success(self): async def test_single_task_no_completion_comment_on_failure(self): """TS-003: Verify NO completion comment when task implementation fails.""" mock_jira = create_mock_jira_client() - mock_runner = create_mock_container_runner(success=False, error_message="Implementation error") + mock_runner = create_mock_container_runner( + success=False, error_message="Implementation error" + ) state = create_initial_feature_state( ticket_key="FEAT-100", @@ -141,7 +150,9 @@ async def test_single_task_no_completion_comment_on_failure(self): assert mock_jira.add_comment.call_count == 1 start_call = mock_jira.add_comment.call_args_list[0] assert start_call[0][0] == "TASK-001" - assert start_call[0][1] == "🔨 Forge is implementing this task." + assert ( + start_call[0][1] == "🔨 Forge started implementing [TASK-001]: Task summary for testing" + ) # Verify error state assert result["last_error"] == "Implementation error" @@ -176,7 +187,10 @@ async def test_multiple_tasks_receive_independent_start_comments(self): # Verify first task got start and completion comments with correct task_key assert mock_jira1.add_comment.call_count == 2 assert mock_jira1.add_comment.call_args_list[0][0][0] == "TASK-100" - assert mock_jira1.add_comment.call_args_list[0][0][1] == "🔨 Forge is implementing this task." + assert ( + mock_jira1.add_comment.call_args_list[0][0][1] + == "🔨 Forge started implementing [TASK-100]: Task summary for testing" + ) assert mock_jira1.add_comment.call_args_list[1][0][0] == "TASK-100" # Reset mock for second task @@ -191,12 +205,15 @@ async def test_multiple_tasks_receive_independent_start_comments(self): patch("forge.workflow.nodes.implementation.JiraClient", return_value=mock_jira2), patch("forge.workflow.nodes.implementation.ContainerRunner", return_value=mock_runner2), ): - result2 = await implement_task(state2) + await implement_task(state2) # Verify second task got its own independent start and completion comments assert mock_jira2.add_comment.call_count == 2 assert mock_jira2.add_comment.call_args_list[0][0][0] == "TASK-101" - assert mock_jira2.add_comment.call_args_list[0][0][1] == "🔨 Forge is implementing this task." + assert ( + mock_jira2.add_comment.call_args_list[0][0][1] + == "🔨 Forge started implementing [TASK-101]: Task summary for testing" + ) assert mock_jira2.add_comment.call_args_list[1][0][0] == "TASK-101" @pytest.mark.asyncio @@ -226,8 +243,14 @@ async def test_multiple_tasks_receive_independent_completion_comments(self): call for call in mock_jira1.add_comment.call_args_list if call[0][0] == "TASK-200" ] assert len(task200_calls) == 2 - assert task200_calls[0][0][1] == "🔨 Forge is implementing this task." - assert task200_calls[1][0][1] == "✅ Implementation complete. Running local code review before PR." + assert ( + task200_calls[0][0][1] + == "🔨 Forge started implementing [TASK-200]: Task summary for testing" + ) + assert ( + task200_calls[1][0][1] + == "✅ Implementation complete. Running local code review before PR." + ) # Second task mock_jira2 = create_mock_jira_client() @@ -247,8 +270,14 @@ async def test_multiple_tasks_receive_independent_completion_comments(self): call for call in mock_jira2.add_comment.call_args_list if call[0][0] == "TASK-201" ] assert len(task201_calls) == 2 - assert task201_calls[0][0][1] == "🔨 Forge is implementing this task." - assert task201_calls[1][0][1] == "✅ Implementation complete. Running local code review before PR." + assert ( + task201_calls[0][0][1] + == "🔨 Forge started implementing [TASK-201]: Task summary for testing" + ) + assert ( + task201_calls[1][0][1] + == "✅ Implementation complete. Running local code review before PR." + ) # Third task mock_jira3 = create_mock_jira_client() @@ -268,8 +297,14 @@ async def test_multiple_tasks_receive_independent_completion_comments(self): call for call in mock_jira3.add_comment.call_args_list if call[0][0] == "TASK-202" ] assert len(task202_calls) == 2 - assert task202_calls[0][0][1] == "🔨 Forge is implementing this task." - assert task202_calls[1][0][1] == "✅ Implementation complete. Running local code review before PR." + assert ( + task202_calls[0][0][1] + == "🔨 Forge started implementing [TASK-202]: Task summary for testing" + ) + assert ( + task202_calls[1][0][1] + == "✅ Implementation complete. Running local code review before PR." + ) # Verify all three tasks are marked as implemented assert result3["implemented_tasks"] == ["TASK-200", "TASK-201", "TASK-202"] @@ -304,7 +339,10 @@ async def test_task_implementation_fails_midway_no_completion_comment(self): # Verify only start comment, no completion comment assert mock_jira.add_comment.call_count == 1 assert mock_jira.add_comment.call_args_list[0][0][0] == "TASK-300" - assert mock_jira.add_comment.call_args_list[0][0][1] == "🔨 Forge is implementing this task." + assert ( + mock_jira.add_comment.call_args_list[0][0][1] + == "🔨 Forge started implementing [TASK-300]: Task summary for testing" + ) # Verify error is set and task not implemented assert "Container crashed" in result["last_error"] @@ -388,7 +426,8 @@ async def test_workflow_continues_when_start_comment_posting_fails(self, caplog) # Verify error was logged (from jira_status utility) assert any( - "Failed to post status comment to TASK-500" in record.message for record in caplog.records + "Failed to post status comment to TASK-500" in record.message + for record in caplog.records ) @pytest.mark.asyncio @@ -430,7 +469,8 @@ async def add_comment_side_effect(*args, **kwargs): # Verify error was logged assert any( - "Failed to post status comment to TASK-501" in record.message for record in caplog.records + "Failed to post status comment to TASK-501" in record.message + for record in caplog.records ) @pytest.mark.asyncio @@ -462,6 +502,8 @@ async def test_workflow_continues_when_all_comment_posting_fails(self, caplog): # Verify errors were logged for both start and completion attempts error_logs = [ - record for record in caplog.records if "Failed to post status comment to TASK-502" in record.message + record + for record in caplog.records + if "Failed to post status comment to TASK-502" in record.message ] assert len(error_logs) == 2 # Both start and completion comments should have logged errors diff --git a/tests/integration/test_qa_mode.py b/tests/integration/test_qa_mode.py index e1e4c64f..ea49dacd 100644 --- a/tests/integration/test_qa_mode.py +++ b/tests/integration/test_qa_mode.py @@ -15,8 +15,8 @@ def test_question_comment_classified_correctly(self): """Verify comment classifier detects questions.""" assert classify_comment("?Why REST?") == CommentType.QUESTION assert classify_comment("@forge ask explain") == CommentType.QUESTION - assert classify_comment("Add more detail") == CommentType.FEEDBACK - assert classify_comment("LGTM") == CommentType.FEEDBACK + assert classify_comment("!Add more detail") == CommentType.FEEDBACK + assert classify_comment("LGTM") == CommentType.INFORMATIONAL def test_state_has_qa_fields(self): """Verify initial state includes Q&A fields.""" diff --git a/tests/sandbox/test_task_execution.py b/tests/sandbox/test_task_execution.py new file mode 100644 index 00000000..3c201cfe --- /dev/null +++ b/tests/sandbox/test_task_execution.py @@ -0,0 +1,334 @@ +"""Integrated and sandbox tests for task execution in container environments.""" + +import tempfile +from collections.abc import Generator +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.models.workflow import TicketType +from forge.sandbox.runner import ContainerConfig, ContainerRunner +from forge.workflow.nodes.task_takeover_execution import execute_task_changes +from forge.workflow.nodes.task_takeover_pr import cleanup_podman_containers +from forge.workflow.nodes.workspace_setup import teardown_workspace + + +def _make_state( + ticket_key: str = "TASK-123", + ticket_type: TicketType = TicketType.TASK, + workspace_path: str | None = "/tmp/ws", + current_repo: str = "acme/backend", + plan_content: str = "This is the approved plan.", + implemented_tasks: list[str] | None = None, +) -> dict[str, Any]: + return { + "ticket_key": ticket_key, + "ticket_type": ticket_type, + "current_node": "execute_task_changes", + "is_paused": False, + "retry_count": 0, + "last_error": None, + "workspace_path": workspace_path, + "current_repo": current_repo, + "plan_content": plan_content, + "implemented_tasks": implemented_tasks or [], + "context": {"branch_name": "forge/TASK-123", "guardrails": ""}, + } + + +def _make_mock_jira() -> AsyncMock: + jira = AsyncMock() + issue = MagicMock() + issue.summary = "Fix validation bug" + issue.description = "Validation logic in auth is failing" + jira.get_issue = AsyncMock(return_value=issue) + jira.add_comment = AsyncMock() + jira.close = AsyncMock() + return jira + + +def _make_mock_git(has_changes: bool = True, sha: str = "abcdef1234567890") -> MagicMock: + git = MagicMock() + git.has_uncommitted_changes = MagicMock(return_value=has_changes) + git.stage_all = MagicMock() + git.commit = MagicMock(return_value=True) + git.get_current_sha = MagicMock(return_value=sha) + return git + + +class TestTaskExecutionSandbox: + """Integrated tests verifying ContainerRunner and workflow task execution.""" + + @pytest.fixture(autouse=True) + def mock_podman_exists(self) -> Generator[None, None, None]: + with patch("shutil.which", return_value="/usr/bin/podman"): + yield + + @pytest.mark.asyncio + @patch("asyncio.create_subprocess_exec") + async def test_container_runner_successful_execution(self, mock_create_proc: AsyncMock) -> None: + """Test ContainerRunner correctly runs a task with successful output.""" + # Arrange + mock_proc = AsyncMock() + mock_proc.communicate = AsyncMock(return_value=(b"Agent finished successfully", b"")) + mock_proc.returncode = 0 + mock_create_proc.return_value = mock_proc + + runner = ContainerRunner() + config = ContainerConfig() + + with tempfile.TemporaryDirectory() as tmpdir: + workspace_path = Path(tmpdir) + + # Act + result = await runner.run( + workspace_path=workspace_path, + task_summary="Add simple feature", + task_description="Implement some changes", + config=config, + ticket_key="TASK-123", + task_key="TASK-123", + repo_name="acme/backend", + ) + + # Assert + assert result.success is True + assert result.exit_code == 0 + assert "Agent finished successfully" in result.stdout + assert not (workspace_path / ".forge" / "task.json").exists() + + # Verify podman run command construction + mock_create_proc.assert_called_once() + cmd_args = mock_create_proc.call_args[0] + assert cmd_args[0] == "podman" + assert cmd_args[1] == "run" + assert f"{workspace_path}:/workspace:Z" in cmd_args + assert any("TASK-123" in arg for arg in cmd_args) + assert "--memory" in cmd_args + assert "--cpus" in cmd_args + + @pytest.mark.asyncio + @patch("asyncio.create_subprocess_exec") + async def test_execute_task_changes_successful_workflow( + self, mock_create_proc: AsyncMock + ) -> None: + """Test the execute_task_changes workflow node with successful container execution.""" + # Arrange + mock_proc = AsyncMock() + mock_proc.communicate = AsyncMock( + return_value=(b"Implementing changes...\nTests passed!", b"") + ) + mock_proc.returncode = 0 + mock_create_proc.return_value = mock_proc + + mock_jira = _make_mock_jira() + mock_git = _make_mock_git(has_changes=True, sha="9876543210abcdef") + + with tempfile.TemporaryDirectory() as tmpdir: + workspace_path = Path(tmpdir) + state = _make_state(workspace_path=str(workspace_path)) + + with ( + patch( + "forge.workflow.nodes.task_takeover_execution.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.task_takeover_execution.GitOperations", + return_value=mock_git, + ), + patch("forge.workflow.nodes.task_takeover_execution.get_settings"), + ): + # Act + updated_state = await execute_task_changes(state) + + # Assert + assert updated_state["task_execution_results"]["success"] is True + assert updated_state["task_execution_results"]["exit_code"] == 0 + assert "Tests passed!" in updated_state["task_execution_logs"]["stdout"] + assert updated_state["commit_info"]["committed"] is True + assert updated_state["commit_info"]["sha"] == "9876543210abcdef" + assert updated_state["last_error"] is None + assert updated_state["retry_count"] == 0 + + # Verify JIRA interactions + mock_jira.get_issue.assert_called_once_with("TASK-123") + mock_jira.add_comment.assert_called() + mock_jira.close.assert_called_once() + + # Verify Git interactions on the host + mock_git.has_uncommitted_changes.assert_called_once() + mock_git.stage_all.assert_called_once() + mock_git.commit.assert_called_once_with( + "[TASK-123] feat: implement task takeover execution changes and tests" + ) + + @pytest.mark.asyncio + @patch("asyncio.create_subprocess_exec") + async def test_build_and_test_recovery_workflow_iterative_self_correction( + self, mock_create_proc: AsyncMock + ) -> None: + """Test build-and-test recovery workflow where compilation errors/test failures are fed back. + + We simulate a container execution that first fails (representing compilation/test failures), + captures the failure logs back to the state, and on the subsequent retry/run, + successfully implements self-correction and passes. + """ + # --- FIRST RUN: Simulated compilation/test failure --- + mock_proc_fail = AsyncMock() + mock_proc_fail.communicate = AsyncMock( + return_value=( + b"Compiling and running tests...\nFailed!", + b"SyntaxError: invalid syntax at auth.py line 25", + ) + ) + mock_proc_fail.returncode = 2 # EXIT_TESTS_FAILED or EXIT_TASK_FAILED + mock_create_proc.return_value = mock_proc_fail + + mock_jira = _make_mock_jira() + mock_git_fail = _make_mock_git(has_changes=False) + + with tempfile.TemporaryDirectory() as tmpdir: + workspace_path = Path(tmpdir) + state_initial = _make_state(workspace_path=str(workspace_path)) + + with ( + patch( + "forge.workflow.nodes.task_takeover_execution.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.task_takeover_execution.GitOperations", + return_value=mock_git_fail, + ), + patch("forge.workflow.nodes.task_takeover_execution.get_settings"), + ): + # Act + state_after_fail = await execute_task_changes(state_initial) + + # Assert first run failed as expected, recording logs and error feedback + assert state_after_fail["task_execution_results"]["success"] is False + assert state_after_fail["task_execution_results"]["exit_code"] == 2 + assert "SyntaxError" in state_after_fail["task_execution_logs"]["stderr"] + assert state_after_fail["retry_count"] == 1 + assert state_after_fail["commit_info"]["committed"] is False + + # Verify failure comment was posted to Jira + comment_calls = [call[0][1] for call in mock_jira.add_comment.call_args_list] + assert any("failed/exited with code 2" in msg for msg in comment_calls) + + # --- SECOND RUN: Simulated self-correction and success --- + mock_proc_success = AsyncMock() + mock_proc_success.communicate = AsyncMock( + return_value=( + b"Self-corrected auth.py.\nAll compilation checks and tests passed successfully!", + b"", + ) + ) + mock_proc_success.returncode = 0 + mock_create_proc.return_value = mock_proc_success + + mock_git_success = _make_mock_git(has_changes=True, sha="abcdef1234567890") + + # We pass the state containing the failure logs and incremented retry count back to simulate the self-correction step + with ( + patch( + "forge.workflow.nodes.task_takeover_execution.JiraClient", + return_value=mock_jira, + ), + patch( + "forge.workflow.nodes.task_takeover_execution.GitOperations", + return_value=mock_git_success, + ), + patch("forge.workflow.nodes.task_takeover_execution.get_settings"), + ): + # Act + state_after_success = await execute_task_changes(state_after_fail) + + # Assert second run succeeded after self-correction, resetting retry and committing changes + assert state_after_success["task_execution_results"]["success"] is True + assert state_after_success["task_execution_results"]["exit_code"] == 0 + assert "All compilation checks" in state_after_success["task_execution_logs"]["stdout"] + assert state_after_success["retry_count"] == 0 # Reset after success + assert state_after_success["commit_info"]["committed"] is True + assert state_after_success["commit_info"]["sha"] == "abcdef1234567890" + + # Verify success comment was posted to Jira + comment_calls_updated = [call[0][1] for call in mock_jira.add_comment.call_args_list] + assert any( + "Task takeover implementation succeeded" in msg for msg in comment_calls_updated + ) + + @pytest.mark.asyncio + @patch("asyncio.create_subprocess_exec") + async def test_cleanup_podman_containers_lifecycle(self, mock_create_proc: AsyncMock) -> None: + """Test cleanup_podman_containers finds, stops, and removes targeted containers securely.""" + # Arrange + mock_ps_proc = AsyncMock() + mock_ps_proc.communicate = AsyncMock(return_value=(b"forge-TASK-123-abc\n", b"")) + + mock_stop_proc = AsyncMock() + mock_stop_proc.wait = AsyncMock() + + mock_rm_proc = AsyncMock() + mock_rm_proc.wait = AsyncMock() + + def side_effect(*args: Any, **_kwargs: Any) -> AsyncMock: + if args[1] == "ps": + return mock_ps_proc + elif args[1] == "stop": + return mock_stop_proc + elif args[1] == "rm": + return mock_rm_proc + return AsyncMock() + + mock_create_proc.side_effect = side_effect + + # Act + await cleanup_podman_containers("TASK-123") + + # Assert + # Check that we queried the containers + assert mock_create_proc.call_count >= 3 + first_call_args = mock_create_proc.call_args_list[0][0] + assert first_call_args[0] == "podman" + assert first_call_args[1] == "ps" + assert "--filter" in first_call_args + assert "name=forge-TASK-123-" in first_call_args + + # Check that stop and rm were called on the returned container name + stop_called = False + rm_called = False + for call in mock_create_proc.call_args_list: + args = call[0] + if "stop" in args: + stop_called = True + assert "forge-TASK-123-abc" in args + if "rm" in args: + rm_called = True + assert "forge-TASK-123-abc" in args + + assert stop_called is True + assert rm_called is True + + @pytest.mark.asyncio + @patch("forge.workflow.nodes.workspace_setup.get_workspace_manager") + async def test_teardown_workspace_secure_destruction(self, mock_get_manager: MagicMock) -> None: + """Test teardown_workspace securely destroys the workspace and clears path in state.""" + # Arrange + state = _make_state(workspace_path="/tmp/ws-to-teardown") + mock_manager = MagicMock() + mock_workspace = MagicMock() + mock_manager.get_workspace.return_value = mock_workspace + mock_get_manager.return_value = mock_manager + + # Act + teardown_state = await teardown_workspace(state) + + # Assert + assert teardown_state["workspace_path"] is None + assert teardown_state["current_node"] == "workspace_complete" + mock_manager.get_workspace.assert_called_once_with("TASK-123", "acme/backend") + mock_manager.destroy_workspace.assert_called_once_with(mock_workspace) diff --git a/tests/unit/api/routes/test_jira_webhook.py b/tests/unit/api/routes/test_jira_webhook.py index bc18dcd3..cd991234 100644 --- a/tests/unit/api/routes/test_jira_webhook.py +++ b/tests/unit/api/routes/test_jira_webhook.py @@ -8,6 +8,8 @@ import pytest from httpx import ASGITransport, AsyncClient from pydantic import SecretStr + +from forge.main import app from tests.fixtures.jira_payloads import ( WEBHOOK_ISSUE_CREATED, WEBHOOK_ISSUE_UPDATED_COMMENT_ADDED, @@ -15,8 +17,6 @@ make_jira_webhook, ) -from forge.main import app - def compute_signature(payload: bytes, secret: str) -> str: """Compute Jira webhook signature with sha256= prefix.""" @@ -44,20 +44,21 @@ async def test_valid_webhook_returns_202(self): mock_producer = MagicMock() mock_producer.publish = AsyncMock() - with patch("forge.api.routes.jira.get_settings", return_value=mock_settings): - with patch("forge.api.routes.jira.QueueProducer", return_value=mock_producer): - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test" - ) as client: - response = await client.post( - "/api/v1/webhooks/jira", - content=payload, - headers={ - "Content-Type": "application/json", - "X-Hub-Signature-256": signature, - }, - ) + with ( + patch("forge.api.routes.jira.get_settings", return_value=mock_settings), + patch("forge.api.routes.jira.QueueProducer", return_value=mock_producer), + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/webhooks/jira", + content=payload, + headers={ + "Content-Type": "application/json", + "X-Hub-Signature-256": signature, + }, + ) assert response.status_code == 202 @@ -71,8 +72,7 @@ async def test_invalid_signature_returns_401(self): with patch("forge.api.routes.jira.get_settings", return_value=mock_settings): async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test" + transport=ASGITransport(app=app), base_url="http://test" ) as client: response = await client.post( "/api/v1/webhooks/jira", @@ -95,8 +95,7 @@ async def test_missing_signature_returns_401(self): with patch("forge.api.routes.jira.get_settings", return_value=mock_settings): async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test" + transport=ASGITransport(app=app), base_url="http://test" ) as client: response = await client.post( "/api/v1/webhooks/jira", @@ -120,20 +119,21 @@ async def test_non_managed_issue_skipped(self): mock_producer = MagicMock() mock_producer.publish = AsyncMock() - with patch("forge.api.routes.jira.get_settings", return_value=mock_settings): - with patch("forge.api.routes.jira.QueueProducer", return_value=mock_producer): - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test" - ) as client: - response = await client.post( - "/api/v1/webhooks/jira", - content=payload, - headers={ - "Content-Type": "application/json", - "X-Hub-Signature-256": signature, - }, - ) + with ( + patch("forge.api.routes.jira.get_settings", return_value=mock_settings), + patch("forge.api.routes.jira.QueueProducer", return_value=mock_producer), + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/webhooks/jira", + content=payload, + headers={ + "Content-Type": "application/json", + "X-Hub-Signature-256": signature, + }, + ) assert response.status_code == 202 data = response.json() @@ -160,23 +160,237 @@ async def test_label_change_event_published(self): mock_producer = MagicMock() mock_producer.publish = AsyncMock() - with patch("forge.api.routes.jira.get_settings", return_value=mock_settings): - with patch("forge.api.routes.jira.QueueProducer", return_value=mock_producer): - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test" - ) as client: - response = await client.post( - "/api/v1/webhooks/jira", - content=payload, - headers={ - "Content-Type": "application/json", - "X-Hub-Signature-256": signature, - }, - ) + with ( + patch("forge.api.routes.jira.get_settings", return_value=mock_settings), + patch("forge.api.routes.jira.QueueProducer", return_value=mock_producer), + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/webhooks/jira", + content=payload, + headers={ + "Content-Type": "application/json", + "X-Hub-Signature-256": signature, + }, + ) + + assert response.status_code == 202 + mock_producer.publish.assert_called_once() + + @pytest.mark.asyncio + async def test_standard_task_without_parent_skipped(self) -> None: + """Standard Task issues without forge:parent label or takeover triggers are skipped.""" + webhook = make_jira_webhook(issue_type="Task", labels=["forge:managed"]) + payload = json.dumps(webhook).encode() + secret = "test-webhook-secret" + signature = compute_signature(payload, secret) + + mock_settings = MagicMock() + mock_settings.jira_webhook_secret = SecretStr(secret) + mock_settings.task_takeover = MagicMock() + mock_settings.task_takeover.labels = MagicMock() + mock_settings.task_takeover.labels.trigger = "forge:task-takeover" + + mock_producer = MagicMock() + mock_producer.publish = AsyncMock() + + with ( + patch("forge.api.routes.jira.get_settings", return_value=mock_settings), + patch("forge.api.routes.jira.QueueProducer", return_value=mock_producer), + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/webhooks/jira", + content=payload, + headers={ + "Content-Type": "application/json", + "X-Hub-Signature-256": signature, + }, + ) + + assert response.status_code == 202 + data = response.json() + assert data["status"] == "skipped" + assert "must have forge:parent label" in data["reason"] + mock_producer.publish.assert_not_called() + + @pytest.mark.asyncio + async def test_standard_task_with_parent_routed_to_parent(self) -> None: + """Standard Task issues with forge:parent label are routed to the parent ticket key.""" + webhook = make_jira_webhook( + issue_type="Task", labels=["forge:managed", "forge:parent:PARENT-123"] + ) + payload = json.dumps(webhook).encode() + secret = "test-webhook-secret" + signature = compute_signature(payload, secret) + + mock_settings = MagicMock() + mock_settings.jira_webhook_secret = SecretStr(secret) + mock_settings.task_takeover = MagicMock() + mock_settings.task_takeover.labels = MagicMock() + mock_settings.task_takeover.labels.trigger = "forge:task-takeover" + + mock_producer = MagicMock() + mock_producer.publish = AsyncMock() + + with ( + patch("forge.api.routes.jira.get_settings", return_value=mock_settings), + patch("forge.api.routes.jira.QueueProducer", return_value=mock_producer), + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/webhooks/jira", + content=payload, + headers={ + "Content-Type": "application/json", + "X-Hub-Signature-256": signature, + }, + ) + + assert response.status_code == 202 + data = response.json() + assert data["status"] == "accepted" + mock_producer.publish.assert_called_once() + called_kwargs = mock_producer.publish.call_args.kwargs + assert called_kwargs["ticket_key"] == "PARENT-123" + assert called_kwargs["payload"]["source_ticket_key"] == "TEST-123" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "trigger_label", + ["forge:task-takeover", "forge:managed:task", "forge:managed:task-takeover"], + ) + async def test_task_with_takeover_trigger_bypasses_parent_check(self, trigger_label: str) -> None: + """Task issue with a task-takeover trigger label bypasses parent check and is queued under its own key.""" + webhook = make_jira_webhook(issue_type="Task", labels=[trigger_label]) + payload = json.dumps(webhook).encode() + secret = "test-webhook-secret" + signature = compute_signature(payload, secret) + + mock_settings = MagicMock() + mock_settings.jira_webhook_secret = SecretStr(secret) + mock_settings.task_takeover = MagicMock() + mock_settings.task_takeover.labels = MagicMock() + mock_settings.task_takeover.labels.trigger = "forge:task-takeover" + + mock_producer = MagicMock() + mock_producer.publish = AsyncMock() + + with ( + patch("forge.api.routes.jira.get_settings", return_value=mock_settings), + patch("forge.api.routes.jira.QueueProducer", return_value=mock_producer), + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/webhooks/jira", + content=payload, + headers={ + "Content-Type": "application/json", + "X-Hub-Signature-256": signature, + }, + ) + + assert response.status_code == 202 + data = response.json() + assert data["status"] == "accepted" + mock_producer.publish.assert_called_once() + called_kwargs = mock_producer.publish.call_args.kwargs + assert called_kwargs["ticket_key"] == "TEST-123" + assert "source_ticket_key" not in called_kwargs["payload"] + + @pytest.mark.asyncio + async def test_task_with_takeover_trigger_in_changelog_bypasses_parent_check(self) -> None: + """Task issue with task-takeover trigger added in changelog bypasses parent check and is queued under its own key.""" + webhook = make_jira_webhook( + issue_type="Task", + labels=[], + changelog_field="labels", + changelog_from="some-other-label", + changelog_to="forge:managed:task-takeover", + ) + payload = json.dumps(webhook).encode() + secret = "test-webhook-secret" + signature = compute_signature(payload, secret) + + mock_settings = MagicMock() + mock_settings.jira_webhook_secret = SecretStr(secret) + mock_settings.task_takeover = MagicMock() + mock_settings.task_takeover.labels = MagicMock() + mock_settings.task_takeover.labels.trigger = "forge:task-takeover" + + mock_producer = MagicMock() + mock_producer.publish = AsyncMock() + + with ( + patch("forge.api.routes.jira.get_settings", return_value=mock_settings), + patch("forge.api.routes.jira.QueueProducer", return_value=mock_producer), + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/webhooks/jira", + content=payload, + headers={ + "Content-Type": "application/json", + "X-Hub-Signature-256": signature, + }, + ) + + assert response.status_code == 202 + data = response.json() + assert data["status"] == "accepted" + mock_producer.publish.assert_called_once() + called_kwargs = mock_producer.publish.call_args.kwargs + assert called_kwargs["ticket_key"] == "TEST-123" + + @pytest.mark.asyncio + async def test_task_with_custom_takeover_trigger_bypasses_parent_check(self) -> None: + """Task issue with a custom configured trigger label bypasses parent check and is queued under its own key.""" + webhook = make_jira_webhook(issue_type="Task", labels=["custom-trigger-label"]) + payload = json.dumps(webhook).encode() + secret = "test-webhook-secret" + signature = compute_signature(payload, secret) + + mock_settings = MagicMock() + mock_settings.jira_webhook_secret = SecretStr(secret) + mock_settings.task_takeover = MagicMock() + mock_settings.task_takeover.labels = MagicMock() + mock_settings.task_takeover.labels.trigger = "custom-trigger-label" + + mock_producer = MagicMock() + mock_producer.publish = AsyncMock() + + with ( + patch("forge.api.routes.jira.get_settings", return_value=mock_settings), + patch("forge.api.routes.jira.QueueProducer", return_value=mock_producer), + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/webhooks/jira", + content=payload, + headers={ + "Content-Type": "application/json", + "X-Hub-Signature-256": signature, + }, + ) assert response.status_code == 202 + data = response.json() + assert data["status"] == "accepted" mock_producer.publish.assert_called_once() + called_kwargs = mock_producer.publish.call_args.kwargs + assert called_kwargs["ticket_key"] == "TEST-123" class TestJiraWebhookParsing: diff --git a/tests/unit/models/test_workflow.py b/tests/unit/models/test_workflow.py index 21eaddc3..18f59f00 100644 --- a/tests/unit/models/test_workflow.py +++ b/tests/unit/models/test_workflow.py @@ -1,6 +1,5 @@ """Unit tests for workflow models.""" - from forge.models.workflow import ( ForgeLabel, JiraStatus, @@ -63,6 +62,13 @@ def test_plan_approved_label_value(self): """ForgeLabel.PLAN_APPROVED already exists with correct value.""" assert ForgeLabel.PLAN_APPROVED.value == "forge:plan-approved" + def test_task_takeover_labels_exist(self) -> None: + """Verify Task Takeover workflow labels are defined.""" + assert ForgeLabel.TASK_TAKEOVER.value == "forge:task-takeover" + assert ForgeLabel.TASK_TRIAGE_PENDING.value == "forge:task-triage-pending" + assert ForgeLabel.TASK_PLAN_PENDING.value == "forge:task-plan-pending" + assert ForgeLabel.TASK_PLAN_APPROVED.value == "forge:task-plan-approved" + def test_general_labels_exist(self): """Verify general labels are defined.""" assert ForgeLabel.FORGE_MANAGED.value == "forge:managed" diff --git a/tests/unit/orchestrator/gates/test_task_plan_approval.py b/tests/unit/orchestrator/gates/test_task_plan_approval.py new file mode 100644 index 00000000..33df2f08 --- /dev/null +++ b/tests/unit/orchestrator/gates/test_task_plan_approval.py @@ -0,0 +1,103 @@ +"""Unit tests for the task takeover plan approval gate and routing logic.""" + +import pytest +from langgraph.graph import END + +from forge.models.workflow import TicketType +from forge.workflow.gates.task_plan_approval import ( + route_task_plan_approval, + task_plan_approval_gate, +) +from forge.workflow.task_takeover.state import create_initial_task_takeover_state + + +class TestTaskPlanApprovalGate: + """Tests for task_plan_approval_gate node.""" + + def test_gate_pauses_workflow(self) -> None: + """Gate sets is_paused=True and updates current_node.""" + state = create_initial_task_takeover_state("TASK-100") + state["current_node"] = "generate_plan" + + result = task_plan_approval_gate(state) + + assert result["is_paused"] is True + assert result["current_node"] == "task_plan_approval_gate" + + +class TestRouteTaskPlanApproval: + """Tests for route_task_plan_approval function.""" + + @pytest.fixture + def paused_state(self): + """Standard paused state at task plan approval gate.""" + state = create_initial_task_takeover_state("TASK-100") + state["current_node"] = "task_plan_approval_gate" + state["is_paused"] = True + return state + + def test_routes_to_end_when_still_paused(self, paused_state) -> None: + """If still paused and no signals are present, route to END.""" + result = route_task_plan_approval(paused_state) + assert result == END + + def test_routes_to_setup_workspace_on_approval(self, paused_state) -> None: + """When resumed with approval, is_paused is False and routes to setup_workspace.""" + paused_state["is_paused"] = False + + result = route_task_plan_approval(paused_state) + assert result == "setup_workspace" + + def test_routes_to_regenerate_plan_on_feedback_comment(self, paused_state) -> None: + """Comment starting with '!' triggers feedback classification and routes to regenerate_plan.""" + # Scenario A: feedback is processed by worker and comes in as revision_requested + state_worker = { + **paused_state, + "revision_requested": True, + "feedback_comment": "Please rewrite the logging part.", + } + assert route_task_plan_approval(state_worker) == "regenerate_plan" + + # Scenario B: feedback comment with '!' is evaluated directly by the router (prefix integration check) + state_direct = { + **paused_state, + "feedback_comment": "!Please rewrite the logging part.", + } + assert route_task_plan_approval(state_direct) == "regenerate_plan" + + def test_routes_to_answer_question_on_question_comment_with_prefix(self, paused_state) -> None: + """Comment starting with '?' or '@forge ask' triggers QUESTION classification and routes to answer_question.""" + # Scenario A: is_question is set + state_worker = { + **paused_state, + "is_question": True, + "feedback_comment": "?Why use REST?", + } + assert route_task_plan_approval(state_worker) == "answer_question" + + # Scenario B: comment starting with '?' is evaluated directly by prefix classifier + state_direct_question = { + **paused_state, + "feedback_comment": "?Why use REST?", + } + assert route_task_plan_approval(state_direct_question) == "answer_question" + + # Scenario C: comment starting with '@forge ask' is evaluated directly by prefix classifier + state_direct_ask = { + **paused_state, + "feedback_comment": "@forge ask can you explain more?", + } + assert route_task_plan_approval(state_direct_ask) == "answer_question" + + def test_yolo_mode_auto_approves(self, paused_state) -> None: + """YOLO mode routes directly to setup_workspace.""" + paused_state["yolo_mode"] = True + result = route_task_plan_approval(paused_state) + assert result == "setup_workspace" + + def test_informational_comment_ignored(self, paused_state) -> None: + """Standard informational comments do not trigger transition and stay in paused state (routes to END).""" + paused_state["feedback_comment"] = "This is a plain comment with no special prefix" + # Standard comments don't change is_paused to False, or set revision_requested/is_question + result = route_task_plan_approval(paused_state) + assert result == END diff --git a/tests/unit/orchestrator/test_blocked_retry.py b/tests/unit/orchestrator/test_blocked_retry.py index 5169b722..d709ccf1 100644 --- a/tests/unit/orchestrator/test_blocked_retry.py +++ b/tests/unit/orchestrator/test_blocked_retry.py @@ -54,7 +54,6 @@ def _make_retry_message(base: QueueMessage) -> QueueMessage: ) - class TestWorkerTerminalBlockedCheck: """Worker skips invocation when is_blocked=True, same as terminal nodes.""" @@ -78,10 +77,9 @@ async def fake_process(_message): mock_state.values = blocked_state terminal_nodes = ("complete", "complete_tasks", "aggregate_feature_status") - is_terminal_or_blocked = ( - blocked_state.get("current_node") in terminal_nodes - or blocked_state.get("is_blocked", False) - ) + is_terminal_or_blocked = blocked_state.get( + "current_node" + ) in terminal_nodes or blocked_state.get("is_blocked", False) if is_terminal_or_blocked: return # skipped @@ -103,9 +101,8 @@ async def test_non_blocked_mid_workflow_is_invocable(self): } terminal_nodes = ("complete", "complete_tasks", "aggregate_feature_status") - is_terminal_or_blocked = ( - state.get("current_node") in terminal_nodes - or state.get("is_blocked", False) + is_terminal_or_blocked = state.get("current_node") in terminal_nodes or state.get( + "is_blocked", False ) assert is_terminal_or_blocked is False @@ -130,9 +127,7 @@ async def test_retry_clears_is_blocked(self, worker, base_message): "context": {}, } - result = await worker._handle_resume_event( - _make_retry_message(base_message), blocked_state - ) + result = await worker._handle_resume_event(_make_retry_message(base_message), blocked_state) assert result.get("is_blocked") is False @@ -152,9 +147,7 @@ async def test_retry_resets_ci_fix_attempts_unconditionally(self, worker, base_m "context": {}, } - result = await worker._handle_resume_event( - _make_retry_message(base_message), blocked_state - ) + result = await worker._handle_resume_event(_make_retry_message(base_message), blocked_state) assert result.get("ci_fix_attempt") == 0 @@ -174,9 +167,7 @@ async def test_retry_clears_last_error(self, worker, base_message): "context": {}, } - result = await worker._handle_resume_event( - _make_retry_message(base_message), blocked_state - ) + result = await worker._handle_resume_event(_make_retry_message(base_message), blocked_state) assert result.get("last_error") is None @@ -196,9 +187,7 @@ async def test_retry_preserves_current_node(self, worker, base_message): "context": {}, } - result = await worker._handle_resume_event( - _make_retry_message(base_message), blocked_state - ) + result = await worker._handle_resume_event(_make_retry_message(base_message), blocked_state) assert result.get("current_node") == "ci_evaluator" @@ -218,9 +207,7 @@ async def test_retry_marks_non_gate_node_for_fresh_invoke(self, worker, base_mes "context": {}, } - result = await worker._handle_resume_event( - _make_retry_message(base_message), blocked_state - ) + result = await worker._handle_resume_event(_make_retry_message(base_message), blocked_state) assert result.get("context", {}).get("force_fresh_invoke") is True @@ -244,9 +231,7 @@ async def test_retry_on_non_terminal_no_error_still_resumes(self, worker, base_m "context": {}, } - result = await worker._handle_resume_event( - _make_retry_message(base_message), stuck_state - ) + result = await worker._handle_resume_event(_make_retry_message(base_message), stuck_state) assert result.get("is_paused") is False assert result.get("last_error") is None @@ -282,3 +267,31 @@ async def test_retry_on_terminal_no_error_posts_comment(self, worker, base_messa assert result.get("current_node") == "complete" # And the user must be informed via a Jira comment worker._post_terminal_error_comment.assert_called_once() + + +class TestRetryAtTaskPlanApprovalGate: + """Tests for forge:retry at task_plan_approval_gate.""" + + @pytest.mark.asyncio + async def test_retry_at_task_plan_approval_gate_sets_revision_requested( + self, worker, base_message + ): + """forge:retry at task_plan_approval_gate sets revision_requested=True.""" + state = { + "ticket_key": "TEST-123", + "current_node": "task_plan_approval_gate", + "is_paused": True, + "is_blocked": False, + "last_error": None, + "ci_fix_attempt": 0, + "retry_count": 0, + "revision_requested": False, + "feedback_comment": None, + "context": {}, + } + + result = await worker._handle_resume_event(_make_retry_message(base_message), state) + + assert result.get("is_paused") is False + assert result.get("revision_requested") is True + assert result.get("feedback_comment") == "Regeneration requested via retry." diff --git a/tests/unit/orchestrator/test_worker.py b/tests/unit/orchestrator/test_worker.py index 5b64958e..0984a4a4 100644 --- a/tests/unit/orchestrator/test_worker.py +++ b/tests/unit/orchestrator/test_worker.py @@ -816,3 +816,242 @@ def test_bullet_list_text(self): def test_non_dict_returns_string(self): assert OrchestratorWorker._extract_text_from_adf("plain") == "plain" assert OrchestratorWorker._extract_text_from_adf(None) == "" + + +class TestTaskPlanApprovalAndLabelPreservation: + """Tests for task plan approval resumption, YOLO gate, and label preservation.""" + + @pytest.fixture(autouse=True) + def ack_comment_mocks(self): + """Mock Jira acknowledgement posting for direct resume-event tests.""" + mock_jira = AsyncMock() + mock_jira.close = AsyncMock() + with ( + patch("forge.orchestrator.worker.JiraClient", return_value=mock_jira), + patch("forge.orchestrator.worker.post_status_comment", new_callable=AsyncMock) as post, + ): + yield post + + @pytest.fixture + def worker(self) -> OrchestratorWorker: + """Create a worker instance for testing.""" + return OrchestratorWorker(consumer_name="test-worker") + + @pytest.fixture + def base_message(self) -> QueueMessage: + """Create a base queue message for testing.""" + return QueueMessage( + message_id="1234567890-0", + event_id="test-event-001", + source=EventSource.JIRA, + event_type="jira:issue_updated", + ticket_key="TEST-123", + payload={ + "issue": { + "key": "TEST-123", + "fields": { + "issuetype": {"name": "Task"}, + "labels": ["forge:managed"], + }, + }, + }, + ) + + @pytest.fixture + def base_state(self) -> dict: + """Create a base workflow state for testing.""" + return { + "ticket_key": "TEST-123", + "ticket_type": "Task", + "current_node": "task_plan_approval_gate", + "is_paused": True, + "context": {}, + } + + @pytest.mark.asyncio + async def test_task_plan_label_change_to_approved_sets_approved_flag( + self, worker: OrchestratorWorker, base_message: QueueMessage, base_state: dict + ): + """Approval for task plan is detected via label change from pending to approved.""" + payload = { + **base_message.payload, + "changelog": { + "items": [ + { + "field": "labels", + "fromString": "forge:managed forge:task-plan-pending", + "toString": "forge:managed forge:task-plan-approved", + } + ] + }, + } + message = QueueMessage( + message_id=base_message.message_id, + event_id=base_message.event_id, + source=base_message.source, + event_type="jira:issue_updated", + ticket_key=base_message.ticket_key, + payload=payload, + ) + + result = await worker._handle_resume_event(message, base_state) + + assert result["is_paused"] is False + assert result.get("revision_requested") is not True + + @pytest.mark.asyncio + async def test_task_plan_label_fallback_approved( + self, worker: OrchestratorWorker, base_message: QueueMessage, base_state: dict + ): + """Fallback detection: check current labels on the ticket when changelog check missed it.""" + payload = { + **base_message.payload, + "issue": { + "key": "TEST-123", + "fields": { + "issuetype": {"name": "Task"}, + "labels": ["forge:managed", "forge:task-plan-approved"], + }, + }, + "changelog": {"items": []}, + } + message = QueueMessage( + message_id=base_message.message_id, + event_id=base_message.event_id, + source=base_message.source, + event_type="jira:issue_updated", + ticket_key=base_message.ticket_key, + payload=payload, + ) + + result = await worker._handle_resume_event(message, base_state) + + assert result["is_paused"] is False + assert result.get("revision_requested") is not True + + @pytest.mark.asyncio + async def test_task_plan_yolo_gate_activation( + self, worker: OrchestratorWorker, base_message: QueueMessage, base_state: dict + ): + """Adding forge:yolo label at task_plan_approval_gate activates YOLO mode.""" + payload = { + **base_message.payload, + "changelog": { + "items": [ + { + "field": "labels", + "fromString": "forge:managed", + "toString": "forge:managed forge:yolo", + } + ] + }, + } + message = QueueMessage( + message_id=base_message.message_id, + event_id=base_message.event_id, + source=base_message.source, + event_type="jira:issue_updated", + ticket_key=base_message.ticket_key, + payload=payload, + ) + + result = await worker._handle_resume_event(message, base_state) + + assert result["yolo_mode"] is True + assert result["is_paused"] is False + + @pytest.mark.asyncio + async def test_label_preservation_during_transitions(self): + """Transitions do not clear identity preservation labels forge:managed:task and forge:managed:task-takeover.""" + from forge.integrations.jira.client import JiraClient + from forge.models.workflow import ForgeLabel + + # Mock settings for JiraClient instantiation + with patch("forge.integrations.jira.client.get_settings") as mock_settings: + mock_settings.return_value.jira_base_url = "https://test.atlassian.net" + mock_settings.return_value.jira_api_token = MagicMock() + mock_settings.return_value.jira_api_token.get_secret_value.return_value = "token" + mock_settings.return_value.jira_user_email = "test@example.com" + + client = JiraClient() + + # Mock get_labels to return current labels including identity preservation ones + client.get_labels = AsyncMock( + return_value=[ + "forge:managed", + "forge:task-plan-pending", + "forge:managed:task", + "forge:managed:task-takeover", + "other-label", + ] + ) + + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + + with patch.object(client, "_get_client") as mock_get_client: + mock_http = AsyncMock() + mock_http.put = AsyncMock(return_value=mock_response) + mock_get_client.return_value = mock_http + + await client.set_workflow_label("TEST-123", ForgeLabel.TASK_PLAN_APPROVED) + + # Check that PUT was called with correct operations + mock_http.put.assert_called_once() + call_args = mock_http.put.call_args + update_ops = call_args.kwargs["json"]["update"]["labels"] + + # Assert no remove operations are queued for the identity labels + remove_ops = [op for op in update_ops if "remove" in op] + assert not any(op["remove"] == "forge:managed:task" for op in remove_ops) + assert not any(op["remove"] == "forge:managed:task-takeover" for op in remove_ops) + + # Verify that "forge:task-plan-pending" is removed + assert any(op["remove"] == "forge:task-plan-pending" for op in remove_ops) + # Verify that "forge:task-plan-approved" is added + add_ops = [op for op in update_ops if "add" in op] + assert any(op["add"] == ForgeLabel.TASK_PLAN_APPROVED.value for op in add_ops) + + +class TestWorkerRouting: + """Tests for message routing and label extraction in the worker.""" + + @pytest.mark.asyncio + async def test_process_workflow_extracts_labels_and_calls_resolve(self): + """Worker extracts labels from the payload and passes them to the router.""" + from forge.models.workflow import TicketType + + worker = OrchestratorWorker(consumer_name="test-worker") + + message = QueueMessage( + message_id="1234567890-0", + event_id="test-event-001", + source=EventSource.JIRA, + event_type="jira:issue_updated", + ticket_key="TEST-123", + payload={ + "issue": { + "key": "TEST-123", + "fields": { + "issuetype": {"name": "Task"}, + "labels": ["forge:task-takeover"], + }, + }, + }, + ) + + mock_router = MagicMock() + mock_router.resolve = MagicMock(return_value=None) + worker.router = mock_router + + with ( + patch("forge.orchestrator.worker.ensure_skills", AsyncMock()), + patch("forge.orchestrator.worker.JiraClient"), + ): + await worker._process_workflow(message) + + mock_router.resolve.assert_called_once_with( + ticket_type=TicketType.TASK, + labels=["forge:task-takeover"], + event=message.payload, + ) diff --git a/tests/unit/prompts/test_prompt_templates.py b/tests/unit/prompts/test_prompt_templates.py index f4c55d27..c1733421 100644 --- a/tests/unit/prompts/test_prompt_templates.py +++ b/tests/unit/prompts/test_prompt_templates.py @@ -57,6 +57,10 @@ def test_list_prompts_for_v1(self): "decompose-epics", "analyze-bug", "regenerate", + "task-takeover-triage", + "task-takeover-planning", + "task-takeover-qa", + "task-takeover-review", ] for expected in expected_prompts: @@ -157,7 +161,10 @@ def test_generate_tasks_preserves_bounded_repo_grounding(self): existing_tasks_section="None", ) - assert "Prefer additional codebase exploration only for missing implementation details" in result + assert ( + "Prefer additional codebase exploration only for missing implementation details" + in result + ) assert "broaden the search when needed" in result assert "unrelated branches, open issues, pull requests" in result assert "nearby source/test patterns" in result @@ -255,6 +262,56 @@ def test_generate_prd_prompt_structure(self): assert "Test requirements" in result assert "Test context" in result + def test_task_takeover_triage_prompt(self): + """task-takeover-triage prompt should enforce strict evaluation of the three mandatory sections.""" + result = load_prompt( + "task-takeover-triage", + summary="Test summary", + description="Test description", + comments="Test comments", + ) + + assert "Problem Statement" in result + assert "Proposed Solution/Approach" in result + assert "Acceptance Criteria" in result + assert "Test description" in result + assert "Test comments" in result + + def test_task_takeover_planning_prompt(self): + """task-takeover-planning prompt should map solutions to repository files and test plans.""" + result = load_prompt( + "task-takeover-planning", + ticket_key="AISOS-1234", + summary="Test summary", + description="Test description", + comments="Test comments", + known_repos="acme/repo", + file_metadata="file1.py\nfile2.py", + ) + + assert "AISOS-1234" in result + assert "acme/repo" in result + assert "file1.py" in result + assert "Target Files" in result + assert "Test Plans" in result + assert "Implementation Steps" in result + + def test_task_takeover_qa_prompt(self): + """task-takeover-qa prompt should provide guidelines for contextual Q&A during planning.""" + result = load_prompt( + "task-takeover-qa", + ticket_key="AISOS-1234", + summary="Test summary", + description="Test description", + plan_content="Test plan content", + question="What is the test plan?", + ) + + assert "AISOS-1234" in result + assert "Test plan content" in result + assert "What is the test plan?" in result + assert "clarifying question" in result + def test_prompts_are_reasonable_length(self): """Prompts should not be excessively long (sanity check).""" # A rough estimate: 1 token ~ 4 characters @@ -276,13 +333,13 @@ def test_prompt_with_special_characters_in_value(self): """Variables with special characters should be handled.""" result = load_prompt( "generate-prd", - raw_requirements="Test with $pecial ch@racters & symbols < > \"quotes\"", + raw_requirements='Test with $pecial ch@racters & symbols < > "quotes"', context="Normal context", ) assert "$pecial" in result assert "ch@racters" in result - assert "\"quotes\"" in result + assert '"quotes"' in result def test_prompt_with_multiline_value(self): """Multiline variable values should be preserved.""" @@ -319,7 +376,7 @@ def test_prompt_with_curly_braces_in_content(self): # This documents current behavior result = load_prompt( "generate-prd", - raw_requirements="JSON: {\"key\": \"value\"}", + raw_requirements='JSON: {"key": "value"}', context="Normal", ) diff --git a/tests/unit/test_config_prd.py b/tests/unit/test_config_prd.py index 13d5c78c..a43bffd0 100644 --- a/tests/unit/test_config_prd.py +++ b/tests/unit/test_config_prd.py @@ -4,7 +4,7 @@ class TestPrdApprovalConfig: - def test_default_proposals_repo_is_empty(self): + def test_default_proposals_repo_is_empty(self) -> None: settings = Settings( jira_base_url="https://test.atlassian.net", jira_api_token="test", @@ -14,7 +14,7 @@ def test_default_proposals_repo_is_empty(self): ) assert settings.prd_proposals_repo == "" - def test_default_proposals_path(self): + def test_default_proposals_path(self) -> None: settings = Settings( jira_base_url="https://test.atlassian.net", jira_api_token="test", @@ -24,7 +24,7 @@ def test_default_proposals_path(self): ) assert settings.prd_proposals_path == "" - def test_proposals_repo_can_be_set_as_global_fallback(self): + def test_proposals_repo_can_be_set_as_global_fallback(self) -> None: settings = Settings( jira_base_url="https://test.atlassian.net", jira_api_token="test", @@ -34,3 +34,53 @@ def test_proposals_repo_can_be_set_as_global_fallback(self): prd_proposals_repo="org/proposals", ) assert settings.prd_proposals_repo == "org/proposals" + + +class TestTaskTakeoverConfig: + def test_default_task_takeover_settings(self) -> None: + settings = Settings( + jira_base_url="https://test.atlassian.net", + jira_api_token="test", + jira_user_email="test@example.com", + github_token="test", + anthropic_api_key="test", + ) + assert settings.task_takeover.enabled is False + assert settings.task_takeover.issue_types == [] + assert settings.task_takeover.require_tests is True + assert settings.task_takeover.review_max_attempts == 2 + + # Verify default labels + labels = settings.task_takeover.labels + assert labels.trigger == "forge:task-takeover" + assert labels.pending == "forge:task-plan-pending" + assert labels.approved == "forge:task-plan-approved" + + def test_override_task_takeover_settings(self) -> None: + settings = Settings( + jira_base_url="https://test.atlassian.net", + jira_api_token="test", + jira_user_email="test@example.com", + github_token="test", + anthropic_api_key="test", + task_takeover={ + "enabled": True, + "issue_types": ["Bug", "Feature"], + "labels": { + "trigger": "custom-trigger", + "pending": "custom-pending", + "approved": "custom-approved", + }, + "require_tests": False, + "review_max_attempts": 3, + }, + ) + assert settings.task_takeover.enabled is True + assert settings.task_takeover.issue_types == ["Bug", "Feature"] + assert settings.task_takeover.require_tests is False + assert settings.task_takeover.review_max_attempts == 3 + + labels = settings.task_takeover.labels + assert labels.trigger == "custom-trigger" + assert labels.pending == "custom-pending" + assert labels.approved == "custom-approved" diff --git a/tests/unit/workflow/nodes/test_qa_handler.py b/tests/unit/workflow/nodes/test_qa_handler.py index a233d855..9474ba91 100644 --- a/tests/unit/workflow/nodes/test_qa_handler.py +++ b/tests/unit/workflow/nodes/test_qa_handler.py @@ -654,3 +654,37 @@ async def test_answer_question_at_plan_approval_gate_stays_paused(self): assert result["current_node"] == "plan_approval_gate" assert result["is_question"] is False assert result["feedback_comment"] is None + + @pytest.mark.asyncio + async def test_answer_question_at_task_plan_approval_gate(self): + """answer_question at task_plan_approval_gate passes context and returns is_paused=True.""" + state = { + "ticket_key": "TASK-123", + "ticket_type": TicketType.TASK, + "current_node": "task_plan_approval_gate", + "is_paused": True, + "is_question": True, + "feedback_comment": "?What is the approach?", + "plan_content": "## Task Plan", + "qa_history": [], + "generation_context": {}, + "revision_requested": False, + } + + mock_jira = create_mock_jira_client() + mock_agent = create_mock_forge_agent() + + with ( + patch("forge.workflow.nodes.qa_handler.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.qa_handler.ForgeAgent", return_value=mock_agent), + ): + result = await answer_question(state) + + assert result["is_paused"] is True + assert result["current_node"] == "task_plan_approval_gate" + assert result["is_question"] is False + assert result["feedback_comment"] is None + mock_agent.answer_question.assert_called_once() + call_kwargs = mock_agent.answer_question.call_args.kwargs + assert call_kwargs["context"]["ticket_type"] == TicketType.TASK + assert call_kwargs["context"]["current_node"] == "task_plan_approval_gate" diff --git a/tests/unit/workflow/nodes/test_task_takeover_execution.py b/tests/unit/workflow/nodes/test_task_takeover_execution.py new file mode 100644 index 00000000..7e367b98 --- /dev/null +++ b/tests/unit/workflow/nodes/test_task_takeover_execution.py @@ -0,0 +1,197 @@ +"""Unit tests for task takeover execution node.""" + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.models.workflow import TicketType +from forge.workflow.nodes.task_takeover_execution import execute_task_changes + + +def _make_state( + ticket_key="TASK-123", + ticket_type=TicketType.TASK, + workspace_path="/tmp/ws", + current_repo="acme/backend", + plan_content="This is the plan content.", + implemented_tasks=None, +): + return { + "ticket_key": ticket_key, + "ticket_type": ticket_type, + "current_node": "execute_task_changes", + "is_paused": False, + "retry_count": 0, + "last_error": None, + "workspace_path": workspace_path, + "current_repo": current_repo, + "plan_content": plan_content, + "implemented_tasks": implemented_tasks or [], + "context": {"branch_name": "forge/TASK-123", "guardrails": ""}, + } + + +def _make_mock_jira( + summary="Implement user authentication", description="Details of the authentication task" +): + jira = AsyncMock() + issue = MagicMock() + issue.summary = summary + issue.description = description + jira.get_issue = AsyncMock(return_value=issue) + jira.add_comment = AsyncMock() + jira.close = AsyncMock() + return jira + + +def _make_mock_runner( + success=True, exit_code=0, stdout="Build successful", stderr="", error_message=None +): + runner = MagicMock() + result = MagicMock() + result.success = success + result.exit_code = exit_code + result.stdout = stdout + result.stderr = stderr + result.error_message = error_message + runner.run = AsyncMock(return_value=result) + return runner + + +def _make_mock_git(has_changes=True, sha="abcdef1234567890"): + git = MagicMock() + git.has_uncommitted_changes = MagicMock(return_value=has_changes) + git.stage_all = MagicMock() + git.commit = MagicMock(return_value=True) + git.get_current_sha = MagicMock(return_value=sha) + return git + + +class TestTaskTakeoverExecutionNode: + """Tests for execute_task_changes node in Task Takeover workflow.""" + + @pytest.mark.asyncio + async def test_successful_execution(self) -> None: + """Test successful task takeover execution with code modifications and tests.""" + state = _make_state() + mock_jira = _make_mock_jira() + mock_runner = _make_mock_runner() + mock_git = _make_mock_git() + + with ( + patch( + "forge.workflow.nodes.task_takeover_execution.JiraClient", return_value=mock_jira + ), + patch( + "forge.workflow.nodes.task_takeover_execution.ContainerRunner", + return_value=mock_runner, + ), + patch( + "forge.workflow.nodes.task_takeover_execution.GitOperations", return_value=mock_git + ), + patch("forge.workflow.nodes.task_takeover_execution.get_settings"), + ): + result_state = await execute_task_changes(state) + + # Assertions on state results + assert result_state["task_execution_results"]["success"] is True + assert result_state["task_execution_results"]["exit_code"] == 0 + assert result_state["task_execution_logs"]["stdout"] == "Build successful" + assert result_state["commit_info"]["committed"] is True + assert result_state["commit_info"]["sha"] == "abcdef1234567890" + assert result_state["last_error"] is None + assert result_state["retry_count"] == 0 + + # Verify JIRA Client was called + mock_jira.get_issue.assert_called_once_with("TASK-123") + mock_jira.add_comment.assert_called() + mock_jira.close.assert_called_once() + + # Verify ContainerRunner was called with correct parameters + mock_runner.run.assert_called_once() + kwargs = mock_runner.run.call_args.kwargs + assert kwargs["workspace_path"] == Path("/tmp/ws") + assert "Approved Implementation Plan" in kwargs["task_description"] + assert "inject at least one new or modified test file" in kwargs["task_description"] + + # Verify GitOperations were performed + mock_git.has_uncommitted_changes.assert_called_once() + mock_git.stage_all.assert_called_once() + mock_git.commit.assert_called_once() + mock_git.get_current_sha.assert_called_once() + + @pytest.mark.asyncio + async def test_execution_failure(self) -> None: + """Test that execution failures are recorded as non-blocking metrics/results in state.""" + state = _make_state() + mock_jira = _make_mock_jira() + mock_runner = _make_mock_runner( + success=False, exit_code=2, stderr="Compilation error", error_message="Tests failed" + ) + mock_git = _make_mock_git(has_changes=False) + + with ( + patch( + "forge.workflow.nodes.task_takeover_execution.JiraClient", return_value=mock_jira + ), + patch( + "forge.workflow.nodes.task_takeover_execution.ContainerRunner", + return_value=mock_runner, + ), + patch( + "forge.workflow.nodes.task_takeover_execution.GitOperations", return_value=mock_git + ), + patch("forge.workflow.nodes.task_takeover_execution.get_settings"), + ): + result_state = await execute_task_changes(state) + + # Non-blocking compilation and test execution failures: we update state and return it gracefully + assert result_state["task_execution_results"]["success"] is False + assert result_state["task_execution_results"]["exit_code"] == 2 + assert result_state["task_execution_results"]["error_message"] == "Tests failed" + assert result_state["task_execution_logs"]["stderr"] == "Compilation error" + assert result_state["commit_info"]["committed"] is False + assert result_state["retry_count"] == 1 + + mock_git.commit.assert_not_called() + + @pytest.mark.asyncio + async def test_missing_workspace_path(self) -> None: + """Test graceful error handling when workspace_path is not set up.""" + state = _make_state(workspace_path=None) + mock_jira = _make_mock_jira() + + with ( + patch( + "forge.workflow.nodes.task_takeover_execution.JiraClient", return_value=mock_jira + ), + patch("forge.workflow.nodes.task_takeover_execution.get_settings"), + ): + result_state = await execute_task_changes(state) + + assert result_state["last_error"] == "Workspace not set up" + assert result_state["current_node"] == "execute_task_changes" + + @pytest.mark.asyncio + async def test_unexpected_exception(self) -> None: + """Test that unexpected exceptions are caught, logged, and updated in state.""" + state = _make_state() + mock_jira = _make_mock_jira() + mock_jira.get_issue.side_effect = Exception("Jira Connection Error") + + with ( + patch( + "forge.workflow.nodes.task_takeover_execution.JiraClient", return_value=mock_jira + ), + patch("forge.workflow.nodes.task_takeover_execution.get_settings"), + patch( + "forge.workflow.nodes.error_handler.notify_error", new=AsyncMock() + ) as mock_notify, + ): + result_state = await execute_task_changes(state) + + assert result_state["last_error"] == "Jira Connection Error" + assert result_state["current_node"] == "execute_task_changes" + assert result_state["retry_count"] == 1 + mock_notify.assert_called_once() diff --git a/tests/unit/workflow/nodes/test_task_takeover_planning.py b/tests/unit/workflow/nodes/test_task_takeover_planning.py new file mode 100644 index 00000000..2dedc089 --- /dev/null +++ b/tests/unit/workflow/nodes/test_task_takeover_planning.py @@ -0,0 +1,214 @@ +"""Unit tests for task takeover planning nodes.""" + +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from langgraph.graph import END + +from forge.models.workflow import ForgeLabel +from forge.workflow.nodes.task_takeover_planning import ( + generate_plan, + plan_approval_gate, + route_plan_approval, +) +from forge.workflow.task_takeover.state import ( + TaskTakeoverState, + create_initial_task_takeover_state, +) + + +def make_task_state(**overrides: Any) -> TaskTakeoverState: + """Create a TaskTakeoverState dict for planning tests.""" + state = create_initial_task_takeover_state("TASK-002") + state_dict = cast(dict[str, Any], state) + state_dict.update(overrides) + return cast(TaskTakeoverState, state_dict) + + +@pytest.fixture +def base_task_state() -> TaskTakeoverState: + return make_task_state() + + +def _make_mock_jira(summary="Implement user session logout", project_key="TASK"): + jira = AsyncMock() + issue = MagicMock() + issue.summary = summary + issue.description = "Task description" + issue.project_key = project_key + jira.get_issue = AsyncMock(return_value=issue) + jira.get_comments = AsyncMock(return_value=[]) + jira.add_comment = AsyncMock() + jira.set_workflow_label = AsyncMock() + jira.get_project_default_repo = AsyncMock(return_value="owner/project") + jira.get_project_repos = AsyncMock(return_value=["owner/project"]) + jira.close = AsyncMock() + return jira + + +def _make_mock_runner_success(plan_content="## Plan\n\nTask Takeover Plan details."): + class _FakeRunner: + async def run(self, workspace_path, **_kwargs): + forge_dir = workspace_path / ".forge" + forge_dir.mkdir(exist_ok=True, parents=True) + (forge_dir / "plan.md").write_text(plan_content) + result = MagicMock() + result.success = True + result.exit_code = 0 + result.stdout = "Done" + result.stderr = "" + return result + + return _FakeRunner() + + +def _make_mock_runner_failure(): + runner = MagicMock() + result = MagicMock() + result.success = False + result.exit_code = 1 + result.stdout = "" + result.stderr = "Container failed" + runner.run = AsyncMock(return_value=result) + return runner + + +class TestGeneratePlan: + """Tests for the generate_plan node.""" + + @pytest.mark.asyncio + async def test_generate_plan_success(self, base_task_state: TaskTakeoverState) -> None: + """Verify successful generation of task takeover plan.""" + mock_jira = _make_mock_jira() + runner = _make_mock_runner_success("## Plan\n\nTask Takeover Plan details.") + + with ( + patch("forge.workflow.nodes.task_takeover_planning.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.nodes.task_takeover_planning.ContainerRunner", return_value=runner + ), + patch("forge.workflow.nodes.task_takeover_planning.GitOperations") as mock_git, + ): + mock_git_instance = MagicMock() + mock_git_instance.clone = MagicMock() + mock_git.return_value = mock_git_instance + result = await generate_plan(base_task_state) + + assert result["plan_content"] == "## Plan\n\nTask Takeover Plan details." + assert result["current_node"] == "task_plan_approval_gate" + mock_jira.set_workflow_label.assert_called_once_with( + "TASK-002", ForgeLabel.TASK_PLAN_PENDING + ) + assert mock_jira.add_comment.call_count == 2 # Ack comment + Plan comment + + @pytest.mark.asyncio + async def test_generate_plan_with_truncation(self, base_task_state: TaskTakeoverState) -> None: + """Verify plan comment is truncated if it exceeds maximum comment size.""" + mock_jira = _make_mock_jira() + long_plan = "A" * 30_000 + runner = _make_mock_runner_success(long_plan) + + with ( + patch("forge.workflow.nodes.task_takeover_planning.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.nodes.task_takeover_planning.ContainerRunner", return_value=runner + ), + patch("forge.workflow.nodes.task_takeover_planning.GitOperations") as mock_git, + ): + mock_git_instance = MagicMock() + mock_git_instance.clone = MagicMock() + mock_git.return_value = mock_git_instance + await generate_plan(base_task_state) + + # Plan comment is the second comment + plan_comment = mock_jira.add_comment.call_args_list[1].args[1] + assert len(plan_comment) <= 25_500 + assert "truncated" in plan_comment.lower() + + @pytest.mark.asyncio + async def test_generate_plan_failure_retries(self, base_task_state: TaskTakeoverState) -> None: + """Verify container failure increments retry_count and handles errors.""" + mock_jira = _make_mock_jira() + runner = _make_mock_runner_failure() + + with ( + patch("forge.workflow.nodes.task_takeover_planning.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.nodes.task_takeover_planning.ContainerRunner", return_value=runner + ), + patch("forge.workflow.nodes.task_takeover_planning.GitOperations") as mock_git, + ): + mock_git_instance = MagicMock() + mock_git_instance.clone = MagicMock() + mock_git.return_value = mock_git_instance + result = await generate_plan(base_task_state) + + assert result["retry_count"] == 1 + assert result["last_error"] is not None + assert result["current_node"] == "generate_plan" + + +class TestRegeneratePlanFlow: + """Tests for the regeneration flow when a revision is requested.""" + + @pytest.mark.asyncio + async def test_regenerate_plan_with_feedback(self, base_task_state: TaskTakeoverState) -> None: + """Verify regenerate plan with revision request and feedback details.""" + state = { + **base_task_state, + "revision_requested": True, + "feedback_comment": "Please add more detailed logging.", + "plan_content": "## Plan\n\nOld Plan content.", + } + + mock_jira = _make_mock_jira() + runner = _make_mock_runner_success("## Plan\n\nNew Plan content with logging.") + + with ( + patch("forge.workflow.nodes.task_takeover_planning.JiraClient", return_value=mock_jira), + patch( + "forge.workflow.nodes.task_takeover_planning.ContainerRunner", return_value=runner + ), + patch("forge.workflow.nodes.task_takeover_planning.GitOperations") as mock_git, + ): + mock_git_instance = MagicMock() + mock_git_instance.clone = MagicMock() + mock_git.return_value = mock_git_instance + result = await generate_plan(state) + + assert result["plan_content"] == "## Plan\n\nNew Plan content with logging." + assert result["revision_requested"] is False + assert result["feedback_comment"] is None + assert result["current_node"] == "task_plan_approval_gate" + + +class TestPlanApprovalGate: + """Tests for plan_approval_gate node.""" + + def test_plan_approval_gate_pauses(self, base_task_state: TaskTakeoverState) -> None: + """Verify plan_approval_gate pauses the state.""" + result = plan_approval_gate(base_task_state) + assert result["is_paused"] is True + assert result["current_node"] == "plan_approval_gate" + + +class TestRoutePlanApproval: + """Tests for route_plan_approval function.""" + + def test_route_plan_approval_paused(self, base_task_state: TaskTakeoverState) -> None: + """Verify it returns END when state is paused.""" + state = {**base_task_state, "is_paused": True} + assert route_plan_approval(state) == END + + def test_route_plan_approval_revision_requested( + self, base_task_state: TaskTakeoverState + ) -> None: + """Verify it returns generate_plan when revision is requested and is_paused is False.""" + state = {**base_task_state, "is_paused": False, "revision_requested": True} + assert route_plan_approval(state) == "generate_plan" + + def test_route_plan_approval_approved(self, base_task_state: TaskTakeoverState) -> None: + """Verify it returns END when plan is approved (no other flags).""" + state = {**base_task_state, "is_paused": False} + assert route_plan_approval(state) == END diff --git a/tests/unit/workflow/nodes/test_task_takeover_pr.py b/tests/unit/workflow/nodes/test_task_takeover_pr.py new file mode 100644 index 00000000..6866759a --- /dev/null +++ b/tests/unit/workflow/nodes/test_task_takeover_pr.py @@ -0,0 +1,236 @@ +"""Unit tests for task takeover PR creation node.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.models.workflow import TicketType +from forge.workflow.nodes.task_takeover_pr import cleanup_podman_containers, create_task_takeover_pr + + +def _make_state( + ticket_key="TASK-123", + ticket_type=TicketType.TASK, + workspace_path="/tmp/ws", + current_repo="acme/backend", + implemented_tasks=None, +): + return { + "ticket_key": ticket_key, + "ticket_type": ticket_type, + "current_node": "create_task_takeover_pr", + "is_paused": False, + "retry_count": 0, + "last_error": None, + "workspace_path": workspace_path, + "current_repo": current_repo, + "implemented_tasks": implemented_tasks or [], + "context": {"branch_name": "forge/TASK-123", "guardrails": ""}, + } + + +def _make_mock_jira(): + jira = AsyncMock() + issue = MagicMock() + issue.summary = "Implement user authentication" + issue.description = "Details of the authentication task" + jira.get_issue = AsyncMock(return_value=issue) + jira.add_comment = AsyncMock() + jira.transition_issue = AsyncMock() + jira.close = AsyncMock() + return jira + + +def _make_mock_github(): + github = AsyncMock() + github.get_or_create_fork = AsyncMock( + return_value={ + "owner": {"login": "fork-owner"}, + "name": "backend", + } + ) + github.sync_fork_with_upstream = AsyncMock() + github.create_pull_request = AsyncMock( + return_value={ + "html_url": "https://github.com/acme/backend/pull/42", + "number": 42, + } + ) + github.close = AsyncMock() + return github + + +def _make_mock_git(): + git = MagicMock() + git.add_fork_remote = MagicMock() + git.push_to_fork = MagicMock() + return git + + +class TestTaskTakeoverPRNode: + """Tests for create_task_takeover_pr node in Task Takeover workflow.""" + + @pytest.mark.asyncio + @patch("forge.workflow.nodes.task_takeover_pr.teardown_workspace") + @patch("forge.workflow.nodes.task_takeover_pr.cleanup_podman_containers") + async def test_successful_pr_creation(self, mock_cleanup, mock_teardown) -> None: + """Test successful PR creation, commenting, transition and teardown.""" + state = _make_state() + mock_jira = _make_mock_jira() + mock_github = _make_mock_github() + mock_git = _make_mock_git() + + # We want teardown_workspace to simulate setting workspace_path to None and updating the state + async def fake_teardown(s): + return {**s, "workspace_path": None, "current_node": "workspace_complete"} + + mock_teardown.side_effect = fake_teardown + + with ( + patch("forge.workflow.nodes.task_takeover_pr.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.task_takeover_pr.GitHubClient", return_value=mock_github), + patch("forge.workflow.nodes.task_takeover_pr.GitOperations", return_value=mock_git), + ): + result_state = await create_task_takeover_pr(state) + + # Assert fork integration and push + mock_github.get_or_create_fork.assert_called_once_with("acme", "backend") + mock_github.sync_fork_with_upstream.assert_called_once_with("fork-owner", "backend") + mock_git.add_fork_remote.assert_called_once_with("fork-owner", "backend") + mock_git.push_to_fork.assert_called_once() + + # Assert PR creation + mock_github.create_pull_request.assert_called_once_with( + owner="acme", + repo="backend", + title="[TASK-123] Implement user authentication", + body="This Pull Request implements task takeover for ticket **[TASK-123]**.\n\n### Ticket Description\nDetails of the authentication task\n\nCo-authored-by: Forge ", + head="fork-owner:forge/TASK-123", + base="main", + ) + + # Assert Jira comment and transition + mock_jira.add_comment.assert_called_once() + comment_arg = mock_jira.add_comment.call_args[0][1] + assert "[PR #42]" in comment_arg + assert "https://github.com/acme/backend/pull/42" in comment_arg + + mock_jira.transition_issue.assert_called_once_with("TASK-123", "In Review") + + # Assert cleanup/teardown + mock_cleanup.assert_called_once_with("TASK-123") + mock_teardown.assert_called_once() + + # Assert resulting state + assert result_state["workspace_path"] is None + assert result_state["current_pr_url"] == "https://github.com/acme/backend/pull/42" + assert result_state["current_pr_number"] == 42 + assert result_state["fork_owner"] == "fork-owner" + assert result_state["fork_repo"] == "backend" + assert "https://github.com/acme/backend/pull/42" in result_state["pr_urls"] + + @pytest.mark.asyncio + @patch("forge.workflow.nodes.task_takeover_pr.teardown_workspace") + @patch("forge.workflow.nodes.task_takeover_pr.cleanup_podman_containers") + async def test_pr_creation_missing_workspace(self, mock_cleanup, mock_teardown) -> None: + """Test PR creation node fails gracefully when workspace_path is not set.""" + state = _make_state(workspace_path=None) + mock_jira = _make_mock_jira() + + with patch("forge.workflow.nodes.task_takeover_pr.JiraClient", return_value=mock_jira): + result_state = await create_task_takeover_pr(state) + + assert "Workspace not set up" in result_state["last_error"] + assert result_state["current_node"] == "create_task_takeover_pr" + mock_cleanup.assert_not_called() + mock_teardown.assert_not_called() + + @pytest.mark.asyncio + @patch("forge.workflow.nodes.task_takeover_pr.teardown_workspace") + @patch("forge.workflow.nodes.task_takeover_pr.cleanup_podman_containers") + async def test_pr_creation_unrecognized_repo_format(self, mock_cleanup, mock_teardown) -> None: + """Test PR creation node fails gracefully when current_repo format is invalid.""" + state = _make_state(current_repo="invalid-format") + mock_jira = _make_mock_jira() + + with patch("forge.workflow.nodes.task_takeover_pr.JiraClient", return_value=mock_jira): + result_state = await create_task_takeover_pr(state) + + assert "Invalid repository format" in result_state["last_error"] + assert result_state["current_node"] == "create_task_takeover_pr" + mock_cleanup.assert_not_called() + mock_teardown.assert_not_called() + + @pytest.mark.asyncio + @patch("forge.workflow.nodes.task_takeover_pr.teardown_workspace") + @patch("forge.workflow.nodes.task_takeover_pr.cleanup_podman_containers") + async def test_pr_creation_api_failure(self, mock_cleanup, mock_teardown) -> None: + """Test node handles API errors gracefully, recording error and incrementing retry count.""" + state = _make_state() + mock_jira = _make_mock_jira() + mock_github = _make_mock_github() + mock_github.get_or_create_fork = AsyncMock(side_effect=Exception("GitHub API down")) + mock_git = _make_mock_git() + + with ( + patch("forge.workflow.nodes.task_takeover_pr.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.task_takeover_pr.GitHubClient", return_value=mock_github), + patch("forge.workflow.nodes.task_takeover_pr.GitOperations", return_value=mock_git), + ): + result_state = await create_task_takeover_pr(state) + + assert "GitHub API down" in result_state["last_error"] + assert result_state["current_node"] == "create_task_takeover_pr" + assert result_state["retry_count"] == 1 + mock_cleanup.assert_not_called() + mock_teardown.assert_not_called() + + @pytest.mark.asyncio + @patch("asyncio.create_subprocess_exec") + async def test_cleanup_podman_containers(self, mock_create_proc) -> None: + """Test cleanup_podman_containers stops and removes matched containers.""" + mock_ps_proc = AsyncMock() + mock_ps_proc.communicate = AsyncMock(return_value=(b"forge-TASK-123-abc\n", b"")) + + mock_stop_proc = AsyncMock() + mock_stop_proc.wait = AsyncMock() + + mock_rm_proc = AsyncMock() + mock_rm_proc.wait = AsyncMock() + + def side_effect(*args, **_kwargs): + if args[1] == "ps": + return mock_ps_proc + elif args[1] == "stop": + return mock_stop_proc + elif args[1] == "rm": + return mock_rm_proc + return AsyncMock() + + mock_create_proc.side_effect = side_effect + + await cleanup_podman_containers("TASK-123") + + # Verify podman commands are executed + assert mock_create_proc.call_count >= 3 + + # Verify first call is to ps + first_call_args = mock_create_proc.call_args_list[0][0] + assert first_call_args[0] == "podman" + assert first_call_args[1] == "ps" + assert "name=forge-TASK-123-" in first_call_args + + # Verify stop and rm are called + stop_called = False + rm_called = False + for call in mock_create_proc.call_args_list: + args = call[0] + if "stop" in args: + stop_called = True + assert "forge-TASK-123-abc" in args + if "rm" in args: + rm_called = True + assert "forge-TASK-123-abc" in args + + assert stop_called is True + assert rm_called is True diff --git a/tests/unit/workflow/nodes/test_task_takeover_review.py b/tests/unit/workflow/nodes/test_task_takeover_review.py new file mode 100644 index 00000000..ededb984 --- /dev/null +++ b/tests/unit/workflow/nodes/test_task_takeover_review.py @@ -0,0 +1,181 @@ +"""Unit tests for the qualitative review node in Task Takeover workflow.""" + +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.workflow.nodes.task_takeover_review import ( + _extract_acceptance_criteria, + _parse_qualitative_review, + run_qualitative_review, +) +from forge.workflow.task_takeover.state import ( + TaskTakeoverState, + create_initial_task_takeover_state, +) + + +def make_task_state(**overrides: Any) -> TaskTakeoverState: + """Create a TaskTakeoverState dict for review tests.""" + state = create_initial_task_takeover_state("TASK-101") + state_dict = cast(dict[str, Any], state) + state_dict.update(overrides) + return cast(TaskTakeoverState, state_dict) + + +@pytest.fixture +def base_task_state() -> TaskTakeoverState: + return make_task_state( + workspace_path="/tmp/fake-workspace-review", + current_repo="owner/repo", + context={"branch_name": "task/TASK-101"}, + ) + + +def _make_mock_jira(description: str = "Acceptance Criteria:\n- Foo\n- Bar") -> AsyncMock: + jira = AsyncMock() + issue = MagicMock() + issue.summary = "Fix session timeout" + issue.description = description + issue.project_key = "TASK" + jira.get_issue = AsyncMock(return_value=issue) + jira.add_comment = AsyncMock() + jira.close = AsyncMock() + return jira + + +class TestExtractAcceptanceCriteria: + """Tests for _extract_acceptance_criteria.""" + + def test_extract_found(self) -> None: + desc = "Some setup info.\nAcceptance Criteria:\n1. Must run fast.\n2. Must pass." + criteria = _extract_acceptance_criteria(desc) + assert criteria.startswith("Acceptance Criteria:") + assert "Must pass." in criteria + + def test_extract_not_found(self) -> None: + desc = "Plain description without the heading." + criteria = _extract_acceptance_criteria(desc) + assert criteria == desc + + def test_extract_empty(self) -> None: + assert _extract_acceptance_criteria("") == "No description or acceptance criteria provided." + + +class TestParseQualitativeReview: + """Tests for _parse_qualitative_review.""" + + def test_parse_adequate(self) -> None: + output = "verdict: adequate\nfeedback: All is well!" + verdict, feedback = _parse_qualitative_review(output) + assert verdict == "adequate" + assert feedback == "All is well!" + + def test_parse_tests_incomplete(self) -> None: + output = "verdict: tests_incomplete\nfeedback: Please add more tests." + verdict, feedback = _parse_qualitative_review(output) + assert verdict == "tests_incomplete" + assert feedback == "Please add more tests." + + def test_parse_invalid_defaults_to_incomplete(self) -> None: + output = "verdict: perfect\nfeedback: Outstanding." + verdict, feedback = _parse_qualitative_review(output) + assert verdict == "tests_incomplete" + + +class TestRunQualitativeReview: + """Tests for run_qualitative_review node.""" + + @pytest.mark.asyncio + async def test_run_qualitative_review_success(self, base_task_state: TaskTakeoverState) -> None: + mock_jira = _make_mock_jira() + mock_agent = AsyncMock() + mock_agent.run_task = AsyncMock( + return_value="verdict: adequate\nfeedback: Brilliant changes." + ) + + with ( + patch("forge.workflow.nodes.task_takeover_review.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.task_takeover_review.GitOperations") as mock_git, + patch("forge.workflow.nodes.task_takeover_review.ForgeAgent", return_value=mock_agent), + patch("forge.workflow.nodes.task_takeover_review.post_status_comment"), + ): + mock_git_instance = MagicMock() + mock_git_instance._run_git = MagicMock() + mock_git_instance._run_git.return_value.returncode = 0 + mock_git_instance._run_git.return_value.stdout = "diff contents" + mock_git.return_value = mock_git_instance + + result = await run_qualitative_review(base_task_state) + + assert result["review_verdict"] == "adequate" + assert result["review_feedback"] == "Brilliant changes." + assert result["qualitative_review_retry_count"] == 0 + assert result["qualitative_review_failed"] is False + assert result["current_node"] == "qualitative_review" + assert result["last_error"] is None + + # Verify read-only agent was invoked + mock_agent.run_task.assert_called_once() + _, kwargs = mock_agent.run_task.call_args + assert kwargs["include_tools"] is False + + @pytest.mark.asyncio + async def test_run_qualitative_review_tests_incomplete( + self, base_task_state: TaskTakeoverState + ) -> None: + mock_jira = _make_mock_jira() + mock_agent = AsyncMock() + mock_agent.run_task = AsyncMock( + return_value="verdict: tests_incomplete\nfeedback: Write more unit tests." + ) + + with ( + patch("forge.workflow.nodes.task_takeover_review.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.task_takeover_review.GitOperations") as mock_git, + patch("forge.workflow.nodes.task_takeover_review.ForgeAgent", return_value=mock_agent), + patch("forge.workflow.nodes.task_takeover_review.post_status_comment"), + ): + mock_git_instance = MagicMock() + mock_git_instance._run_git = MagicMock() + mock_git_instance._run_git.return_value.returncode = 0 + mock_git_instance._run_git.return_value.stdout = "diff contents" + mock_git.return_value = mock_git_instance + + result = await run_qualitative_review(base_task_state) + + assert result["review_verdict"] == "tests_incomplete" + assert result["review_feedback"] == "Write more unit tests." + assert result["qualitative_review_retry_count"] == 1 + assert result["qualitative_review_failed"] is True + assert result["current_node"] == "qualitative_review" + assert result["last_error"] is None + + @pytest.mark.asyncio + async def test_run_qualitative_review_missing_workspace( + self, base_task_state: TaskTakeoverState + ) -> None: + base_task_state["workspace_path"] = None + + result = await run_qualitative_review(base_task_state) + assert result["last_error"] == "Workspace not set up" + assert result["current_node"] == "qualitative_review" + + @pytest.mark.asyncio + async def test_run_qualitative_review_exception_handling( + self, base_task_state: TaskTakeoverState + ) -> None: + mock_jira = _make_mock_jira() + mock_jira.get_issue = AsyncMock(side_effect=RuntimeError("Jira connection failure")) + + with ( + patch("forge.workflow.nodes.task_takeover_review.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.error_handler.notify_error") as mock_notify, + ): + result = await run_qualitative_review(base_task_state) + + assert result["last_error"] is not None + assert "Jira connection failure" in result["last_error"] + assert result["current_node"] == "qualitative_review" + mock_notify.assert_called_once() diff --git a/tests/unit/workflow/nodes/test_task_takeover_triage.py b/tests/unit/workflow/nodes/test_task_takeover_triage.py new file mode 100644 index 00000000..118e4e30 --- /dev/null +++ b/tests/unit/workflow/nodes/test_task_takeover_triage.py @@ -0,0 +1,246 @@ +"""Unit tests for triage_task node.""" + +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.models.workflow import ForgeLabel +from forge.workflow.task_takeover.state import ( + TaskTakeoverState, + create_initial_task_takeover_state, +) + + +def make_task_state(**overrides: Any) -> TaskTakeoverState: + """Create a TaskTakeoverState dict for triage tests.""" + state = create_initial_task_takeover_state("TASK-001") + state_dict = cast(dict[str, Any], state) + state_dict.update(overrides) + return cast(TaskTakeoverState, state_dict) + + +@pytest.fixture +def complete_ticket_state() -> TaskTakeoverState: + """TaskTakeoverState with a well-specified ticket.""" + return make_task_state( + current_node="start", + ) + + +@pytest.fixture +def resume_ticket_state() -> TaskTakeoverState: + """TaskTakeoverState resuming from triage_gate.""" + return make_task_state( + current_node="triage_gate", + is_paused=True, + ) + + +@pytest.fixture +def mock_jira() -> MagicMock: + jira = MagicMock() + jira.get_issue = AsyncMock( + return_value=MagicMock( + summary="Login fails with special characters", + description="Problem Statement: ... Proposed Solution/Approach: ... Acceptance Criteria: ...", + ) + ) + jira.get_comments = AsyncMock(return_value=[]) + jira.add_comment = AsyncMock() + jira.set_workflow_label = AsyncMock() + jira.close = AsyncMock() + return jira + + +@pytest.fixture +def mock_agent_sufficient() -> MagicMock: + """ForgeAgent that returns 'sufficient' for the triage prompt.""" + agent = MagicMock() + agent.run_task = AsyncMock(return_value="sufficient") + agent.close = AsyncMock() + return agent + + +@pytest.fixture +def mock_agent_missing_fields() -> MagicMock: + """ForgeAgent that returns a JSON list of missing fields.""" + agent = MagicMock() + agent.run_task = AsyncMock( + return_value='["Problem Statement", "Acceptance Criteria"]' + ) + agent.close = AsyncMock() + return agent + + +class TestTriageTaskSufficientTicket: + """When the ticket has all required fields, triage passes.""" + + @pytest.mark.asyncio + async def test_sets_triage_passed_true( + self, + complete_ticket_state: TaskTakeoverState, + mock_jira: MagicMock, + mock_agent_sufficient: MagicMock, + ) -> None: + """triage_passed=True and transitions to generate_plan on success.""" + from forge.workflow.nodes.task_takeover_triage import triage_task + + with ( + patch( + "forge.workflow.nodes.task_takeover_triage.JiraClient", return_value=mock_jira + ), + patch( + "forge.workflow.nodes.task_takeover_triage.ForgeAgent", + return_value=mock_agent_sufficient, + ), + ): + result = await triage_task(complete_ticket_state) + + assert result["triage_passed"] is True + assert result["current_node"] == "generate_plan" + assert result["is_paused"] is False + assert result["triage_missing_fields"] == [] + + @pytest.mark.asyncio + async def test_acknowledgement_comment_posted_first( + self, + complete_ticket_state: TaskTakeoverState, + mock_jira: MagicMock, + mock_agent_sufficient: MagicMock, + ) -> None: + """Acknowledgement comment is posted before triage evaluation on first invocation.""" + from forge.workflow.nodes.task_takeover_triage import triage_task + + call_order: list[str] = [] + + async def mock_comment(*_args: Any, **_kwargs: Any) -> MagicMock: + call_order.append("comment") + return MagicMock() + + async def mock_run_task(*_args: Any, **_kwargs: Any) -> str: + call_order.append("agent") + return "sufficient" + + mock_jira.add_comment.side_effect = mock_comment + mock_agent_sufficient.run_task.side_effect = mock_run_task + + with ( + patch( + "forge.workflow.nodes.task_takeover_triage.JiraClient", return_value=mock_jira + ), + patch( + "forge.workflow.nodes.task_takeover_triage.ForgeAgent", + return_value=mock_agent_sufficient, + ), + ): + await triage_task(complete_ticket_state) + + assert call_order[0] == "comment" + assert mock_jira.add_comment.call_count == 2 # Ack comment + Success comment + + @pytest.mark.asyncio + async def test_acknowledgement_comment_suppressed_on_resume( + self, + resume_ticket_state: TaskTakeoverState, + mock_jira: MagicMock, + mock_agent_sufficient: MagicMock, + ) -> None: + """Acknowledgement comment is bypassed when resuming from triage_gate.""" + from forge.workflow.nodes.task_takeover_triage import triage_task + + with ( + patch( + "forge.workflow.nodes.task_takeover_triage.JiraClient", return_value=mock_jira + ), + patch( + "forge.workflow.nodes.task_takeover_triage.ForgeAgent", + return_value=mock_agent_sufficient, + ), + ): + await triage_task(resume_ticket_state) + + # Only the pass comment should be posted on resume + assert mock_jira.add_comment.call_count == 1 + comment_text = mock_jira.add_comment.call_args_list[0].args[1] + assert "Thanks for the update" in comment_text + + +class TestTriageTaskMissingFields: + """When the ticket is missing required fields, triage pauses.""" + + @pytest.mark.asyncio + async def test_sets_triage_passed_false( + self, + complete_ticket_state: TaskTakeoverState, + mock_jira: MagicMock, + mock_agent_missing_fields: MagicMock, + ) -> None: + """triage_passed=False, is_paused=True, and transitions to triage_gate on failure.""" + from forge.workflow.nodes.task_takeover_triage import triage_task + + with ( + patch( + "forge.workflow.nodes.task_takeover_triage.JiraClient", return_value=mock_jira + ), + patch( + "forge.workflow.nodes.task_takeover_triage.ForgeAgent", + return_value=mock_agent_missing_fields, + ), + ): + result = await triage_task(complete_ticket_state) + + assert result["triage_passed"] is False + assert result["current_node"] == "triage_gate" + assert result["is_paused"] is True + assert "Problem Statement" in result["triage_missing_fields"] + assert "Acceptance Criteria" in result["triage_missing_fields"] + + @pytest.mark.asyncio + async def test_applies_triage_pending_label_and_posts_comment( + self, + complete_ticket_state: TaskTakeoverState, + mock_jira: MagicMock, + mock_agent_missing_fields: MagicMock, + ) -> None: + """Applies forge:task-triage-pending label and posts a detailed comment on failure.""" + from forge.workflow.nodes.task_takeover_triage import triage_task + + with ( + patch( + "forge.workflow.nodes.task_takeover_triage.JiraClient", return_value=mock_jira + ), + patch( + "forge.workflow.nodes.task_takeover_triage.ForgeAgent", + return_value=mock_agent_missing_fields, + ), + ): + await triage_task(complete_ticket_state) + + mock_jira.set_workflow_label.assert_called_once_with( + "TASK-001", ForgeLabel.TASK_TRIAGE_PENDING + ) + assert mock_jira.add_comment.call_count == 2 # Ack comment + Missing fields comment + missing_fields_comment = mock_jira.add_comment.call_args_list[1].args[1] + assert "Problem Statement" in missing_fields_comment + assert "Acceptance Criteria" in missing_fields_comment + + +class TestTriageTaskErrorHandling: + """Error handling and retry logic.""" + + @pytest.mark.asyncio + async def test_escalates_to_blocked_on_max_retries(self, mock_jira: MagicMock) -> None: + """Transitions to escalate_blocked when max retries exceeded.""" + from forge.workflow.nodes.task_takeover_triage import triage_task + + state = make_task_state(retry_count=3) + with ( + patch( + "forge.workflow.nodes.task_takeover_triage.JiraClient", return_value=mock_jira + ), + ): + result = await triage_task(state) + + assert result["current_node"] == "escalate_blocked" + assert result["is_paused"] is False diff --git a/tests/unit/workflow/task_takeover/test_graph.py b/tests/unit/workflow/task_takeover/test_graph.py new file mode 100644 index 00000000..0f33d467 --- /dev/null +++ b/tests/unit/workflow/task_takeover/test_graph.py @@ -0,0 +1,118 @@ +"""Unit tests for Task Takeover workflow state and graph structure.""" + +from typing import Any, cast +import pytest +from langgraph.graph import END, StateGraph + +from forge.models.workflow import TicketType +from forge.workflow.task_takeover.graph import ( + _route_after_triage_check, + build_task_takeover_graph, + route_entry, +) +from forge.workflow.task_takeover.state import ( + TaskTakeoverState, + create_initial_task_takeover_state, +) + + +def _task_state(**overrides: Any) -> TaskTakeoverState: + base = { + "ticket_key": "TASK-1", + "ticket_type": TicketType.TASK, + "current_node": "start", + "is_paused": False, + "retry_count": 0, + "last_error": None, + "triage_passed": False, + "triage_missing_fields": [], + "plan_content": None, + } + return cast(TaskTakeoverState, {**base, **overrides}) + + +class TestTaskTakeoverState: + """Test TaskTakeoverState definition and initial state creation.""" + + def test_state_fields(self) -> None: + """Verify TaskTakeoverState has required fields.""" + # Simple instantiation to verify type definition + state = TaskTakeoverState( + triage_passed=True, + triage_missing_fields=["steps"], + plan_content="Takeover plan", + ) + assert state["triage_passed"] is True + assert state["triage_missing_fields"] == ["steps"] + assert state["plan_content"] == "Takeover plan" + + def test_create_initial_state_defaults(self) -> None: + """create_initial_task_takeover_state sets default values appropriately.""" + state = create_initial_task_takeover_state("TASK-1") + assert state["ticket_key"] == "TASK-1" + assert state["ticket_type"] == TicketType.TASK + assert state["triage_passed"] is False + assert state["triage_missing_fields"] == [] + assert state["plan_content"] is None + assert state["current_node"] == "start" + + +class TestRouteEntry: + """route_entry maps current_node values to correct resume targets.""" + + @pytest.mark.parametrize( + "node,expected", + [ + ("triage_check", "triage_check"), + ("triage_gate", "triage_gate"), + ("generate_plan", "generate_plan"), + ("task_plan_approval_gate", "task_plan_approval_gate"), + ("escalate_blocked", "escalate_blocked"), + ("complete", END), + ], + ) + def test_route_entry_mapping(self, node: str, expected: str) -> None: + """route_entry maps each current_node to the correct resume target.""" + state = _task_state(current_node=node) + assert route_entry(state) == expected + + def test_new_task_routes_to_triage(self) -> None: + """A fresh task takeover ticket with no current_node starts at triage_check.""" + state = create_initial_task_takeover_state(ticket_key="TASK-1") + assert route_entry(state) == "triage_check" + + def test_unknown_node_routes_to_triage(self) -> None: + """An unrecognized current_node value restarts from triage_check.""" + state = _task_state(current_node="unrecognized_node") + assert route_entry(state) == "triage_check" + + +class TestTriageCheckRouting: + """_route_after_triage_check transitions correctly.""" + + @pytest.mark.parametrize( + "current_node,expected", + [ + ("analyze_bug", "generate_plan"), + ("triage_gate", "triage_gate"), + ("escalate_blocked", "escalate_blocked"), + ("unknown_node", "triage_gate"), + ], + ) + def test_route_after_triage_check(self, current_node: str, expected: str) -> None: + """_route_after_triage_check maps triage results to task takeover nodes.""" + state = _task_state(current_node=current_node) + assert _route_after_triage_check(state) == expected + + +class TestTaskTakeoverGraph: + """Test StateGraph compilation and logic.""" + + def test_build_task_takeover_graph(self) -> None: + """build_task_takeover_graph returns a compiled StateGraph.""" + graph = build_task_takeover_graph() + assert isinstance(graph, StateGraph) + + # Compile the graph to verify correctness + compiled_graph = graph.compile() + assert compiled_graph is not None diff --git a/tests/unit/workflow/task_takeover/test_workflow.py b/tests/unit/workflow/task_takeover/test_workflow.py new file mode 100644 index 00000000..3a22ce8d --- /dev/null +++ b/tests/unit/workflow/task_takeover/test_workflow.py @@ -0,0 +1,96 @@ +"""Tests for TaskTakeoverWorkflow.""" + +from unittest.mock import patch + +import pytest +from langgraph.graph import StateGraph + +from forge.models.workflow import TicketType +from forge.workflow.task_takeover import TaskTakeoverWorkflow +from forge.workflow.task_takeover.state import TaskTakeoverState + + +class TestTaskTakeoverWorkflow: + """Tests for TaskTakeoverWorkflow class.""" + + @pytest.fixture(autouse=True) + def mock_settings(self): + """Mock settings to enable task takeover.""" + from forge.config import Settings, TaskTakeoverSettings + + mock_s = Settings() + mock_s.task_takeover = TaskTakeoverSettings(enabled=True) + + with patch("forge.config.get_settings", return_value=mock_s): + yield + + def test_workflow_has_name(self): + """TaskTakeoverWorkflow has name attribute.""" + workflow = TaskTakeoverWorkflow() + assert workflow.name == "task_takeover" + + def test_workflow_has_description(self): + """TaskTakeoverWorkflow has description.""" + workflow = TaskTakeoverWorkflow() + assert workflow.description == "Task Takeover workflow" + + def test_state_schema_returns_task_takeover_state(self): + """state_schema returns TaskTakeoverState.""" + workflow = TaskTakeoverWorkflow() + assert workflow.state_schema is TaskTakeoverState + + def test_build_graph_returns_state_graph(self): + """build_graph returns a StateGraph.""" + workflow = TaskTakeoverWorkflow() + graph = workflow.build_graph() + assert isinstance(graph, StateGraph) + + def test_create_initial_state(self): + """create_initial_state returns TaskTakeoverState with defaults.""" + workflow = TaskTakeoverWorkflow() + state = workflow.create_initial_state("TASK-123") + + assert state["ticket_key"] == "TASK-123" + assert state["ticket_type"] == TicketType.TASK + assert state["current_node"] == "start" + + def test_matches_strictly_when_both_managed_and_trigger_present(self): + """matches returns True when forge:managed and exact trigger are present.""" + workflow = TaskTakeoverWorkflow() + + # Exact trigger "forge:task-takeover" + assert ( + workflow.matches(TicketType.TASK, ["forge:managed", "forge:task-takeover"], {}) is True + ) + + # Exact trigger "forge:managed:task" + assert ( + workflow.matches(TicketType.TASK, ["forge:managed", "forge:managed:task"], {}) is True + ) + + # Exact trigger "forge:managed:task-takeover" + assert ( + workflow.matches(TicketType.TASK, ["forge:managed", "forge:managed:task-takeover"], {}) + is True + ) + + def test_matches_returns_false_when_only_managed_present(self): + """matches returns False when only forge:managed is present without trigger.""" + workflow = TaskTakeoverWorkflow() + assert workflow.matches(TicketType.TASK, ["forge:managed"], {}) is False + assert ( + workflow.matches(TicketType.TASK, ["forge:managed", "forge:prd-drafting"], {}) is False + ) + + def test_matches_returns_true_when_only_trigger_present_without_managed(self): + """matches returns True when trigger label is present even if forge:managed is missing.""" + workflow = TaskTakeoverWorkflow() + assert workflow.matches(TicketType.TASK, ["forge:task-takeover"], {}) is True + assert workflow.matches(TicketType.TASK, ["forge:managed:task"], {}) is True + assert workflow.matches(TicketType.TASK, ["forge:managed:task-takeover"], {}) is True + + def test_matches_returns_false_with_non_trigger_labels(self): + """matches returns False if no exact trigger label is present.""" + workflow = TaskTakeoverWorkflow() + assert workflow.matches(TicketType.TASK, ["forge:managed-something"], {}) is False + assert workflow.matches(TicketType.TASK, ["other-label"], {}) is False diff --git a/tests/unit/workflow/test_registry.py b/tests/unit/workflow/test_registry.py index 5c7ba5a7..808e537e 100644 --- a/tests/unit/workflow/test_registry.py +++ b/tests/unit/workflow/test_registry.py @@ -1,5 +1,8 @@ """Tests for workflow registry.""" +from unittest.mock import patch + +import pytest from forge.models.workflow import TicketType @@ -7,6 +10,17 @@ class TestDefaultRouter: """Tests for create_default_router.""" + @pytest.fixture(autouse=True) + def mock_settings(self): + """Mock settings to enable task takeover.""" + from forge.config import Settings, TaskTakeoverSettings + + mock_s = Settings() + mock_s.task_takeover = TaskTakeoverSettings(enabled=True) + + with patch("forge.config.get_settings", return_value=mock_s): + yield + def test_creates_router_with_workflows(self): """create_default_router returns router with workflows.""" from forge.workflow.registry import create_default_router @@ -14,7 +28,7 @@ def test_creates_router_with_workflows(self): router = create_default_router() workflows = router.list_workflows() - assert len(workflows) >= 2 + assert len(workflows) >= 3 def test_resolves_feature_to_feature_workflow(self): """Feature tickets resolve to FeatureWorkflow.""" @@ -35,3 +49,75 @@ def test_resolves_bug_to_bug_workflow(self): assert workflow is not None assert workflow.name == "bug" + + def test_resolves_to_task_takeover_workflow_with_trigger_labels(self): + """Tickets with forge:managed and task takeover trigger labels resolve to TaskTakeoverWorkflow.""" + from forge.workflow.registry import create_default_router + + router = create_default_router() + + # Feature ticket with task takeover triggers + workflow = router.resolve( + TicketType.FEATURE, + ["forge:managed", "forge:task-takeover"], + {}, + ) + assert workflow is not None + assert workflow.name == "task_takeover" + + # Bug ticket with task takeover triggers + workflow = router.resolve( + TicketType.BUG, + ["forge:managed", "forge:managed:task-takeover"], + {}, + ) + assert workflow is not None + assert workflow.name == "task_takeover" + + # Standalone task ticket with takeover triggers + workflow = router.resolve( + TicketType.TASK, + ["forge:managed", "forge:managed:task"], + {}, + ) + assert workflow is not None + assert workflow.name == "task_takeover" + + def test_resolves_to_general_workflow_without_trigger_labels(self): + """Tickets with forge:managed but without task takeover trigger labels resolve to general workflows.""" + from forge.workflow.registry import create_default_router + + router = create_default_router() + + # Feature ticket without task takeover triggers + workflow = router.resolve( + TicketType.FEATURE, + ["forge:managed"], + {}, + ) + assert workflow is not None + assert workflow.name == "feature" + + # Bug ticket without task takeover triggers + workflow = router.resolve( + TicketType.BUG, + ["forge:managed"], + {}, + ) + assert workflow is not None + assert workflow.name == "bug" + + def test_task_takeover_has_priority_over_bug_workflow(self): + """Conflicting labels (e.g. both forge:managed:bug and forge:managed:task) prioritize Task Takeover routing.""" + from forge.workflow.registry import create_default_router + + router = create_default_router() + + # A Bug ticket with both forge:managed and forge:managed:task should resolve to TaskTakeoverWorkflow, not BugWorkflow + workflow = router.resolve( + TicketType.BUG, + ["forge:managed", "forge:managed:task"], + {}, + ) + assert workflow is not None + assert workflow.name == "task_takeover" diff --git a/tests/unit/workflow/test_router.py b/tests/unit/workflow/test_router.py index 6fedaac4..a076d70b 100644 --- a/tests/unit/workflow/test_router.py +++ b/tests/unit/workflow/test_router.py @@ -1,5 +1,8 @@ """Tests for WorkflowRouter.""" +from unittest.mock import patch + +import pytest from langgraph.graph import StateGraph from forge.models.workflow import TicketType @@ -16,9 +19,7 @@ class MockWorkflow(BaseWorkflow): def state_schema(self) -> type: return BaseState - def matches( - self, ticket_type: TicketType, _labels: list[str], _event: dict - ) -> bool: + def matches(self, ticket_type: TicketType, _labels: list[str], _event: dict) -> bool: return ticket_type == TicketType.FEATURE def build_graph(self) -> StateGraph: @@ -38,9 +39,7 @@ class MockBugWorkflow(BaseWorkflow): def state_schema(self) -> type: return BaseState - def matches( - self, ticket_type: TicketType, _labels: list[str], _event: dict - ) -> bool: + def matches(self, ticket_type: TicketType, _labels: list[str], _event: dict) -> bool: return ticket_type == TicketType.BUG def build_graph(self) -> StateGraph: @@ -53,6 +52,17 @@ def build_graph(self) -> StateGraph: class TestWorkflowRouter: """Tests for WorkflowRouter.""" + @pytest.fixture(autouse=True) + def mock_settings(self): + """Mock settings to enable task takeover.""" + from forge.config import Settings, TaskTakeoverSettings + + mock_s = Settings() + mock_s.task_takeover = TaskTakeoverSettings(enabled=True) + + with patch("forge.config.get_settings", return_value=mock_s): + yield + def test_register_workflow(self): """Can register a workflow class.""" from forge.workflow.router import WorkflowRouter @@ -126,3 +136,35 @@ def test_list_workflows(self): assert len(workflows) == 2 assert workflows[0]["name"] == "mock" assert workflows[1]["name"] == "mock_bug" + + def test_resolve_exact_matching_no_accidental_prefix_triggers(self): + """Verify that prefix-based triggers do not resolve to TaskTakeoverWorkflow.""" + from forge.workflow.router import WorkflowRouter + from forge.workflow.task_takeover import TaskTakeoverWorkflow + + router = WorkflowRouter() + router.register(TaskTakeoverWorkflow) + + # Labels starting with triggers but are not exact matches should not resolve + prefix_labels_cases = [ + ["forge:managed", "forge:task-takeover-fake"], + ["forge:managed", "forge:managed:task-fake"], + ["forge:managed", "forge:managed:task-takeover-fake"], + ] + + for labels in prefix_labels_cases: + workflow = router.resolve( + ticket_type=TicketType.BUG, + labels=labels, + event={}, + ) + assert workflow is None, f"Accidentally resolved with prefix-trigger labels: {labels}" + + # An exact match still resolves correctly + workflow = router.resolve( + ticket_type=TicketType.BUG, + labels=["forge:managed", "forge:task-takeover"], + event={}, + ) + assert workflow is not None + assert workflow.name == "task_takeover" diff --git a/tests/workflow/test_qualitative_review.py b/tests/workflow/test_qualitative_review.py new file mode 100644 index 00000000..b251ba2a --- /dev/null +++ b/tests/workflow/test_qualitative_review.py @@ -0,0 +1,366 @@ +"""Unit and integration tests for Task Takeover Qualitative Review Node.""" + +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.workflow.nodes.task_takeover_review import ( + _extract_acceptance_criteria, + _parse_qualitative_review, + run_qualitative_review, +) +from forge.workflow.task_takeover.state import ( + TaskTakeoverState, + create_initial_task_takeover_state, +) + + +def make_task_state(**overrides: Any) -> TaskTakeoverState: + """Create a TaskTakeoverState dict for review tests.""" + state = create_initial_task_takeover_state("TASK-101") + state_dict = cast(dict[str, Any], state) + state_dict.update(overrides) + return cast(TaskTakeoverState, state_dict) + + +@pytest.fixture +def base_task_state() -> TaskTakeoverState: + return make_task_state( + workspace_path="/tmp/fake-workspace-review", + current_repo="owner/repo", + context={"branch_name": "task/TASK-101"}, + ) + + +def _make_mock_jira(description: str = "Acceptance Criteria:\n- Foo\n- Bar") -> AsyncMock: + jira = AsyncMock() + issue = MagicMock() + issue.summary = "Fix session timeout" + issue.description = description + issue.project_key = "TASK" + jira.get_issue = AsyncMock(return_value=issue) + jira.add_comment = AsyncMock() + jira.close = AsyncMock() + return jira + + +class TestParseQualitativeReview: + """Tests for _parse_qualitative_review helper.""" + + def test_parses_adequate_success(self) -> None: + output = "verdict: adequate\nfeedback: Everything is correct and fully tested." + verdict, feedback = _parse_qualitative_review(output) + assert verdict == "adequate" + assert feedback == "Everything is correct and fully tested." + + def test_parses_tests_incomplete_failure(self) -> None: + output = "verdict: tests_incomplete\nfeedback: Tests do not fail without the fix." + verdict, feedback = _parse_qualitative_review(output) + assert verdict == "tests_incomplete" + assert feedback == "Tests do not fail without the fix." + + def test_unknown_verdict_defaults_to_tests_incomplete(self) -> None: + """Unrecognized or absent verdict defaults to tests_incomplete to avoid skipping quality gate.""" + output = "verdict: outstanding\nfeedback: Great work." + verdict, feedback = _parse_qualitative_review(output) + assert verdict == "tests_incomplete" + assert feedback == "Great work." + + def test_case_insensitive_verdict(self) -> None: + output = "Verdict: Adequate\nfeedback: Well done." + verdict, feedback = _parse_qualitative_review(output) + assert verdict == "adequate" + assert feedback == "Well done." + + def test_backtick_and_literal_escape_after_verdict_parses_correctly(self) -> None: + """LLM output with trailing backtick and literal \\n-dash is still parsed.""" + output = "verdict: adequate`\\n- next section\nfeedback: Good." + verdict, feedback = _parse_qualitative_review(output) + assert verdict == "adequate" + assert feedback == "Good." + + def test_verdict_in_inline_code_backticks_parses_correctly(self) -> None: + """Verdict wrapped in markdown inline code backticks is parsed.""" + output = "verdict: `adequate`\nfeedback: Excellent." + verdict, feedback = _parse_qualitative_review(output) + assert verdict == "adequate" + assert feedback == "Excellent." + + +class TestExtractAcceptanceCriteria: + """Tests for _extract_acceptance_criteria helper.""" + + def test_extract_found(self) -> None: + desc = "Some setup info.\nAcceptance Criteria:\n1. Must run fast.\n2. Must pass." + criteria = _extract_acceptance_criteria(desc) + assert criteria.startswith("Acceptance Criteria:") + assert "Must pass." in criteria + + def test_extract_not_found(self) -> None: + desc = "Plain description without the heading." + criteria = _extract_acceptance_criteria(desc) + assert criteria == desc + + def test_extract_empty(self) -> None: + assert _extract_acceptance_criteria("") == "No description or acceptance criteria provided." + + +class TestRunQualitativeReviewNode: + """Tests for run_qualitative_review node.""" + + @pytest.mark.asyncio + async def test_run_qualitative_review_success_state_updates( + self, base_task_state: TaskTakeoverState + ) -> None: + """Verify state updates when qualitative review passes (verdict is adequate).""" + mock_jira = _make_mock_jira() + mock_agent = AsyncMock() + mock_agent.run_task = AsyncMock( + return_value="verdict: adequate\nfeedback: All acceptance criteria met and automated tests verified." + ) + + with ( + patch("forge.workflow.nodes.task_takeover_review.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.task_takeover_review.GitOperations") as mock_git, + patch("forge.workflow.nodes.task_takeover_review.ForgeAgent", return_value=mock_agent), + patch("forge.workflow.nodes.task_takeover_review.post_status_comment"), + ): + mock_git_instance = MagicMock() + mock_git_instance._run_git = MagicMock() + mock_git_instance._run_git.return_value.returncode = 0 + mock_git_instance._run_git.return_value.stdout = "diff contents" + mock_git.return_value = mock_git_instance + + result = await run_qualitative_review(base_task_state) + + assert result["review_verdict"] == "adequate" + assert "All acceptance criteria met" in result["review_feedback"] + assert result["qualitative_review_retry_count"] == 0 + assert result["qualitative_review_failed"] is False + assert result["current_node"] == "qualitative_review" + assert result["last_error"] is None + + @pytest.mark.asyncio + async def test_run_qualitative_review_failure_state_updates( + self, base_task_state: TaskTakeoverState + ) -> None: + """Verify state updates and retry metric increment when review fails.""" + mock_jira = _make_mock_jira() + mock_agent = AsyncMock() + mock_agent.run_task = AsyncMock( + return_value="verdict: tests_incomplete\nfeedback: No automated tests found in the git diff." + ) + + with ( + patch("forge.workflow.nodes.task_takeover_review.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.task_takeover_review.GitOperations") as mock_git, + patch("forge.workflow.nodes.task_takeover_review.ForgeAgent", return_value=mock_agent), + patch("forge.workflow.nodes.task_takeover_review.post_status_comment"), + ): + mock_git_instance = MagicMock() + mock_git_instance._run_git = MagicMock() + mock_git_instance._run_git.return_value.returncode = 0 + mock_git_instance._run_git.return_value.stdout = "diff contents" + mock_git.return_value = mock_git_instance + + result = await run_qualitative_review(base_task_state) + + assert result["review_verdict"] == "tests_incomplete" + assert "No automated tests found" in result["review_feedback"] + assert result["qualitative_review_retry_count"] == 1 + assert result["qualitative_review_failed"] is True + assert result["current_node"] == "qualitative_review" + assert result["last_error"] is None + + @pytest.mark.asyncio + async def test_run_qualitative_review_retry_increment( + self, base_task_state: TaskTakeoverState + ) -> None: + """Verify that existing retry counts are incremented correctly on failure.""" + base_task_state["qualitative_review_retry_count"] = 1 + + mock_jira = _make_mock_jira() + mock_agent = AsyncMock() + mock_agent.run_task = AsyncMock( + return_value="verdict: tests_incomplete\nfeedback: Still lacking necessary test coverage." + ) + + with ( + patch("forge.workflow.nodes.task_takeover_review.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.task_takeover_review.GitOperations") as mock_git, + patch("forge.workflow.nodes.task_takeover_review.ForgeAgent", return_value=mock_agent), + patch("forge.workflow.nodes.task_takeover_review.post_status_comment"), + ): + mock_git_instance = MagicMock() + mock_git_instance._run_git = MagicMock() + mock_git_instance._run_git.return_value.returncode = 0 + mock_git_instance._run_git.return_value.stdout = "diff contents" + mock_git.return_value = mock_git_instance + + result = await run_qualitative_review(base_task_state) + + assert result["qualitative_review_retry_count"] == 2 + assert result["qualitative_review_failed"] is True + + @pytest.mark.asyncio + async def test_run_qualitative_review_valid_diff( + self, base_task_state: TaskTakeoverState + ) -> None: + """Verify qualitative review behavior when dealing with a valid git diff structure. + + Valid structure has requirements met and automated tests added. + """ + mock_jira = _make_mock_jira( + description="Acceptance Criteria:\n1. Must implement user authentication.\n2. Must add tests." + ) + mock_agent = AsyncMock() + # Mocking LLM confirming that the diff met all requirements and added tests + mock_agent.run_task = AsyncMock( + return_value="verdict: adequate\nfeedback: Perfect, all requirements met and tests are written." + ) + + valid_diff = """diff --git a/src/auth.py b/src/auth.py +new file mode 100644 +--- /dev/null ++++ b/src/auth.py +@@ -0,0 +1,5 @@ ++def login(): ++ return True +diff --git a/tests/test_auth.py b/tests/test_auth.py +new file mode 100644 +--- /dev/null ++++ b/tests/test_auth.py +@@ -0,0 +1,4 @@ ++from src.auth import login ++def test_login(): ++ assert login() is True +""" + + with ( + patch("forge.workflow.nodes.task_takeover_review.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.task_takeover_review.GitOperations") as mock_git, + patch("forge.workflow.nodes.task_takeover_review.ForgeAgent", return_value=mock_agent), + patch("forge.workflow.nodes.task_takeover_review.post_status_comment"), + ): + mock_git_instance = MagicMock() + mock_git_instance._run_git = MagicMock() + mock_git_instance._run_git.return_value.returncode = 0 + mock_git_instance._run_git.return_value.stdout = valid_diff + mock_git.return_value = mock_git_instance + + result = await run_qualitative_review(base_task_state) + + assert result["review_verdict"] == "adequate" + assert result["qualitative_review_failed"] is False + + @pytest.mark.asyncio + async def test_run_qualitative_review_invalid_diff_missing_tests( + self, base_task_state: TaskTakeoverState + ) -> None: + """Verify qualitative review behavior when dealing with an invalid git diff structure lacking tests.""" + mock_jira = _make_mock_jira( + description="Acceptance Criteria:\n1. Must implement user authentication.\n2. Must add tests." + ) + mock_agent = AsyncMock() + # Mocking LLM indicating that no automated test is found + mock_agent.run_task = AsyncMock( + return_value="verdict: tests_incomplete\nfeedback: No automated test was found in the git diff." + ) + + invalid_diff = """diff --git a/src/auth.py b/src/auth.py +new file mode 100644 +--- /dev/null ++++ b/src/auth.py +@@ -0,0 +1,5 @@ ++def login(): ++ return True +""" + + with ( + patch("forge.workflow.nodes.task_takeover_review.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.task_takeover_review.GitOperations") as mock_git, + patch("forge.workflow.nodes.task_takeover_review.ForgeAgent", return_value=mock_agent), + patch("forge.workflow.nodes.task_takeover_review.post_status_comment"), + ): + mock_git_instance = MagicMock() + mock_git_instance._run_git = MagicMock() + mock_git_instance._run_git.return_value.returncode = 0 + mock_git_instance._run_git.return_value.stdout = invalid_diff + mock_git.return_value = mock_git_instance + + result = await run_qualitative_review(base_task_state) + + assert result["review_verdict"] == "tests_incomplete" + assert result["qualitative_review_failed"] is True + + @pytest.mark.asyncio + async def test_run_qualitative_review_invalid_diff_unmet_criteria( + self, base_task_state: TaskTakeoverState + ) -> None: + """Verify qualitative review behavior when dealing with an invalid git diff structure that fails requirements.""" + mock_jira = _make_mock_jira( + description="Acceptance Criteria:\n1. Must implement user authentication.\n2. Must add tests." + ) + mock_agent = AsyncMock() + # Mocking LLM indicating that the implementation is incomplete or buggy + mock_agent.run_task = AsyncMock( + return_value="verdict: tests_incomplete\nfeedback: The user authentication logic is missing password hashing requirement." + ) + + invalid_diff = """diff --git a/src/auth.py b/src/auth.py +new file mode 100644 +--- /dev/null ++++ b/src/auth.py +@@ -0,0 +1,4 @@ ++def login(): ++ # Missing password hashing or actual implementation ++ return True +""" + + with ( + patch("forge.workflow.nodes.task_takeover_review.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.task_takeover_review.GitOperations") as mock_git, + patch("forge.workflow.nodes.task_takeover_review.ForgeAgent", return_value=mock_agent), + patch("forge.workflow.nodes.task_takeover_review.post_status_comment"), + ): + mock_git_instance = MagicMock() + mock_git_instance._run_git = MagicMock() + mock_git_instance._run_git.return_value.returncode = 0 + mock_git_instance._run_git.return_value.stdout = invalid_diff + mock_git.return_value = mock_git_instance + + result = await run_qualitative_review(base_task_state) + + assert result["review_verdict"] == "tests_incomplete" + assert result["qualitative_review_failed"] is True + + @pytest.mark.asyncio + async def test_run_qualitative_review_missing_workspace( + self, base_task_state: TaskTakeoverState + ) -> None: + """Verify error state is set when the workspace path is missing.""" + base_task_state["workspace_path"] = None + + result = await run_qualitative_review(base_task_state) + assert result["last_error"] == "Workspace not set up" + assert result["current_node"] == "qualitative_review" + + @pytest.mark.asyncio + async def test_run_qualitative_review_exception_handling( + self, base_task_state: TaskTakeoverState + ) -> None: + """Verify robust error recovery and notify_error triggering when exceptions are raised.""" + mock_jira = _make_mock_jira() + mock_jira.get_issue = AsyncMock(side_effect=RuntimeError("Jira API timeout")) + + with ( + patch("forge.workflow.nodes.task_takeover_review.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.error_handler.notify_error") as mock_notify, + ): + result = await run_qualitative_review(base_task_state) + + assert result["last_error"] is not None + assert "Jira API timeout" in result["last_error"] + assert result["current_node"] == "qualitative_review" + mock_notify.assert_called_once() diff --git a/tests/workflow/test_task_takeover_graph.py b/tests/workflow/test_task_takeover_graph.py new file mode 100644 index 00000000..9a86a4a2 --- /dev/null +++ b/tests/workflow/test_task_takeover_graph.py @@ -0,0 +1,259 @@ +"""Unit and integration tests for Task Takeover workflow graph and routing.""" + +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from langgraph.graph import END, StateGraph + +from forge.models.workflow import ForgeLabel, TicketType +from forge.workflow.gates.task_plan_approval import route_task_plan_approval +from forge.workflow.task_takeover.graph import ( + _route_after_answer, + _route_after_triage_check, + build_task_takeover_graph, + route_entry, +) +from forge.workflow.task_takeover.state import ( + TaskTakeoverState, +) + + +def make_task_state(**overrides: Any) -> TaskTakeoverState: + """Create a TaskTakeoverState dict for graph tests.""" + base = { + "ticket_key": "TASK-123", + "ticket_type": TicketType.TASK, + "current_node": "start", + "is_paused": False, + "retry_count": 0, + "last_error": None, + "triage_passed": False, + "triage_missing_fields": [], + "plan_content": None, + } + return cast(TaskTakeoverState, {**base, **overrides}) + + +class TestTaskTakeoverGraphStructure: + """Test LangGraph StateGraph structure and compilation.""" + + def test_graph_compilation_and_nodes(self) -> None: + """Verify the graph compiles and contains the correct nodes and transitions.""" + graph = build_task_takeover_graph() + assert isinstance(graph, StateGraph) + + compiled_graph = graph.compile() + assert compiled_graph is not None + + # Verify expected nodes are present in the compiled graph + expected_nodes = { + "route_entry", + "triage_check", + "triage_gate", + "generate_plan", + "task_plan_approval_gate", + "escalate_blocked", + "answer_question", + "setup_workspace", + "execute_task_changes", + "run_qualitative_review", + "create_task_takeover_pr", + } + for node in expected_nodes: + assert node in compiled_graph.nodes + + +class TestPathTransitions: + """Test path transitions and route entry logic for state progression.""" + + @pytest.mark.parametrize( + "current_node, expected_next", + [ + ("triage_check", "triage_check"), + ("triage_gate", "triage_gate"), + ("generate_plan", "generate_plan"), + ("task_plan_approval_gate", "task_plan_approval_gate"), + ("escalate_blocked", "escalate_blocked"), + ("setup_workspace", "setup_workspace"), + ("execute_task_changes", "execute_task_changes"), + ("qualitative_review", "run_qualitative_review"), + ("create_task_takeover_pr", "create_task_takeover_pr"), + ("complete", END), + ("", "triage_check"), + ("unknown_node", "triage_check"), + ], + ) + def test_route_entry(self, current_node: str, expected_next: str) -> None: + """Verify that route_entry resumes at the appropriate node or restarts from triage.""" + state = make_task_state(current_node=current_node) + assert route_entry(state) == expected_next + + @pytest.mark.parametrize( + "current_node, expected_next", + [ + ("generate_plan", "generate_plan"), + ("triage_gate", "triage_gate"), + ("escalate_blocked", "escalate_blocked"), + ("unknown_node", "triage_gate"), + ], + ) + def test_route_after_triage_check(self, current_node: str, expected_next: str) -> None: + """Verify route_after_triage_check path routing.""" + state = make_task_state(current_node=current_node) + assert _route_after_triage_check(state) == expected_next + + @pytest.mark.parametrize( + "current_node, expected_next", + [ + ("task_plan_approval_gate", "task_plan_approval_gate"), + ("", "task_plan_approval_gate"), + ("some_other_gate", "some_other_gate"), + ], + ) + def test_route_after_answer(self, current_node: str, expected_next: str) -> None: + """Verify route_after_answer returns back to the original gate.""" + state = make_task_state(current_node=current_node) + assert _route_after_answer(state) == expected_next + + +class TestQualitativeReviewRouting: + """Test routing after run_qualitative_review.""" + + def test_route_after_qualitative_review_adequate(self) -> None: + """If review is adequate, proceed to PR creation.""" + from forge.workflow.task_takeover.graph import _route_after_qualitative_review + + state = make_task_state( + review_verdict="adequate", + qualitative_review_retry_count=0, + ) + assert _route_after_qualitative_review(state) == "create_task_takeover_pr" + + def test_route_after_qualitative_review_failed_under_limit(self) -> None: + """If review is failed or incomplete and under the limit, route back to execute_task_changes.""" + from forge.workflow.task_takeover.graph import _route_after_qualitative_review + + state = make_task_state( + review_verdict="tests_incomplete", + qualitative_review_retry_count=1, + ) + # Assuming standard review_max_attempts limit is 2, retry_count of 1 is under the limit + assert _route_after_qualitative_review(state) == "execute_task_changes" + + def test_route_after_qualitative_review_failed_at_or_above_limit(self) -> None: + """If review is failed or incomplete and at/above the limit, transition to escalate_blocked.""" + from forge.workflow.task_takeover.graph import _route_after_qualitative_review + + state = make_task_state( + review_verdict="tests_incomplete", + qualitative_review_retry_count=2, + ) + # retry_count of 2 is at/above the limit of 2, so transition to escalate_blocked + assert _route_after_qualitative_review(state) == "escalate_blocked" + + +class TestInteractiveGateBehavior: + """Test interactive gate behavior for plan approvals, questions, and revision requests.""" + + @pytest.fixture + def paused_state(self) -> TaskTakeoverState: + return make_task_state( + current_node="task_plan_approval_gate", + is_paused=True, + ) + + def test_gate_remains_paused_waiting_for_updates(self, paused_state: TaskTakeoverState) -> None: + """If still paused and no revision/question signals exist, stay paused (END).""" + result = route_task_plan_approval(paused_state) + assert result == END + + def test_gate_routes_to_answer_question_on_prefix( + self, paused_state: TaskTakeoverState + ) -> None: + """Comment prefixed with '?' or '@forge ask' routes to answer_question.""" + # 1. Direct bool flag + state_bool = {**paused_state, "is_question": True} + assert route_task_plan_approval(state_bool) == "answer_question" + + # 2. '?' prefix comment + state_q = {**paused_state, "feedback_comment": "?Can we run this in parallel?"} + assert route_task_plan_approval(state_q) == "answer_question" + + # 3. '@forge ask' prefix comment + state_ask = {**paused_state, "feedback_comment": "@forge ask how does this scale?"} + assert route_task_plan_approval(state_ask) == "answer_question" + + def test_gate_routes_to_regenerate_plan_on_prefix( + self, paused_state: TaskTakeoverState + ) -> None: + """Comment prefixed with '!' routes to regenerate_plan.""" + # 1. Direct bool flag + state_bool = {**paused_state, "revision_requested": True} + assert route_task_plan_approval(state_bool) == "regenerate_plan" + + # 2. '!' prefix comment + state_excl = {**paused_state, "feedback_comment": "!Please add redis cache."} + assert route_task_plan_approval(state_excl) == "regenerate_plan" + + def test_gate_routes_to_setup_workspace_on_label_approval( + self, paused_state: TaskTakeoverState + ) -> None: + """Changing the label to forge:task-plan-approved clears is_paused and routes to setup_workspace.""" + state_approved = {**paused_state, "is_paused": False} + assert route_task_plan_approval(state_approved) == "setup_workspace" + + def test_yolo_mode_bypasses_approval(self, paused_state: TaskTakeoverState) -> None: + """YOLO mode bypasses the approval checkpoints completely.""" + state_yolo = {**paused_state, "yolo_mode": True} + assert route_task_plan_approval(state_yolo) == "setup_workspace" + + +class TestWorkflowIdentityLabelTransitions: + """Test that workflow identity labels are preserved across transitions.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "identity_label", + ["forge:managed:task", "forge:managed:task-takeover"], + ) + async def test_identity_labels_preserved_during_transition(self, identity_label: str) -> None: + """Verify that forge:managed:task and forge:managed:task-takeover are not removed during transitions.""" + from forge.integrations.jira.client import JiraClient + + mock_client = MagicMock() + mock_response = MagicMock() + mock_client.put = AsyncMock(return_value=mock_response) + + # Initialize JiraClient and mock methods + jira = JiraClient() + jira._client = mock_client + jira.get_labels = AsyncMock( + return_value=[ + "forge:managed", + identity_label, + "forge:task-triage-pending", + ] + ) + + with patch.object(jira, "_get_client", return_value=mock_client): + await jira.set_workflow_label("TASK-123", ForgeLabel.TASK_PLAN_PENDING) + + # Retrieve the PUT request payload + mock_client.put.assert_called_once() + put_url = mock_client.put.call_args[0][0] + put_json = mock_client.put.call_args[1]["json"] + + assert put_url == "/issue/TASK-123" + + # Verify operations + operations = put_json["update"]["labels"] + removed_labels = [op["remove"] for op in operations if "remove" in op] + added_labels = [op["add"] for op in operations if "add" in op] + + # Verify that identity label was NOT removed + assert identity_label not in removed_labels + # Verify that the old state label was removed + assert "forge:task-triage-pending" in removed_labels + # Verify that the new plan pending label was added + assert ForgeLabel.TASK_PLAN_PENDING.value in added_labels diff --git a/tests/workflow/test_task_takeover_triage.py b/tests/workflow/test_task_takeover_triage.py new file mode 100644 index 00000000..0e68d73a --- /dev/null +++ b/tests/workflow/test_task_takeover_triage.py @@ -0,0 +1,148 @@ +"""Unit and integration tests for Task Takeover triage.""" + +import json +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from forge.models.workflow import ForgeLabel +from forge.workflow.nodes.task_takeover_triage import triage_task +from forge.workflow.task_takeover.state import ( + TaskTakeoverState, + create_initial_task_takeover_state, +) + + +def make_task_state(**overrides: Any) -> TaskTakeoverState: + """Create a TaskTakeoverState dict for triage tests.""" + state = create_initial_task_takeover_state("TASK-123") + state_dict = cast(dict[str, Any], state) + state_dict.update(overrides) + return cast(TaskTakeoverState, state_dict) + + +@pytest.fixture +def mock_jira() -> MagicMock: + jira = MagicMock() + jira.get_issue = AsyncMock( + return_value=MagicMock( + summary="Login fails with special characters", + description="Problem description", + ) + ) + jira.get_comments = AsyncMock(return_value=[]) + jira.add_comment = AsyncMock() + jira.set_workflow_label = AsyncMock() + jira.close = AsyncMock() + return jira + + +@pytest.fixture +def mock_agent() -> MagicMock: + agent = MagicMock() + agent.run_task = AsyncMock() + agent.close = AsyncMock() + return agent + + +@pytest.mark.asyncio +async def test_complete_ticket_passes_triage( + mock_jira: MagicMock, + mock_agent: MagicMock, +) -> None: + """Verify that a complete ticket passes triage and moves to planning.""" + state = make_task_state(current_node="start") + mock_agent.run_task.return_value = "sufficient" + + with ( + patch("forge.workflow.nodes.task_takeover_triage.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.task_takeover_triage.ForgeAgent", return_value=mock_agent), + ): + result = await triage_task(state) + + assert result["triage_passed"] is True + assert result["current_node"] == "generate_plan" + assert result["is_paused"] is False + assert result["triage_missing_fields"] == [] + + # Check Jira interactions + # 1. Ack comment posted first + # 2. Success comment posted + assert mock_jira.add_comment.call_count == 2 + mock_jira.add_comment.assert_any_call( + "TASK-123", + "Received task/epic for Task Takeover — checking ticket completeness before starting planning.", + ) + mock_jira.add_comment.assert_any_call( + "TASK-123", + "Ticket has enough information to proceed. Starting plan generation — results will be posted here.", + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "missing_fields, expected_missing_list", + [ + # Single missing section + (["Problem Statement"], ["Problem Statement"]), + (["Proposed Solution/Approach"], ["Proposed Solution/Approach"]), + (["Acceptance Criteria"], ["Acceptance Criteria"]), + # Combinations of missing sections + ( + ["Problem Statement", "Proposed Solution/Approach"], + ["Problem Statement", "Proposed Solution/Approach"], + ), + ( + ["Problem Statement", "Acceptance Criteria"], + ["Problem Statement", "Acceptance Criteria"], + ), + ( + ["Proposed Solution/Approach", "Acceptance Criteria"], + ["Proposed Solution/Approach", "Acceptance Criteria"], + ), + # All sections missing + ( + ["Problem Statement", "Proposed Solution/Approach", "Acceptance Criteria"], + ["Problem Statement", "Proposed Solution/Approach", "Acceptance Criteria"], + ), + # Malformed/Unexpected output fallback + ( + "not-a-list", + ["(could not determine — please provide additional context about the task)"], + ), + ], +) +async def test_incomplete_ticket_triage_permutations( + mock_jira: MagicMock, + mock_agent: MagicMock, + missing_fields: Any, + expected_missing_list: list[str], +) -> None: + """Verify that all permutations of missing sections trigger correct state, label, and comments.""" + state = make_task_state(current_node="start") + + if isinstance(missing_fields, list): + mock_agent.run_task.return_value = json.dumps(missing_fields) + else: + mock_agent.run_task.return_value = missing_fields + + with ( + patch("forge.workflow.nodes.task_takeover_triage.JiraClient", return_value=mock_jira), + patch("forge.workflow.nodes.task_takeover_triage.ForgeAgent", return_value=mock_agent), + ): + result = await triage_task(state) + + assert result["triage_passed"] is False + assert result["current_node"] == "triage_gate" + assert result["is_paused"] is True + assert result["triage_missing_fields"] == expected_missing_list + + # Verify label change to TASK_TRIAGE_PENDING + mock_jira.set_workflow_label.assert_called_once_with("TASK-123", ForgeLabel.TASK_TRIAGE_PENDING) + + # Verify detailed comment lists the missing fields + assert mock_jira.add_comment.call_count == 2 + detailed_comment = mock_jira.add_comment.call_args_list[1].args[1] + for field in expected_missing_list: + assert field in detailed_comment