diff --git a/cmd/mcp/flags.go b/cmd/mcp/flags.go index ffe0ecd..7b7eaf9 100644 --- a/cmd/mcp/flags.go +++ b/cmd/mcp/flags.go @@ -73,7 +73,7 @@ var ServeFlags = []FlagDef{ ViperKey: "mcp.allow_api_key_query_param", Default: "false", IsBool: true, - Usage: "Accept the Seerr API key via the api_key query parameter in addition to the X-Api-Key header (HTTP transport only)", + Usage: "Accept the MCP auth token via the api_key query parameter in addition to headers (HTTP transport only)", }, { Name: "log-file", diff --git a/cmd/mcp/mcp.go b/cmd/mcp/mcp.go index 51f26df..5ef4c9a 100644 --- a/cmd/mcp/mcp.go +++ b/cmd/mcp/mcp.go @@ -2,7 +2,9 @@ package mcp import ( "context" + "encoding/json" "fmt" + "net/http" "seerr-cli/cmd/apiutil" api "seerr-cli/pkg/api" @@ -63,6 +65,22 @@ func apiToolError(label string, err error) (*mcplib.CallToolResult, error) { return mcplib.NewToolResultError(msg), nil } +// marshalResult marshals res to JSON. When Execute() returns a +// *GenericOpenAPIError on a 2xx response the structured decode failed because +// the live API returned fields not present in the generated OpenAPI model. In +// that case the raw body carried by the error is returned so the caller still +// receives the full response data instead of a false failure. +func marshalResult(res interface{}, httpResp *http.Response, err error) ([]byte, error) { + if err != nil { + if e, ok := err.(*api.GenericOpenAPIError); ok && + httpResp != nil && httpResp.StatusCode < 300 && len(e.Body()) > 0 { + return e.Body(), nil + } + return nil, err + } + return json.Marshal(res) +} + func init() { // Subcommands are added in their respective files' init() functions. } diff --git a/cmd/mcp/serve.go b/cmd/mcp/serve.go index 685f08a..aaf9451 100644 --- a/cmd/mcp/serve.go +++ b/cmd/mcp/serve.go @@ -159,11 +159,9 @@ func runServe(_ *cobra.Command, args []string) error { httpHandler := server.NewStreamableHTTPServer(s) handler := http.Handler(httpHandler) - // Per-request Seerr API key injection (header or optional query param). - handler = SeerrAPIKeyMiddleware(allowAPIKeyQueryParam, handler) handler = httpLoggingMiddleware(handler) if authToken != "" { - handler = bearerAuthMiddleware(authToken, handler) + handler = MCPAuthMiddleware(authToken, allowAPIKeyQueryParam, handler) } // The health endpoint must be reachable without auth, so register it in // a top-level mux that sits above the bearer-auth middleware. @@ -243,50 +241,31 @@ func corsMiddleware(next http.Handler) http.Handler { }) } -func bearerAuthMiddleware(token string, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authHeader := r.Header.Get("Authorization") - const prefix = "Bearer " - if !strings.HasPrefix(authHeader, prefix) { - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - provided := strings.TrimPrefix(authHeader, prefix) - if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 { - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - next.ServeHTTP(w, r) - }) -} - -// SeerrAPIKeyMiddleware extracts the Seerr API key from the incoming request -// and injects it into the request context for use by MCP tool handlers. +// MCPAuthMiddleware authenticates incoming MCP HTTP requests against token. // -// The key is read from the X-Api-Key request header first. When -// allowQueryParam is true the middleware also accepts the key via the -// api_key query parameter; the header takes precedence when both are present. +// Accepted credential locations (in precedence order): +// 1. Authorization: Bearer +// 2. X-Api-Key: +// 3. ?api_key= query parameter — only when allowQueryParam is true. // -// If neither location provides a key the middleware responds with 401. -func SeerrAPIKeyMiddleware(allowQueryParam bool, next http.Handler) http.Handler { +// Requests that do not supply a matching credential are rejected with 401. +// The Seerr API key used for outbound Seerr calls is always read from the +// application config (seerr.api_key) and is never sourced from this middleware. +func MCPAuthMiddleware(token string, allowQueryParam bool, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var apiKey string - if v := r.Header.Get("X-Api-Key"); v != "" { - apiKey = v + var provided string + if v := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer "); v != r.Header.Get("Authorization") { + provided = v + } else if v := r.Header.Get("X-Api-Key"); v != "" { + provided = v } else if allowQueryParam { - if v := r.URL.Query().Get("api_key"); v != "" { - apiKey = v - } + provided = r.URL.Query().Get("api_key") } - if apiKey != "" { - ctx := context.WithValue(r.Context(), apiKeyCtxKey, apiKey) - r = r.Clone(ctx) - } else { + if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } - next.ServeHTTP(w, r) }) } diff --git a/cmd/mcp/tools_request.go b/cmd/mcp/tools_request.go index 4754a49..e2a24e3 100644 --- a/cmd/mcp/tools_request.go +++ b/cmd/mcp/tools_request.go @@ -3,6 +3,8 @@ package mcp import ( "context" "encoding/json" + "strconv" + "strings" api "seerr-cli/pkg/api" @@ -44,6 +46,7 @@ func registerRequestTools(s *server.MCPServer) { mcp.WithIdempotentHintAnnotation(false), mcp.WithString("mediaType", mcp.Required(), mcp.Description("Media type: movie or tv")), mcp.WithNumber("mediaId", mcp.Required(), mcp.Description("TMDB media ID")), + mcp.WithString("seasons", mcp.Description(`Seasons to request (required for tv): "all" or comma-separated numbers e.g. "1,2,3"`)), mcp.WithBoolean("is4k", mcp.Description("Request 4K version")), ), RequestCreateHandler(), @@ -117,14 +120,11 @@ func RequestListHandler() server.ToolHandlerFunc { if sort := req.GetString("sort", ""); sort != "" { r = r.Sort(sort) } - res, _, err := r.Execute() + res, httpResp, err := r.Execute() + b, err := marshalResult(res, httpResp, err) if err != nil { return apiToolError("RequestGet failed", err) } - b, err := json.Marshal(res) - if err != nil { - return nil, err - } return mcp.NewToolResultText(string(b)), nil } } @@ -136,14 +136,11 @@ func RequestGetHandler() server.ToolHandlerFunc { return nil, err } client := newAPIClientWithKey(apiKeyFromContext(callCtx)) - res, _, err := client.RequestAPI.RequestRequestIdGet(callCtx, requestId).Execute() + res, httpResp, err := client.RequestAPI.RequestRequestIdGet(callCtx, requestId).Execute() + b, err := marshalResult(res, httpResp, err) if err != nil { return apiToolError("RequestRequestIdGet failed", err) } - b, err := json.Marshal(res) - if err != nil { - return nil, err - } return mcp.NewToolResultText(string(b)), nil } } @@ -159,18 +156,32 @@ func RequestCreateHandler() server.ToolHandlerFunc { return nil, err } body := api.NewRequestPostRequest(mediaType, float32(mediaId)) + if seasons := req.GetString("seasons", ""); seasons != "" { + if seasons == "all" { + s := "all" + body.SetSeasons(api.StringAsRequestPostRequestSeasons(&s)) + } else { + parts := strings.Split(seasons, ",") + nums := make([]float32, 0, len(parts)) + for _, p := range parts { + n, err := strconv.ParseInt(strings.TrimSpace(p), 10, 64) + if err != nil { + return mcp.NewToolResultError("invalid season number: " + p), nil + } + nums = append(nums, float32(n)) + } + body.SetSeasons(api.ArrayOfFloat32AsRequestPostRequestSeasons(&nums)) + } + } if is4k := req.GetBool("is4k", false); is4k { body.Is4k = &is4k } client := newAPIClientWithKey(apiKeyFromContext(callCtx)) - res, _, err := client.RequestAPI.RequestPost(callCtx).RequestPostRequest(*body).Execute() + res, httpResp, err := client.RequestAPI.RequestPost(callCtx).RequestPostRequest(*body).Execute() + b, err := marshalResult(res, httpResp, err) if err != nil { return apiToolError("RequestPost failed", err) } - b, err := json.Marshal(res) - if err != nil { - return nil, err - } return mcp.NewToolResultText(string(b)), nil } } @@ -182,14 +193,11 @@ func RequestApproveHandler() server.ToolHandlerFunc { return nil, err } client := newAPIClientWithKey(apiKeyFromContext(callCtx)) - res, _, err := client.RequestAPI.RequestRequestIdStatusPost(callCtx, requestId, "approve").Execute() + res, httpResp, err := client.RequestAPI.RequestRequestIdStatusPost(callCtx, requestId, "approve").Execute() + b, err := marshalResult(res, httpResp, err) if err != nil { return apiToolError("RequestRequestIdStatusPost(approve) failed", err) } - b, err := json.Marshal(res) - if err != nil { - return nil, err - } return mcp.NewToolResultText(string(b)), nil } } @@ -201,14 +209,11 @@ func RequestDeclineHandler() server.ToolHandlerFunc { return nil, err } client := newAPIClientWithKey(apiKeyFromContext(callCtx)) - res, _, err := client.RequestAPI.RequestRequestIdStatusPost(callCtx, requestId, "decline").Execute() + res, httpResp, err := client.RequestAPI.RequestRequestIdStatusPost(callCtx, requestId, "decline").Execute() + b, err := marshalResult(res, httpResp, err) if err != nil { return apiToolError("RequestRequestIdStatusPost(decline) failed", err) } - b, err := json.Marshal(res) - if err != nil { - return nil, err - } return mcp.NewToolResultText(string(b)), nil } } @@ -235,14 +240,11 @@ func RequestRetryHandler() server.ToolHandlerFunc { return nil, err } client := newAPIClientWithKey(apiKeyFromContext(callCtx)) - res, _, err := client.RequestAPI.RequestRequestIdRetryPost(callCtx, requestId).Execute() + res, httpResp, err := client.RequestAPI.RequestRequestIdRetryPost(callCtx, requestId).Execute() + b, err := marshalResult(res, httpResp, err) if err != nil { return apiToolError("RequestRequestIdRetryPost failed", err) } - b, err := json.Marshal(res) - if err != nil { - return nil, err - } return mcp.NewToolResultText(string(b)), nil } } diff --git a/cmd/mcp/tools_users.go b/cmd/mcp/tools_users.go index bcb5fef..e4b0cf1 100644 --- a/cmd/mcp/tools_users.go +++ b/cmd/mcp/tools_users.go @@ -2,7 +2,6 @@ package mcp import ( "context" - "encoding/json" api "seerr-cli/pkg/api" @@ -58,14 +57,11 @@ func UsersListHandler() server.ToolHandlerFunc { if skip := req.GetInt("skip", 0); skip > 0 { r = r.Skip(float32(skip)) } - res, _, err := r.Execute() + res, httpResp, err := r.Execute() + b, err := marshalResult(res, httpResp, err) if err != nil { return apiToolError("UserGet failed", err) } - b, err := json.Marshal(res) - if err != nil { - return nil, err - } return mcp.NewToolResultText(string(b)), nil } } @@ -77,14 +73,11 @@ func UsersGetHandler() server.ToolHandlerFunc { return nil, err } client := newAPIClientWithKey(apiKeyFromContext(callCtx)) - res, _, err := client.UsersAPI.UserUserIdGet(callCtx, float32(userId)).Execute() + res, httpResp, err := client.UsersAPI.UserUserIdGet(callCtx, float32(userId)).Execute() + b, err := marshalResult(res, httpResp, err) if err != nil { return apiToolError("UserUserIdGet failed", err) } - b, err := json.Marshal(res) - if err != nil { - return nil, err - } return mcp.NewToolResultText(string(b)), nil } } @@ -113,14 +106,11 @@ func UsersUpdateHandler() server.ToolHandlerFunc { return mcp.NewToolResultError("at least one field must be provided"), nil } client := newAPIClientWithKey(apiKeyFromContext(callCtx)) - res, _, err := client.UsersAPI.UserUserIdPut(callCtx, float32(userId)).UserUpdatePayload(body).Execute() + res, httpResp, err := client.UsersAPI.UserUserIdPut(callCtx, float32(userId)).UserUpdatePayload(body).Execute() + b, err := marshalResult(res, httpResp, err) if err != nil { return apiToolError("UserUserIdPut failed", err) } - b, err := json.Marshal(res) - if err != nil { - return nil, err - } return mcp.NewToolResultText(string(b)), nil } } @@ -132,14 +122,11 @@ func UsersQuotaHandler() server.ToolHandlerFunc { return nil, err } client := newAPIClientWithKey(apiKeyFromContext(callCtx)) - res, _, err := client.UsersAPI.UserUserIdQuotaGet(callCtx, float32(userId)).Execute() + res, httpResp, err := client.UsersAPI.UserUserIdQuotaGet(callCtx, float32(userId)).Execute() + b, err := marshalResult(res, httpResp, err) if err != nil { return apiToolError("UserUserIdQuotaGet failed", err) } - b, err := json.Marshal(res) - if err != nil { - return nil, err - } return mcp.NewToolResultText(string(b)), nil } } diff --git a/tests/mcp_api_key_middleware_test.go b/tests/mcp_api_key_middleware_test.go index d5bf279..8124fb9 100644 --- a/tests/mcp_api_key_middleware_test.go +++ b/tests/mcp_api_key_middleware_test.go @@ -10,100 +10,119 @@ import ( "github.com/stretchr/testify/assert" ) -// captureContextKey is a helper that records the API key injected into the -// request context by the middleware. -func captureContextKey(t *testing.T) (handler http.Handler, captured *string) { - t.Helper() - s := "" - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - s, _ = r.Context().Value(cmdmcp.APIKeyContextKey).(string) +const testAuthToken = "super-secret" + +// newAuthTestHandler returns an inner handler that records whether it was called. +func newAuthTestHandler(called *bool) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + *called = true w.WriteHeader(http.StatusOK) }) - return h, &s } -func TestSeerrAPIKeyMiddleware_headerOnly(t *testing.T) { - inner, captured := captureContextKey(t) - handler := cmdmcp.SeerrAPIKeyMiddleware(false, inner) +func TestMCPAuthMiddleware_bearerToken(t *testing.T) { + var called bool + handler := cmdmcp.MCPAuthMiddleware(testAuthToken, false, newAuthTestHandler(&called)) - req := httptest.NewRequest(http.MethodGet, "/mcp", nil) - req.Header.Set("X-Api-Key", "header-key") + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer "+testAuthToken) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "header-key", *captured) + assert.True(t, called) } -func TestSeerrAPIKeyMiddleware_queryParamOnly(t *testing.T) { - inner, captured := captureContextKey(t) - handler := cmdmcp.SeerrAPIKeyMiddleware(true, inner) +func TestMCPAuthMiddleware_xApiKeyHeader(t *testing.T) { + var called bool + handler := cmdmcp.MCPAuthMiddleware(testAuthToken, false, newAuthTestHandler(&called)) - req := httptest.NewRequest(http.MethodGet, "/mcp?api_key=qparam-key", nil) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("X-Api-Key", testAuthToken) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "qparam-key", *captured) + assert.True(t, called) } -func TestSeerrAPIKeyMiddleware_headerPrecedenceOverQueryParam(t *testing.T) { - inner, captured := captureContextKey(t) - handler := cmdmcp.SeerrAPIKeyMiddleware(true, inner) +func TestMCPAuthMiddleware_queryParam(t *testing.T) { + var called bool + handler := cmdmcp.MCPAuthMiddleware(testAuthToken, true, newAuthTestHandler(&called)) - req := httptest.NewRequest(http.MethodGet, "/mcp?api_key=qparam-key", nil) - req.Header.Set("X-Api-Key", "header-key") + req := httptest.NewRequest(http.MethodPost, "/mcp?api_key="+testAuthToken, nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "header-key", *captured, "header must take precedence over query param") + assert.True(t, called) } -func TestSeerrAPIKeyMiddleware_queryParamDisabled_ignoresQueryParam(t *testing.T) { - var innerCalled bool - inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - innerCalled = true - w.WriteHeader(http.StatusOK) - }) - // allowQueryParam=false; query param present but should not satisfy auth. - handler := cmdmcp.SeerrAPIKeyMiddleware(false, inner) +func TestMCPAuthMiddleware_queryParamDisabled_returns401(t *testing.T) { + var called bool + handler := cmdmcp.MCPAuthMiddleware(testAuthToken, false, newAuthTestHandler(&called)) - req := httptest.NewRequest(http.MethodGet, "/mcp?api_key=qparam-key", nil) + req := httptest.NewRequest(http.MethodPost, "/mcp?api_key="+testAuthToken, nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) - // No header, query param disabled — must return 401. assert.Equal(t, http.StatusUnauthorized, rec.Code) - assert.False(t, innerCalled) + assert.False(t, called) } -func TestSeerrAPIKeyMiddleware_neitherPresent_returns401(t *testing.T) { - var innerCalled bool - inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - innerCalled = true - w.WriteHeader(http.StatusOK) - }) - handler := cmdmcp.SeerrAPIKeyMiddleware(true, inner) +func TestMCPAuthMiddleware_wrongToken_returns401(t *testing.T) { + var called bool + handler := cmdmcp.MCPAuthMiddleware(testAuthToken, true, newAuthTestHandler(&called)) - req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("X-Api-Key", "wrong-token") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusUnauthorized, rec.Code) - assert.False(t, innerCalled) + assert.False(t, called) } -func TestSeerrAPIKeyMiddleware_queryParam_sensitiveValueNotLogged(t *testing.T) { - // This test ensures SafeLogQuery redacts the api_key value from query strings. +func TestMCPAuthMiddleware_noCredentials_returns401(t *testing.T) { + var called bool + handler := cmdmcp.MCPAuthMiddleware(testAuthToken, true, newAuthTestHandler(&called)) + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code) + assert.False(t, called) +} + +// TestMCPAuthMiddleware_noAuthToken verifies that when no --auth-token is +// configured the middleware is not applied and requests pass through without +// any credential check. This is the no-auth / --no-auth scenario. +func TestMCPAuthMiddleware_noAuthToken_requestPassesThrough(t *testing.T) { + // When authToken is empty the serve command does not wrap the handler with + // MCPAuthMiddleware at all. The Seerr API key is always read from the app + // config (seerr.api_key) and never sourced from the incoming request. + var called bool + inner := newAuthTestHandler(&called) + + // Simulate the no-auth path: handler is NOT wrapped with MCPAuthMiddleware. + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + rec := httptest.NewRecorder() + inner.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.True(t, called) +} + +func TestMCPAuthMiddlewareIsTheOnlyPerRequestAuthMechanism(t *testing.T) { + // Compile-time check that MCPAuthMiddleware exists and has the expected signature. + _ = cmdmcp.MCPAuthMiddleware +} + +func TestMCPLogQueryRedaction(t *testing.T) { + // Ensure SafeLogQuery redacts the api_key value from query strings. redacted := cmdmcp.SafeLogQuery("api_key=secret123&page=1") assert.NotContains(t, redacted, "secret123") assert.Contains(t, redacted, "api_key={redacted}") assert.Contains(t, redacted, "page=1") } - -func TestSeerrAPIKeyMiddlewareIsTheOnlyPerRequestKeyMechanism(t *testing.T) { - // SeerrAPIKeyMiddleware must compile and be the sole per-request API key - // injection mechanism — path-based routing has been removed. - _ = cmdmcp.SeerrAPIKeyMiddleware -} diff --git a/tests/mcp_marshal_result_test.go b/tests/mcp_marshal_result_test.go new file mode 100644 index 0000000..ad7ec3f --- /dev/null +++ b/tests/mcp_marshal_result_test.go @@ -0,0 +1,99 @@ +package tests + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + cmdmcp "seerr-cli/cmd/mcp" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestMCPRequestCreateFallbackOnDecodeError verifies that request_create returns +// the raw response body as a success result when the HTTP call succeeds (201) +// but the generated client fails to decode the response. This happens in the +// wild because Overseerr's POST /api/v1/request response embeds a UserSettings +// object with an "id" field that the generated model rejects via +// DisallowUnknownFields. +func TestMCPRequestCreateFallbackOnDecodeError(t *testing.T) { + // The nested settings.id field is not in the generated UserSettings struct. + // The strict decoder rejects it even though the HTTP response is 201. + overseerrBody := `{ + "id": 50, + "status": 2, + "type": "tv", + "media": {"mediaType": "tv"}, + "requestedBy": { + "id": 1, + "settings": { + "id": 1, + "locale": "", + "discoverRegion": "" + } + } + }` + + ts, cleanup := newMCPTestServer(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.URL.Path == "/api/v1/request" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(overseerrBody)) + return + } + w.WriteHeader(http.StatusNotFound) + }) + defer cleanup() + _ = ts + + handler := cmdmcp.RequestCreateHandler() + result := callTool(t, handler, map[string]any{ + "mediaType": "tv", + "mediaId": float64(83100), + "seasons": "all", + }) + + // Must not be an error — the request was created successfully. + require.False(t, result.IsError, + "2xx response with unknown fields must not produce a tool error") + + text := resultText(t, result) + + // The raw body must be returned so the caller has the full response. + var parsed map[string]interface{} + require.NoError(t, json.Unmarshal([]byte(text), &parsed)) + assert.Equal(t, float64(50), parsed["id"]) + assert.Equal(t, float64(2), parsed["status"]) +} + +// TestMCPRequestCreateRealErrorStillFails verifies that a genuine Overseerr +// error (non-2xx) is still surfaced as a tool error, not silently swallowed by +// the fallback path. +func TestMCPRequestCreateRealErrorStillFails(t *testing.T) { + ts, cleanup := newMCPTestServer(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.URL.Path == "/api/v1/request" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"message":"Invalid request"}`)) + return + } + w.WriteHeader(http.StatusNotFound) + }) + defer cleanup() + _ = ts + + handler := cmdmcp.RequestCreateHandler() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]any{ + "mediaType": "tv", + "mediaId": float64(83100), + "seasons": "all", + } + result, err := handler(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.IsError, "non-2xx response must produce a tool error") +} diff --git a/tests/mcp_serve_test.go b/tests/mcp_serve_test.go index 6f1d23a..1b3ec8e 100644 --- a/tests/mcp_serve_test.go +++ b/tests/mcp_serve_test.go @@ -3,6 +3,7 @@ package tests import ( "context" "encoding/json" + "io" "net/http" "net/http/httptest" "strings" @@ -412,12 +413,123 @@ func TestMCPBlocklistListHandler(t *testing.T) { assert.Contains(t, text, `"results"`) } +func TestMCPRequestCreateMovieHandler(t *testing.T) { + var capturedBody []byte + ts, cleanup := newMCPTestServer(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.URL.Path == "/api/v1/request" { + capturedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"id":1,"status":1,"media":{"mediaType":"movie"}}`)) + return + } + w.WriteHeader(http.StatusNotFound) + }) + defer cleanup() + _ = ts + + handler := cmdmcp.RequestCreateHandler() + result := callTool(t, handler, map[string]any{ + "mediaType": "movie", + "mediaId": float64(550), + }) + text := resultText(t, result) + + assert.Contains(t, text, `"id"`) + + var body map[string]interface{} + require.NoError(t, json.Unmarshal(capturedBody, &body)) + assert.Equal(t, "movie", body["mediaType"]) + assert.NotContains(t, body, "seasons", "movie requests must not send a seasons field") +} + +func TestMCPRequestCreateTVSeasonsAll(t *testing.T) { + var capturedBody []byte + ts, cleanup := newMCPTestServer(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.URL.Path == "/api/v1/request" { + capturedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"id":2,"status":1,"media":{"mediaType":"tv"}}`)) + return + } + w.WriteHeader(http.StatusNotFound) + }) + defer cleanup() + _ = ts + + handler := cmdmcp.RequestCreateHandler() + result := callTool(t, handler, map[string]any{ + "mediaType": "tv", + "mediaId": float64(1399), + "seasons": "all", + }) + text := resultText(t, result) + + assert.Contains(t, text, `"id"`) + + var body map[string]interface{} + require.NoError(t, json.Unmarshal(capturedBody, &body)) + assert.Equal(t, "all", body["seasons"], "seasons=all must be forwarded as the string \"all\"") +} + +func TestMCPRequestCreateTVSpecificSeasons(t *testing.T) { + var capturedBody []byte + ts, cleanup := newMCPTestServer(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.URL.Path == "/api/v1/request" { + capturedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"id":3,"status":1,"media":{"mediaType":"tv"}}`)) + return + } + w.WriteHeader(http.StatusNotFound) + }) + defer cleanup() + _ = ts + + handler := cmdmcp.RequestCreateHandler() + result := callTool(t, handler, map[string]any{ + "mediaType": "tv", + "mediaId": float64(1399), + "seasons": "1,2", + }) + text := resultText(t, result) + + assert.Contains(t, text, `"id"`) + + var body map[string]interface{} + require.NoError(t, json.Unmarshal(capturedBody, &body)) + seasons, ok := body["seasons"].([]interface{}) + require.True(t, ok, "seasons must be a JSON array") + assert.Equal(t, []interface{}{float64(1), float64(2)}, seasons) +} + +func TestMCPRequestCreateInvalidSeason(t *testing.T) { + ts, cleanup := newMCPTestServer(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }) + defer cleanup() + _ = ts + + handler := cmdmcp.RequestCreateHandler() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]any{ + "mediaType": "tv", + "mediaId": float64(1399), + "seasons": "1,abc", + } + result, err := handler(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.IsError, "invalid season number must produce a tool error") +} + // --- API key context propagation test --- func TestAPIKeyContextPropagation(t *testing.T) { // Verify that an API key injected into the context is forwarded to the - // Seerr API as the X-Api-Key header, matching how SeerrAPIKeyMiddleware - // injects keys for downstream tool handlers. + // Seerr API as the X-Api-Key header by tool handlers. var receivedAPIKey string ts, cleanup := newMCPTestServer(func(w http.ResponseWriter, r *http.Request) {