diff --git a/.golangci.yml b/.golangci.yml index 3e28c53..56ba433 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -47,10 +47,12 @@ linters: # Exclude G301 (directory permissions) - workspace needs readable directories # Exclude G304 (file inclusion) - paths are validated via safePath() # Exclude G306 (file permissions) - workspace files need to be readable + # Exclude G706 (log injection) - we use slog structured logging which is inherently safe excludes: - G301 - G304 - G306 + - G706 formatters: enable: diff --git a/mcp/client.go b/mcp/client.go index cb333c0..d35b6f4 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "log/slog" + "sort" "strings" "sync" "time" @@ -31,54 +32,69 @@ type ClientManager struct { mu sync.Mutex clients map[string]*sdk_mcp.Client sessions map[string]*sdk_mcp.ClientSession + // headerHashes tracks the hash of headers used when creating each session, + // so we can detect when DynamicHeaders change and invalidate stale sessions. + headerHashes map[string]string } // NewClientManager creates a new ClientManager. func NewClientManager() *ClientManager { return &ClientManager{ - clients: make(map[string]*sdk_mcp.Client), - sessions: make(map[string]*sdk_mcp.ClientSession), + clients: make(map[string]*sdk_mcp.Client), + sessions: make(map[string]*sdk_mcp.ClientSession), + headerHashes: make(map[string]string), } } // GetSession returns or creates an MCP session for the given server. +// Sessions are cached by server name. If DynamicHeaders change (e.g. a different +// user or rotated credentials), the stale session is invalidated and recreated. func (m *ClientManager) GetSession(ctx context.Context, server *protocol.MCPServerConfig, logger *slog.Logger) (*sdk_mcp.ClientSession, error) { if logger == nil { logger = slog.With("server", server.Name) } + serverName := server.Name + currentHash := headersCacheKey(server.Headers, server.DynamicHeaders) + m.mu.Lock() defer m.mu.Unlock() - serverName := server.Name - logger.Debug("mcp resolving session", "transport", server.Transport, "url", server.URL, "command", server.Command, + "auth_key", maskKey(server.DynamicHeaders["Authorization"]), ) - // Check if session exists and is still valid + // Check if session exists and headers haven't changed if session, ok := m.sessions[serverName]; ok { - logger.Debug("mcp reusing session") - return session, nil + if m.headerHashes[serverName] == currentHash { + logger.Debug("mcp reusing session") + return session, nil + } + logger.Info("mcp headers changed, invalidating cached session", + "server", serverName, + ) + _ = session.Close() + delete(m.sessions, serverName) + delete(m.clients, serverName) + delete(m.headerHashes, serverName) } - // Create client if not exists - client, ok := m.clients[serverName] - if !ok { - client = sdk_mcp.NewClient(&sdk_mcp.Implementation{ - Name: "flashduty-runner", - Version: "1.0.0", - }, nil) - m.clients[serverName] = client - } + // Create client + client := sdk_mcp.NewClient(&sdk_mcp.Implementation{ + Name: "flashduty-runner", + Version: "1.0.0", + }, nil) + m.clients[serverName] = client // Create transport logger.Info("mcp creating transport", "transport", server.Transport, "url", server.URL, "command", server.Command, + "auth_key", maskKey(server.DynamicHeaders["Authorization"]), ) transport, err := createTransport(server) @@ -103,8 +119,11 @@ func (m *ClientManager) GetSession(ctx context.Context, server *protocol.MCPServ return nil, fmt.Errorf("failed to connect to MCP server '%s': %w", serverName, err) } - logger.Info("mcp connected") + logger.Info("mcp connected", + "auth_key", maskKey(server.DynamicHeaders["Authorization"]), + ) m.sessions[serverName] = session + m.headerHashes[serverName] = currentHash return session, nil } @@ -174,9 +193,34 @@ func (m *ClientManager) ListTools(ctx context.Context, server *protocol.MCPServe func (m *ClientManager) invalidateSession(serverName string) { m.mu.Lock() delete(m.sessions, serverName) + delete(m.headerHashes, serverName) m.mu.Unlock() } +func headersCacheKey(headers, dynamicHeaders map[string]string) string { + encode := func(m map[string]string) string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + parts := make([]string, len(keys)) + for i, k := range keys { + parts[i] = k + "=" + m[k] + } + return strings.Join(parts, "\x00") + } + return "s:" + encode(headers) + "\x01d:" + encode(dynamicHeaders) +} + +func maskKey(key string) string { + key = strings.TrimPrefix(key, "Bearer ") + if len(key) <= 6 { + return key + } + return key[:6] + "***" +} + // Close closes all active sessions and clients. func (m *ClientManager) Close() { m.mu.Lock() diff --git a/mcp/transport.go b/mcp/transport.go index 89426e1..3bc7e46 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -1,6 +1,7 @@ package mcp import ( + "context" "log/slog" "net/http" "os" @@ -12,7 +13,7 @@ import ( // NewStdioTransport creates a new stdio transport for MCP. func NewStdioTransport(command string, args []string, env map[string]string) sdk_mcp.Transport { - cmd := exec.Command(command, args...) + cmd := exec.CommandContext(context.Background(), command, args...) //nolint:gosec // G204: command comes from cloud-controlled MCP server config cmd.Env = buildEnv(env) return &sdk_mcp.CommandTransport{ Command: cmd, @@ -41,6 +42,7 @@ func NewSSETransport(endpoint string, headers map[string]string, dynamicHeaders "endpoint", endpoint, "headers_count", len(headers), "dynamic_headers_count", len(dynamicHeaders), + "auth_key", maskKey(dynamicHeaders["Authorization"]), ) return &sdk_mcp.StreamableClientTransport{ diff --git a/workspace/large_output.go b/workspace/large_output.go index 692ac8c..f2dc568 100644 --- a/workspace/large_output.go +++ b/workspace/large_output.go @@ -134,18 +134,18 @@ func (p *LargeOutputProcessor) truncateContent(content string, filePath string) // Build truncation message var sb strings.Builder sb.WriteString("\n") - sb.WriteString(fmt.Sprintf("Output too large (%d chars, %d lines).", len(content), totalLines)) + fmt.Fprintf(&sb, "Output too large (%d chars, %d lines).", len(content), totalLines) if filePath != "" { - sb.WriteString(fmt.Sprintf(" Full content saved to: %s\n\n", filePath)) + fmt.Fprintf(&sb, " Full content saved to: %s\n\n", filePath) } else { sb.WriteString(" Could not save full content.\n\n") } - sb.WriteString(fmt.Sprintf("Preview (first %d lines):\n```\n%s\n```\n\n", previewLines, preview)) + fmt.Fprintf(&sb, "Preview (first %d lines):\n```\n%s\n```\n\n", previewLines, preview) if filePath != "" { - sb.WriteString(fmt.Sprintf("To read more: read(\"%s\", offset=%d, limit=100)\n", filePath, previewLines)) + fmt.Fprintf(&sb, "To read more: read(\"%s\", offset=%d, limit=100)\n", filePath, previewLines) } sb.WriteString("") diff --git a/workspace/workspace.go b/workspace/workspace.go index a286136..e758b57 100644 --- a/workspace/workspace.go +++ b/workspace/workspace.go @@ -252,7 +252,7 @@ func (w *Workspace) Grep(ctx context.Context, args *protocol.GrepArgs) (*protoco // Build content string var sb strings.Builder for _, match := range res.Matches { - sb.WriteString(fmt.Sprintf("%s:%d:%s\n", match.Path, match.LineNumber, match.Content)) + fmt.Fprintf(&sb, "%s:%d:%s\n", match.Path, match.LineNumber, match.Content) } content := sb.String() @@ -274,13 +274,14 @@ func (w *Workspace) Grep(ctx context.Context, args *protocol.GrepArgs) (*protoco } func (w *Workspace) grepWithRipgrep(ctx context.Context, args *protocol.GrepArgs) (*protocol.GrepResult, error) { - cmdArgs := []string{"--column", "--line-number", "--no-heading", "--color", "never", "--smart-case"} + cmdArgs := make([]string, 0, 6+2*len(args.Include)+2) + cmdArgs = append(cmdArgs, "--column", "--line-number", "--no-heading", "--color", "never", "--smart-case") for _, inc := range args.Include { cmdArgs = append(cmdArgs, "--glob", inc) } cmdArgs = append(cmdArgs, args.Pattern, ".") - cmd := exec.CommandContext(ctx, "rg", cmdArgs...) + cmd := exec.CommandContext(ctx, "rg", cmdArgs...) //nolint:gosec // G204: args built from validated grep parameters cmd.Dir = w.root var stdout strings.Builder @@ -393,7 +394,7 @@ func (w *Workspace) executeBashCommand(ctx context.Context, command, workdir str ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - cmd := exec.CommandContext(ctx, "bash", "-c", command) + cmd := exec.CommandContext(ctx, "bash", "-c", command) //nolint:gosec // G204: command is user-initiated via workspace tool cmd.Dir = workdir // Use a limited writer to prevent OOM from very large outputs diff --git a/ws/client.go b/ws/client.go index 74a22c4..a117b2b 100644 --- a/ws/client.go +++ b/ws/client.go @@ -518,9 +518,10 @@ func getLinuxVersion() string { return "" } -// getCommandOutput executes a command and returns its trimmed output. func getCommandOutput(name string, args ...string) string { - out, err := exec.Command(name, args...).Output() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + out, err := exec.CommandContext(ctx, name, args...).Output() //nolint:gosec // G204: args are hardcoded system info commands if err == nil { return strings.TrimSpace(string(out)) } @@ -550,7 +551,7 @@ func getDefaultShell() string { func getTotalMemoryMB() int64 { switch runtime.GOOS { case "darwin": - out, err := exec.Command("sysctl", "-n", "hw.memsize").Output() + out, err := exec.CommandContext(context.Background(), "sysctl", "-n", "hw.memsize").Output() //nolint:gosec // G204: hardcoded command if err == nil { var bytes int64 if _, err := fmt.Sscanf(strings.TrimSpace(string(out)), "%d", &bytes); err == nil { diff --git a/ws/handler.go b/ws/handler.go index 963f7c0..95426e5 100644 --- a/ws/handler.go +++ b/ws/handler.go @@ -100,7 +100,7 @@ func (h *Handler) handleTaskRequest(ctx context.Context, msg *protocol.Message) "operation", req.Operation, ) - taskCtx, cancel := context.WithCancel(ctx) + taskCtx, cancel := context.WithCancel(ctx) //nolint:gosec // G118: cancel is stored in h.runningTask and called on task completion/cancellation h.mu.Lock() h.runningTask[req.TaskID] = cancel h.mu.Unlock()