Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ linters:
# Exclude G301 (directory permissions) - workspace needs readable directories
# Exclude G304 (file inclusion) - paths are validated via safePath()
# Exclude G306 (file permissions) - workspace files need to be readable
# Exclude G706 (log injection) - we use slog structured logging which is inherently safe
excludes:
- G301
- G304
- G306
- G706

formatters:
enable:
Expand Down
78 changes: 61 additions & 17 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"fmt"
"log/slog"
"sort"
"strings"
"sync"
"time"
Expand All @@ -31,54 +32,69 @@ type ClientManager struct {
mu sync.Mutex
clients map[string]*sdk_mcp.Client
sessions map[string]*sdk_mcp.ClientSession
// headerHashes tracks the hash of headers used when creating each session,
// so we can detect when DynamicHeaders change and invalidate stale sessions.
headerHashes map[string]string
}

// NewClientManager creates a new ClientManager.
func NewClientManager() *ClientManager {
return &ClientManager{
clients: make(map[string]*sdk_mcp.Client),
sessions: make(map[string]*sdk_mcp.ClientSession),
clients: make(map[string]*sdk_mcp.Client),
sessions: make(map[string]*sdk_mcp.ClientSession),
headerHashes: make(map[string]string),
}
}

// GetSession returns or creates an MCP session for the given server.
// Sessions are cached by server name. If DynamicHeaders change (e.g. a different
// user or rotated credentials), the stale session is invalidated and recreated.
func (m *ClientManager) GetSession(ctx context.Context, server *protocol.MCPServerConfig, logger *slog.Logger) (*sdk_mcp.ClientSession, error) {
if logger == nil {
logger = slog.With("server", server.Name)
}

serverName := server.Name
currentHash := headersCacheKey(server.Headers, server.DynamicHeaders)

m.mu.Lock()
defer m.mu.Unlock()

serverName := server.Name

logger.Debug("mcp resolving session",
"transport", server.Transport,
"url", server.URL,
"command", server.Command,
"auth_key", maskKey(server.DynamicHeaders["Authorization"]),
)

// Check if session exists and is still valid
// Check if session exists and headers haven't changed
if session, ok := m.sessions[serverName]; ok {
logger.Debug("mcp reusing session")
return session, nil
if m.headerHashes[serverName] == currentHash {
logger.Debug("mcp reusing session")
return session, nil
}
logger.Info("mcp headers changed, invalidating cached session",
"server", serverName,
)
_ = session.Close()
delete(m.sessions, serverName)
delete(m.clients, serverName)
delete(m.headerHashes, serverName)
}

// Create client if not exists
client, ok := m.clients[serverName]
if !ok {
client = sdk_mcp.NewClient(&sdk_mcp.Implementation{
Name: "flashduty-runner",
Version: "1.0.0",
}, nil)
m.clients[serverName] = client
}
// Create client
client := sdk_mcp.NewClient(&sdk_mcp.Implementation{
Name: "flashduty-runner",
Version: "1.0.0",
}, nil)
m.clients[serverName] = client

// Create transport
logger.Info("mcp creating transport",
"transport", server.Transport,
"url", server.URL,
"command", server.Command,
"auth_key", maskKey(server.DynamicHeaders["Authorization"]),
)

transport, err := createTransport(server)
Expand All @@ -103,8 +119,11 @@ func (m *ClientManager) GetSession(ctx context.Context, server *protocol.MCPServ
return nil, fmt.Errorf("failed to connect to MCP server '%s': %w", serverName, err)
}

logger.Info("mcp connected")
logger.Info("mcp connected",
"auth_key", maskKey(server.DynamicHeaders["Authorization"]),
)
m.sessions[serverName] = session
m.headerHashes[serverName] = currentHash
return session, nil
}

Expand Down Expand Up @@ -174,9 +193,34 @@ func (m *ClientManager) ListTools(ctx context.Context, server *protocol.MCPServe
func (m *ClientManager) invalidateSession(serverName string) {
m.mu.Lock()
delete(m.sessions, serverName)
delete(m.headerHashes, serverName)
m.mu.Unlock()
}

func headersCacheKey(headers, dynamicHeaders map[string]string) string {
encode := func(m map[string]string) string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
parts := make([]string, len(keys))
for i, k := range keys {
parts[i] = k + "=" + m[k]
}
return strings.Join(parts, "\x00")
}
return "s:" + encode(headers) + "\x01d:" + encode(dynamicHeaders)
}

func maskKey(key string) string {
key = strings.TrimPrefix(key, "Bearer ")
if len(key) <= 6 {
return key
}
return key[:6] + "***"
}

