Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/mcp/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ var ServeFlags = []FlagDef{
ViperKey: "mcp.allow_api_key_query_param",
Default: "false",
IsBool: true,
Usage: "Accept the Seerr API key via the api_key query parameter in addition to the X-Api-Key header (HTTP transport only)",
Usage: "Accept the MCP auth token via the api_key query parameter in addition to headers (HTTP transport only)",
},
{
Name: "log-file",
Expand Down
18 changes: 18 additions & 0 deletions cmd/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package mcp

import (
"context"
"encoding/json"
"fmt"
"net/http"

"seerr-cli/cmd/apiutil"
api "seerr-cli/pkg/api"
Expand Down Expand Up @@ -63,6 +65,22 @@ func apiToolError(label string, err error) (*mcplib.CallToolResult, error) {
return mcplib.NewToolResultError(msg), nil
}

// marshalResult marshals res to JSON. When Execute() returns a
// *GenericOpenAPIError on a 2xx response the structured decode failed because
// the live API returned fields not present in the generated OpenAPI model. In
// that case the raw body carried by the error is returned so the caller still
// receives the full response data instead of a false failure.
func marshalResult(res interface{}, httpResp *http.Response, err error) ([]byte, error) {
if err != nil {
if e, ok := err.(*api.GenericOpenAPIError); ok &&
httpResp != nil && httpResp.StatusCode < 300 && len(e.Body()) > 0 {
return e.Body(), nil
}
return nil, err
}
return json.Marshal(res)
}

func init() {
// Subcommands are added in their respective files' init() functions.
}
55 changes: 17 additions & 38 deletions cmd/mcp/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,9 @@ func runServe(_ *cobra.Command, args []string) error {

httpHandler := server.NewStreamableHTTPServer(s)
handler := http.Handler(httpHandler)
// Per-request Seerr API key injection (header or optional query param).
handler = SeerrAPIKeyMiddleware(allowAPIKeyQueryParam, handler)
handler = httpLoggingMiddleware(handler)
if authToken != "" {
handler = bearerAuthMiddleware(authToken, handler)
handler = MCPAuthMiddleware(authToken, allowAPIKeyQueryParam, handler)
}
// The health endpoint must be reachable without auth, so register it in
// a top-level mux that sits above the bearer-auth middleware.
Expand Down Expand Up @@ -243,50 +241,31 @@ func corsMiddleware(next http.Handler) http.Handler {
})
}

func bearerAuthMiddleware(token string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
const prefix = "Bearer "
if !strings.HasPrefix(authHeader, prefix) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
provided := strings.TrimPrefix(authHeader, prefix)
if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}

// SeerrAPIKeyMiddleware extracts the Seerr API key from the incoming request
// and injects it into the request context for use by MCP tool handlers.
// MCPAuthMiddleware authenticates incoming MCP HTTP requests against token.
//
// The key is read from the X-Api-Key request header first. When
// allowQueryParam is true the middleware also accepts the key via the
// api_key query parameter; the header takes precedence when both are present.
// Accepted credential locations (in precedence order):
// 1. Authorization: Bearer <token>
// 2. X-Api-Key: <token>
// 3. ?api_key=<token> query parameter — only when allowQueryParam is true.
//
// If neither location provides a key the middleware responds with 401.
func SeerrAPIKeyMiddleware(allowQueryParam bool, next http.Handler) http.Handler {
// Requests that do not supply a matching credential are rejected with 401.
// The Seerr API key used for outbound Seerr calls is always read from the
// application config (seerr.api_key) and is never sourced from this middleware.
func MCPAuthMiddleware(token string, allowQueryParam bool, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var apiKey string
if v := r.Header.Get("X-Api-Key"); v != "" {
apiKey = v
var provided string
if v := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer "); v != r.Header.Get("Authorization") {
provided = v
} else if v := r.Header.Get("X-Api-Key"); v != "" {
provided = v
} else if allowQueryParam {
if v := r.URL.Query().Get("api_key"); v != "" {
apiKey = v
}
provided = r.URL.Query().Get("api_key")
}

if apiKey != "" {
ctx := context.WithValue(r.Context(), apiKeyCtxKey, apiKey)
r = r.Clone(ctx)
} else {
if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}

next.ServeHTTP(w, r)
})
}
62 changes: 32 additions & 30 deletions cmd/mcp/tools_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package mcp
import (
"context"
"encoding/json"
"strconv"
"strings"

api "seerr-cli/pkg/api"

Expand Down Expand Up @@ -44,6 +46,7 @@ func registerRequestTools(s *server.MCPServer) {
mcp.WithIdempotentHintAnnotation(false),
mcp.WithString("mediaType", mcp.Required(), mcp.Description("Media type: movie or tv")),
mcp.WithNumber("mediaId", mcp.Required(), mcp.Description("TMDB media ID")),
mcp.WithString("seasons", mcp.Description(`Seasons to request (required for tv): "all" or comma-separated numbers e.g. "1,2,3"`)),
mcp.WithBoolean("is4k", mcp.Description("Request 4K version")),
),
RequestCreateHandler(),
Expand Down Expand Up @@ -117,14 +120,11 @@ func RequestListHandler() server.ToolHandlerFunc {
if sort := req.GetString("sort", ""); sort != "" {
r = r.Sort(sort)
}
res, _, err := r.Execute()
res, httpResp, err := r.Execute()
b, err := marshalResult(res, httpResp, err)
if err != nil {
return apiToolError("RequestGet failed", err)
}
b, err := json.Marshal(res)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(string(b)), nil
}
}
Expand All @@ -136,14 +136,11 @@ func RequestGetHandler() server.ToolHandlerFunc {
return nil, err
}
client := newAPIClientWithKey(apiKeyFromContext(callCtx))
res, _, err := client.RequestAPI.RequestRequestIdGet(callCtx, requestId).Execute()
res, httpResp, err := client.RequestAPI.RequestRequestIdGet(callCtx, requestId).Execute()
b, err := marshalResult(res, httpResp, err)
if err != nil {
return apiToolError("RequestRequestIdGet failed", err)
}
b, err := json.Marshal(res)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(string(b)), nil
}
}
Expand All @@ -159,18 +156,32 @@ func RequestCreateHandler() server.ToolHandlerFunc {
return nil, err
}
body := api.NewRequestPostRequest(mediaType, float32(mediaId))
if seasons := req.GetString("seasons", ""); seasons != "" {
if seasons == "all" {
s := "all"
body.SetSeasons(api.StringAsRequestPostRequestSeasons(&s))
} else {
parts := strings.Split(seasons, ",")
nums := make([]float32, 0, len(parts))
for _, p := range parts {
n, err := strconv.ParseInt(strings.TrimSpace(p), 10, 64)
if err != nil {
return mcp.NewToolResultError("invalid season number: " + p), nil
}
nums = append(nums, float32(n))
}
body.SetSeasons(api.ArrayOfFloat32AsRequestPostRequestSeasons(&nums))
}
}
if is4k := req.GetBool("is4k", false); is4k {
body.Is4k = &is4k
}
client := newAPIClientWithKey(apiKeyFromContext(callCtx))
res, _, err := client.RequestAPI.RequestPost(callCtx).RequestPostRequest(*body).Execute()
res, httpResp, err := client.RequestAPI.RequestPost(callCtx).RequestPostRequest(*body).Execute()
b, err := marshalResult(res, httpResp, err)
if err != nil {
return apiToolError("RequestPost failed", err)
}
b, err := json.Marshal(res)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(string(b)), nil
}
}
Expand All @@ -182,14 +193,11 @@ func RequestApproveHandler() server.ToolHandlerFunc {
return nil, err
}
client := newAPIClientWithKey(apiKeyFromContext(callCtx))
res, _, err := client.RequestAPI.RequestRequestIdStatusPost(callCtx, requestId, "approve").Execute()
res, httpResp, err := client.RequestAPI.RequestRequestIdStatusPost(callCtx, requestId, "approve").Execute()
b, err := marshalResult(res, httpResp, err)
if err != nil {
return apiToolError("RequestRequestIdStatusPost(approve) failed", err)
}
b, err := json.Marshal(res)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(string(b)), nil
}
}
Expand All @@ -201,14 +209,11 @@ func RequestDeclineHandler() server.ToolHandlerFunc {
return nil, err
}
client := newAPIClientWithKey(apiKeyFromContext(callCtx))
res, _, err := client.RequestAPI.RequestRequestIdStatusPost(callCtx, requestId, "decline").Execute()
res, httpResp, err := client.RequestAPI.RequestRequestIdStatusPost(callCtx, requestId, "decline").Execute()
b, err := marshalResult(res, httpResp, err)
if err != nil {
return apiToolError("RequestRequestIdStatusPost(decline) failed", err)
}
b, err := json.Marshal(res)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(string(b)), nil
}
}
Expand All @@ -235,14 +240,11 @@ func RequestRetryHandler() server.ToolHandlerFunc {
return nil, err
}
client := newAPIClientWithKey(apiKeyFromContext(callCtx))
res, _, err := client.RequestAPI.RequestRequestIdRetryPost(callCtx, requestId).Execute()
res, httpResp, err := client.RequestAPI.RequestRequestIdRetryPost(callCtx, requestId).Execute()
b, err := marshalResult(res, httpResp, err)
if err != nil {
return apiToolError("RequestRequestIdRetryPost failed", err)
}
b, err := json.Marshal(res)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(string(b)), nil
}
}
Expand Down
29 changes: 8 additions & 21 deletions cmd/mcp/tools_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package mcp