// Close closes all active sessions and clients.
func (m *ClientManager) Close() {
m.mu.Lock()
Expand Down
4 changes: 3 additions & 1 deletion mcp/transport.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mcp

import (
"context"
"log/slog"
"net/http"
"os"
Expand All @@ -12,7 +13,7 @@ import (

// NewStdioTransport creates a new stdio transport for MCP.
func NewStdioTransport(command string, args []string, env map[string]string) sdk_mcp.Transport {
cmd := exec.Command(command, args...)
cmd := exec.CommandContext(context.Background(), command, args...) //nolint:gosec // G204: command comes from cloud-controlled MCP server config
cmd.Env = buildEnv(env)
return &sdk_mcp.CommandTransport{
Command: cmd,
Expand Down Expand Up @@ -41,6 +42,7 @@ func NewSSETransport(endpoint string, headers map[string]string, dynamicHeaders
"endpoint", endpoint,
"headers_count", len(headers),
"dynamic_headers_count", len(dynamicHeaders),
"auth_key", maskKey(dynamicHeaders["Authorization"]),
)

return &sdk_mcp.StreamableClientTransport{
Expand Down
8 changes: 4 additions & 4 deletions workspace/large_output.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,18 @@ func (p *LargeOutputProcessor) truncateContent(content string, filePath string)
// Build truncation message
var sb strings.Builder
sb.WriteString("<output_truncated>\n")
sb.WriteString(fmt.Sprintf("Output too large (%d chars, %d lines).", len(content), totalLines))
fmt.Fprintf(&sb, "Output too large (%d chars, %d lines).", len(content), totalLines)

if filePath != "" {
sb.WriteString(fmt.Sprintf(" Full content saved to: %s\n\n", filePath))
fmt.Fprintf(&sb, " Full content saved to: %s\n\n", filePath)
} else {
sb.WriteString(" Could not save full content.\n\n")
}

sb.WriteString(fmt.Sprintf("Preview (first %d lines):\n```\n%s\n```\n\n", previewLines, preview))
fmt.Fprintf(&sb, "Preview (first %d lines):\n```\n%s\n```\n\n", previewLines, preview)

if filePath != "" {
sb.WriteString(fmt.Sprintf("To read more: read(\"%s\", offset=%d, limit=100)\n", filePath, previewLines))
fmt.Fprintf(&sb, "To read more: read(\"%s\", offset=%d, limit=100)\n", filePath, previewLines)
}

sb.WriteString("</output_truncated>")
Expand Down
9 changes: 5 additions & 4 deletions workspace/workspace.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func (w *Workspace) Grep(ctx context.Context, args *protocol.GrepArgs) (*protoco
// Build content string
var sb strings.Builder
for _, match := range res.Matches {
sb.WriteString(fmt.Sprintf("%s:%d:%s\n", match.Path, match.LineNumber, match.Content))
fmt.Fprintf(&sb, "%s:%d:%s\n", match.Path, match.LineNumber, match.Content)
}
content := sb.String()

Expand All @@ -274,13 +274,14 @@ func (w *Workspace) Grep(ctx context.Context, args *protocol.GrepArgs) (*protoco
}

func (w *Workspace) grepWithRipgrep(ctx context.Context, args *protocol.GrepArgs) (*protocol.GrepResult, error) {
cmdArgs := []string{"--column", "--line-number", "--no-heading", "--color", "never", "--smart-case"}
cmdArgs := make([]string, 0, 6+2*len(args.Include)+2)
cmdArgs = append(cmdArgs, "--column", "--line-number", "--no-heading", "--color", "never", "--smart-case")
for _, inc := range args.Include {
cmdArgs = append(cmdArgs, "--glob", inc)
}
cmdArgs = append(cmdArgs, args.Pattern, ".")

cmd := exec.CommandContext(ctx, "rg", cmdArgs...)
cmd := exec.CommandContext(ctx, "rg", cmdArgs...) //nolint:gosec // G204: args built from validated grep parameters
cmd.Dir = w.root

var stdout strings.Builder
Expand Down Expand Up @@ -393,7 +394,7 @@ func (w *Workspace) executeBashCommand(ctx context.Context, command, workdir str
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

cmd := exec.CommandContext(ctx, "bash", "-c", command)
cmd := exec.CommandContext(ctx, "bash", "-c", command) //nolint:gosec // G204: command is user-initiated via workspace tool
cmd.Dir = workdir

// Use a limited writer to prevent OOM from very large outputs
Expand Down
7 changes: 4 additions & 3 deletions ws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,9 +518,10 @@ func getLinuxVersion() string {
return ""
}

// getCommandOutput executes a command and returns its trimmed output.
func getCommandOutput(name string, args ...string) string {
out, err := exec.Command(name, args...).Output()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
out, err := exec.CommandContext(ctx, name, args...).Output() //nolint:gosec // G204: args are hardcoded system info commands
if err == nil {
return strings.TrimSpace(string(out))
}
Expand Down Expand Up @@ -550,7 +551,7 @@ func getDefaultShell() string {
func getTotalMemoryMB() int64 {
switch runtime.GOOS {
case "darwin":
out, err := exec.Command("sysctl", "-n", "hw.memsize").Output()
out, err := exec.CommandContext(context.Background(), "sysctl", "-n", "hw.memsize").Output() //nolint:gosec // G204: hardcoded command
if err == nil {
var bytes int64
if _, err := fmt.Sscanf(strings.TrimSpace(string(out)), "%d", &bytes); err == nil {
Expand Down
2 changes: 1 addition & 1 deletion ws/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (h *Handler) handleTaskRequest(ctx context.Context, msg *protocol.Message)
"operation", req.Operation,
)

taskCtx, cancel := context.WithCancel(ctx)
taskCtx, cancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel is stored in h.runningTask and called on task completion/cancellation
h.mu.Lock()
h.runningTask[req.TaskID] = cancel
h.mu.Unlock()
Expand Down
Loading