import (
"context"
"encoding/json"

api "seerr-cli/pkg/api"

Expand Down Expand Up @@ -58,14 +57,11 @@ func UsersListHandler() server.ToolHandlerFunc {
if skip := req.GetInt("skip", 0); skip > 0 {
r = r.Skip(float32(skip))
}
res, _, err := r.Execute()
res, httpResp, err := r.Execute()
b, err := marshalResult(res, httpResp, err)
if err != nil {
return apiToolError("UserGet failed", err)
}
b, err := json.Marshal(res)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(string(b)), nil
}
}
Expand All @@ -77,14 +73,11 @@ func UsersGetHandler() server.ToolHandlerFunc {
return nil, err
}
client := newAPIClientWithKey(apiKeyFromContext(callCtx))
res, _, err := client.UsersAPI.UserUserIdGet(callCtx, float32(userId)).Execute()
res, httpResp, err := client.UsersAPI.UserUserIdGet(callCtx, float32(userId)).Execute()
b, err := marshalResult(res, httpResp, err)
if err != nil {
return apiToolError("UserUserIdGet failed", err)
}
b, err := json.Marshal(res)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(string(b)), nil
}
}
Expand Down Expand Up @@ -113,14 +106,11 @@ func UsersUpdateHandler() server.ToolHandlerFunc {
return mcp.NewToolResultError("at least one field must be provided"), nil
}
client := newAPIClientWithKey(apiKeyFromContext(callCtx))
res, _, err := client.UsersAPI.UserUserIdPut(callCtx, float32(userId)).UserUpdatePayload(body).Execute()
res, httpResp, err := client.UsersAPI.UserUserIdPut(callCtx, float32(userId)).UserUpdatePayload(body).Execute()
b, err := marshalResult(res, httpResp, err)
if err != nil {
return apiToolError("UserUserIdPut failed", err)
}
b, err := json.Marshal(res)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(string(b)), nil
}
}
Expand All @@ -132,14 +122,11 @@ func UsersQuotaHandler() server.ToolHandlerFunc {
return nil, err
}
client := newAPIClientWithKey(apiKeyFromContext(callCtx))
res, _, err := client.UsersAPI.UserUserIdQuotaGet(callCtx, float32(userId)).Execute()
res, httpResp, err := client.UsersAPI.UserUserIdQuotaGet(callCtx, float32(userId)).Execute()
b, err := marshalResult(res, httpResp, err)
if err != nil {
return apiToolError("UserUserIdQuotaGet failed", err)
}
b, err := json.Marshal(res)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(string(b)), nil
}
}
Loading
Loading