diff --git a/README.md b/README.md index c662744..e3b4d50 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,10 @@ HTTP API providing user/client message handling for an fmsg host. Exposes CRUD o | Variable | Default | Description | | ------------------- | ------------------------ | ------------------------------------------------------- | | `FMSG_DATA_DIR` | *(required)* | Path where message data files are stored, e.g. `/var/lib/fmsgd/` | -| `FMSG_API_JWT_SECRET` | *(required)* | HMAC secret used to validate JWT tokens. Prefix with `base64:` to supply a base64-encoded key (e.g. `base64:c2VjcmV0`); otherwise the raw string is used. | +| `FMSG_JWT_JWKS_URL` | *(prod)* | URL of the IdP's JWKS endpoint (e.g. `https://idp.fmsg.io/.well-known/jwks.json`). When set, the API verifies EdDSA tokens issued by the IdP. Public keys are fetched and cached, refreshed and looked up by the token's `kid` header. | +| `FMSG_JWT_ISSUER` | *(prod, required with JWKS)* | Expected `iss` claim value (e.g. `https://idp.fmsg.io`). Tokens with a different issuer are rejected. | +| `FMSG_JWT_AUDIENCE` | *(optional)* | When set, tokens must include this value in their `aud` claim. | +| `FMSG_API_JWT_SECRET` | *(dev)* | HMAC secret for HS256 token verification. Used only in dev mode (when `FMSG_JWT_JWKS_URL` is unset). Prefix with `base64:` to supply a base64-encoded key. Either this or `FMSG_JWT_JWKS_URL` must be set. | | `FMSG_TLS_CERT` | *(optional)* | Path to the TLS certificate file (e.g. `/etc/letsencrypt/live/example.com/fullchain.pem`). When set with `FMSG_TLS_KEY`, enables HTTPS on port 443. | | `FMSG_TLS_KEY` | *(optional)* | Path to the TLS private key file (e.g. `/etc/letsencrypt/live/example.com/privkey.pem`). Must be set together with `FMSG_TLS_CERT`. | | `FMSG_API_PORT` | `8000` | TCP port for plain HTTP mode (ignored when TLS is enabled) | @@ -26,6 +29,42 @@ Standard PostgreSQL environment variables (`PGHOST`, `PGPORT`, `PGUSER`, A `.env` file placed in the working directory is loaded automatically at startup (values in the environment take precedence). +## Authentication + +All `/fmsg/*` routes require an `Authorization: Bearer ` header. The API +operates in one of two verification modes, selected automatically at startup: + +### EdDSA (production) + +Active when `FMSG_JWT_JWKS_URL` is set. Tokens are expected to be issued by the +fmsg IdP and signed with Ed25519. The JWKS endpoint is polled on a schedule; +the IdP can rotate keys by adding a new JWK with a fresh `kid`. + +Required token header: `alg: EdDSA`, `kid: `, `typ: JWT`. + +Required claims: + +| Claim | Description | +| ----- | ----------- | +| `iss` | Must equal `FMSG_JWT_ISSUER`. | +| `sub` | User address in `@user@domain` form. | +| `iat` | Issued-at timestamp (Unix seconds). | +| `nbf` | Not-before timestamp. | +| `exp` | Expiry timestamp (must be in the future, ±10 s leeway). | +| `jti` | Unique token ID. Used for in-process replay prevention until `exp`. | +| `aud` | Optional; required only when `FMSG_JWT_AUDIENCE` is set. | + +A 10-second clock-skew leeway is applied to `iat`/`nbf`/`exp` validation. +Replay prevention is in-process and does not coordinate across multiple API +instances; deploy as a single instance or replace the cache before scaling +horizontally. + +### HMAC (development) + +Active when `FMSG_JWT_JWKS_URL` is unset. Tokens must be HS256-signed with the +shared secret in `FMSG_API_JWT_SECRET`. Required claims are `sub` and `exp`; +`iat`/`nbf` are honoured when present. No replay prevention is applied. + ## Building Requires **Go 1.25** or newer. @@ -50,7 +89,8 @@ Set `FMSG_TLS_CERT` and `FMSG_TLS_KEY` to enable HTTPS on port `443`. ```bash export FMSG_DATA_DIR=/opt/fmsg/data -export FMSG_API_JWT_SECRET=changeme +export FMSG_JWT_JWKS_URL=https://idp.fmsg.io/.well-known/jwks.json +export FMSG_JWT_ISSUER=https://idp.fmsg.io export FMSG_TLS_CERT=/etc/letsencrypt/live/example.com/fullchain.pem export FMSG_TLS_KEY=/etc/letsencrypt/live/example.com/privkey.pem export PGHOST=localhost diff --git a/src/go.mod b/src/go.mod index 69e0407..9534b24 100644 --- a/src/go.mod +++ b/src/go.mod @@ -3,15 +3,16 @@ module github.com/markmnl/fmsg-webapi go 1.25.0 require ( - github.com/appleboy/gin-jwt/v2 v2.10.3 + github.com/MicahParks/keyfunc/v3 v3.8.0 github.com/gin-gonic/gin v1.12.0 - github.com/golang-jwt/jwt/v4 v4.5.2 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/jackc/pgx/v5 v5.8.0 github.com/joho/godotenv v1.5.1 golang.org/x/time v0.15.0 ) require ( + github.com/MicahParks/jwkset v0.11.0 // indirect github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic v1.15.0 // indirect github.com/bytedance/sonic/loader v0.5.0 // indirect @@ -37,7 +38,6 @@ require ( github.com/quic-go/quic-go v0.59.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.1 // indirect - github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect golang.org/x/arch v0.22.0 // indirect golang.org/x/crypto v0.48.0 // indirect diff --git a/src/go.sum b/src/go.sum index edc7d4d..7daea34 100644 --- a/src/go.sum +++ b/src/go.sum @@ -1,7 +1,7 @@ -github.com/appleboy/gin-jwt/v2 v2.10.3 h1:KNcPC+XPRNpuoBh+j+rgs5bQxN+SwG/0tHbIqpRoBGc= -github.com/appleboy/gin-jwt/v2 v2.10.3/go.mod h1:LDUaQ8mF2W6LyXIbd5wqlV2SFebuyYs4RDwqMNgpsp8= -github.com/appleboy/gofight/v2 v2.1.2 h1:VOy3jow4vIK8BRQJoC/I9muxyYlJ2yb9ht2hZoS3rf4= -github.com/appleboy/gofight/v2 v2.1.2/go.mod h1:frW+U1QZEdDgixycTj4CygQ48yLTUhplt43+Wczp3rw= +github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= +github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= +github.com/MicahParks/keyfunc/v3 v3.8.0 h1:Hx2dgIjAXGk9slakM6rV9BOeaWDPEXXZ4Us8guNBfds= +github.com/MicahParks/keyfunc/v3 v3.8.0/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= @@ -31,8 +31,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= -github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= -github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -79,18 +79,10 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= -github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= -github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= -github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY= github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= -github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= -github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= diff --git a/src/main.go b/src/main.go index 4aa99ee..e0b4d9f 100644 --- a/src/main.go +++ b/src/main.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/MicahParks/keyfunc/v3" "github.com/gin-gonic/gin" "github.com/joho/godotenv" @@ -26,8 +27,13 @@ func main() { // Required configuration. dataDir := mustEnv("FMSG_DATA_DIR") - jwtSecret := mustEnv("FMSG_API_JWT_SECRET") - jwtKey := parseSecret(jwtSecret) + + // JWT configuration. Mode is selected automatically: + // * EdDSA (prod) when FMSG_JWT_JWKS_URL is set. + // * HMAC (dev) otherwise, using FMSG_API_JWT_SECRET. + jwksURL := os.Getenv("FMSG_JWT_JWKS_URL") + jwtIssuer := os.Getenv("FMSG_JWT_ISSUER") + jwtAudience := os.Getenv("FMSG_JWT_AUDIENCE") // TLS configuration (optional — omit both to run plain HTTP). tlsCert := os.Getenv("FMSG_TLS_CERT") @@ -55,7 +61,11 @@ func main() { log.Println("connected to PostgreSQL") // Initialise JWT middleware. - jwtMiddleware, err := middleware.SetupJWT(jwtKey, idURL) + jwtCfg, err := buildJWTConfig(ctx, jwksURL, jwtIssuer, jwtAudience, idURL) + if err != nil { + log.Fatalf("failed to configure JWT: %v", err) + } + jwtMiddleware, err := middleware.New(jwtCfg) if err != nil { log.Fatalf("failed to initialise JWT middleware: %v", err) } @@ -72,7 +82,7 @@ func main() { // Register routes under /fmsg, all protected by JWT. fmsg := router.Group("/fmsg") - fmsg.Use(jwtMiddleware.MiddlewareFunc()) + fmsg.Use(jwtMiddleware) { fmsg.GET("/wait", msgHandler.Wait) fmsg.GET("", msgHandler.List) @@ -145,6 +155,40 @@ func envOrDefaultInt(key string, defaultValue int) int { return defaultValue } +// buildJWTConfig assembles a middleware.Config from environment-derived +// inputs, picking EdDSA (prod) when a JWKS URL is supplied and falling back +// to HMAC (dev) otherwise. +func buildJWTConfig(ctx context.Context, jwksURL, issuer, audience, idURL string) (middleware.Config, error) { + cfg := middleware.Config{ + Issuer: issuer, + Audience: audience, + IDURL: idURL, + } + + if jwksURL != "" { + if issuer == "" { + return cfg, errors.New("FMSG_JWT_ISSUER is required when FMSG_JWT_JWKS_URL is set") + } + k, err := keyfunc.NewDefaultCtx(ctx, []string{jwksURL}) + if err != nil { + return cfg, err + } + cfg.Mode = middleware.ModeEdDSA + cfg.JWKS = k.Keyfunc + log.Printf("JWT mode: EdDSA (issuer=%s, jwks=%s)", issuer, jwksURL) + return cfg, nil + } + + secret := os.Getenv("FMSG_API_JWT_SECRET") + if secret == "" { + return cfg, errors.New("either FMSG_JWT_JWKS_URL (prod) or FMSG_API_JWT_SECRET (dev) must be set") + } + cfg.Mode = middleware.ModeHMAC + cfg.HMACKey = parseSecret(secret) + log.Println("JWT mode: HMAC (development)") + return cfg, nil +} + // parseSecret returns the HMAC key bytes for the given secret string. // If s begins with "base64:" the remainder is base64-decoded; otherwise the // raw string bytes are used. diff --git a/src/middleware/jti_cache.go b/src/middleware/jti_cache.go new file mode 100644 index 0000000..cbe2480 --- /dev/null +++ b/src/middleware/jti_cache.go @@ -0,0 +1,93 @@ +package middleware + +import ( + "sync" + "time" +) + +// jtiCacheMaxEntries bounds memory usage of the in-process replay cache. +// When exceeded, expired entries are swept first; if still over the limit, +// new entries are dropped (the request is still rejected only on a true +// duplicate, never on overflow). +const jtiCacheMaxEntries = 100_000 + +// jtiCache tracks JWT IDs that have already been seen, until their +// corresponding token expiry, to prevent replay attacks. +// +// The cache lives in-process; it does not coordinate across multiple API +// instances. For a horizontally-scaled deployment, replace with a shared +// store (e.g. Postgres or Redis). +type jtiCache struct { + mu sync.Mutex + entries map[string]time.Time + stop chan struct{} +} + +// newJTICache returns a cache with a background sweeper running until Close. +func newJTICache() *jtiCache { + c := &jtiCache{ + entries: make(map[string]time.Time), + stop: make(chan struct{}), + } + go c.sweepLoop(time.Minute) + return c +} + +// Seen atomically checks whether jti has been recorded with an unexpired +// entry; if not, records it with the given expiry. Returns true if the +// jti was already present (i.e. this is a replay). +// +// Empty jti strings are never considered seen (caller decides policy). +func (c *jtiCache) Seen(jti string, exp time.Time) bool { + if jti == "" { + return false + } + now := time.Now() + c.mu.Lock() + defer c.mu.Unlock() + if existing, ok := c.entries[jti]; ok && existing.After(now) { + return true + } + if len(c.entries) >= jtiCacheMaxEntries { + c.sweepLocked(now) + if len(c.entries) >= jtiCacheMaxEntries { + // Cache full of unexpired entries; refuse to grow but do not + // falsely flag the token as a replay. + return false + } + } + c.entries[jti] = exp + return false +} + +// Close stops the background sweeper. +func (c *jtiCache) Close() { + select { + case <-c.stop: + default: + close(c.stop) + } +} + +func (c *jtiCache) sweepLoop(interval time.Duration) { + t := time.NewTicker(interval) + defer t.Stop() + for { + select { + case <-c.stop: + return + case now := <-t.C: + c.mu.Lock() + c.sweepLocked(now) + c.mu.Unlock() + } + } +} + +func (c *jtiCache) sweepLocked(now time.Time) { + for k, exp := range c.entries { + if !exp.After(now) { + delete(c.entries, k) + } + } +} diff --git a/src/middleware/jwt.go b/src/middleware/jwt.go index 625328a..2a6e466 100644 --- a/src/middleware/jwt.go +++ b/src/middleware/jwt.go @@ -2,115 +2,226 @@ package middleware import ( + "errors" "fmt" "log" "net/http" "strings" "time" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" - jwtv4 "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" ) +// IdentityKey is the Gin context key under which the authenticated user +// address is stored. const IdentityKey = "sub" -// clockSkew is the leeway applied to iat/nbf/exp validation to tolerate minor -// clock differences between services (e.g. containers on the same host). -const clockSkew = 10 * time.Second +// DefaultClockSkew is the leeway applied to iat/nbf/exp validation to tolerate +// minor clock differences between services. +const DefaultClockSkew = 10 * time.Second -// identityClaims is the payload stored in the JWT. -type identityClaims struct { - Addr string +// Mode selects the JWT verification strategy. +type Mode int + +const ( + // ModeHMAC verifies HS256 tokens with a shared symmetric secret. + // Intended for development and testing. + ModeHMAC Mode = iota + // ModeEdDSA verifies EdDSA (Ed25519) tokens whose public keys are + // served by an external IdP via JWKS. + ModeEdDSA +) + +// Config configures the JWT middleware. +type Config struct { + // Mode selects HMAC (dev) or EdDSA (prod) verification. + Mode Mode + + // HMACKey is the symmetric secret bytes (required when Mode == ModeHMAC). + HMACKey []byte + + // JWKS resolves Ed25519 public keys (typically by token header `kid`). + // Required when Mode == ModeEdDSA. + JWKS jwt.Keyfunc + + // Issuer, when non-empty, is required to match the token `iss` claim. + // Mandatory in EdDSA mode. + Issuer string + + // Audience, when non-empty, is required to be present in the token + // `aud` claim. Optional. + Audience string + + // IDURL is the base URL of the fmsgid identity service. + IDURL string + + // ClockSkew is the leeway applied to time-based claim validation. + // Defaults to DefaultClockSkew when zero. + ClockSkew time.Duration } -// SetupJWT creates and returns a configured GinJWTMiddleware. -// key is the HMAC secret bytes used to validate tokens. -// idURL is the base URL of the fmsgid service used to validate user addresses. -func SetupJWT(key []byte, idURL string) (*jwt.GinJWTMiddleware, error) { - // Set the global TimeFunc used by golang-jwt/v4 when validating iat/nbf/exp - // inside MapClaims.Valid(). gin-jwt's own TimeFunc field does not affect this - // path; only the package-level variable does. - jwtv4.TimeFunc = func() time.Time { return time.Now().Add(clockSkew) } - - mw, err := jwt.New(&jwt.GinJWTMiddleware{ - Realm: "fmsg", - Key: key, - Timeout: 24 * time.Hour, - MaxRefresh: 24 * time.Hour, - IdentityKey: IdentityKey, - - // PayloadFunc stores the user address from the login credentials into JWT claims. - PayloadFunc: func(data interface{}) jwt.MapClaims { - if v, ok := data.(*identityClaims); ok { - return jwt.MapClaims{IdentityKey: v.Addr} - } - return jwt.MapClaims{} - }, - - // IdentityHandler extracts the address from JWT claims and puts it in Gin context. - IdentityHandler: func(c *gin.Context) interface{} { - claims := jwt.ExtractClaims(c) - addr, _ := claims[IdentityKey].(string) - return &identityClaims{Addr: addr} - }, - - // Authorizator validates the extracted identity and checks with fmsgid. - Authorizator: func(data interface{}, c *gin.Context) bool { - v, ok := data.(*identityClaims) - if !ok || v == nil { - return false +// New constructs the JWT verification middleware. +// +// The returned handler: +// - extracts a Bearer token from the Authorization header, +// - parses & verifies the signature according to cfg.Mode, +// - validates iss/aud/exp/nbf claims, +// - rejects replays (EdDSA mode only) by tracking jti in-process, +// - extracts sub as the user address and validates its shape, +// - calls fmsgid to confirm the user is known and accepting messages, +// - on success stores the address in the Gin context under IdentityKey. +// +// On failure the response is 400/401/403/503 with a JSON `{"error": "..."}` body. +func New(cfg Config) (gin.HandlerFunc, error) { + if cfg.ClockSkew == 0 { + cfg.ClockSkew = DefaultClockSkew + } + + var ( + validMethods []string + keyFunc jwt.Keyfunc + ) + + switch cfg.Mode { + case ModeHMAC: + if len(cfg.HMACKey) == 0 { + return nil, errors.New("middleware: HMAC mode requires a non-empty HMACKey") + } + validMethods = []string{jwt.SigningMethodHS256.Alg()} + key := cfg.HMACKey + keyFunc = func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %s", t.Method.Alg()) } - addr := v.Addr - if !isValidAddr(addr) { - log.Printf("auth rejected: ip=%s reason=invalid_addr", c.ClientIP()) - return false + return key, nil + } + case ModeEdDSA: + if cfg.JWKS == nil { + return nil, errors.New("middleware: EdDSA mode requires a JWKS keyfunc") + } + if cfg.Issuer == "" { + return nil, errors.New("middleware: EdDSA mode requires an Issuer") + } + validMethods = []string{jwt.SigningMethodEdDSA.Alg()} + jwks := cfg.JWKS + keyFunc = func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodEd25519); !ok { + return nil, fmt.Errorf("unexpected signing method: %s", t.Method.Alg()) } - // Store the validated identity in context for downstream handlers. - c.Set(IdentityKey, addr) - - // Validate with fmsgid service. - code, accepting, err := checkFmsgID(idURL, addr) - if err != nil { - log.Printf("fmsgid check error for %s: %v", addr, err) - c.Set("auth_error_code", http.StatusServiceUnavailable) - c.Set("auth_error_msg", "identity service unavailable") - return false + return jwks(t) + } + default: + return nil, fmt.Errorf("middleware: unknown JWT mode %d", cfg.Mode) + } + + parserOpts := []jwt.ParserOption{ + jwt.WithValidMethods(validMethods), + jwt.WithLeeway(cfg.ClockSkew), + jwt.WithExpirationRequired(), + jwt.WithIssuedAt(), + } + if cfg.Issuer != "" { + parserOpts = append(parserOpts, jwt.WithIssuer(cfg.Issuer)) + } + if cfg.Audience != "" { + parserOpts = append(parserOpts, jwt.WithAudience(cfg.Audience)) + } + parser := jwt.NewParser(parserOpts...) + + var replay *jtiCache + if cfg.Mode == ModeEdDSA { + replay = newJTICache() + } + + idURL := cfg.IDURL + + return func(c *gin.Context) { + tokenStr, err := extractBearer(c.GetHeader("Authorization")) + if err != nil { + respondAuth(c, http.StatusUnauthorized, "missing or malformed Authorization header") + return + } + + claims := jwt.MapClaims{} + if _, err := parser.ParseWithClaims(tokenStr, claims, keyFunc); err != nil { + log.Printf("auth rejected: ip=%s reason=parse_error err=%v", c.ClientIP(), err) + respondAuth(c, http.StatusUnauthorized, "invalid token") + return + } + + addr, _ := claims["sub"].(string) + if !isValidAddr(addr) { + log.Printf("auth rejected: ip=%s reason=invalid_addr sub=%q", c.ClientIP(), addr) + respondAuth(c, http.StatusUnauthorized, "invalid identity") + return + } + + if replay != nil { + jti, _ := claims["jti"].(string) + if jti == "" { + log.Printf("auth rejected: ip=%s addr=%s reason=missing_jti", c.ClientIP(), addr) + respondAuth(c, http.StatusUnauthorized, "invalid token") + return } - if code == http.StatusNotFound { - log.Printf("auth rejected: ip=%s addr=%s reason=not_found", c.ClientIP(), addr) - c.Set("auth_error_code", http.StatusBadRequest) - c.Set("auth_error_msg", fmt.Sprintf("User %s not found", addr)) - return false + expTime, err := claims.GetExpirationTime() + if err != nil || expTime == nil { + respondAuth(c, http.StatusUnauthorized, "invalid token") + return } - if code == http.StatusOK && !accepting { - log.Printf("auth rejected: ip=%s addr=%s reason=not_accepting", c.ClientIP(), addr) - c.Set("auth_error_code", http.StatusForbidden) - c.Set("auth_error_msg", fmt.Sprintf("User %s not authorised to send new messages", addr)) - return false + if replay.Seen(jti, expTime.Time) { + log.Printf("auth rejected: ip=%s addr=%s reason=jti_replay jti=%s", c.ClientIP(), addr, jti) + respondAuth(c, http.StatusUnauthorized, "token already used") + return } - return true - }, + } - // Unauthorized responds when JWT validation or authorization fails. - Unauthorized: func(c *gin.Context, code int, message string) { - if errCode, exists := c.Get("auth_error_code"); exists { - code = errCode.(int) - } - if errMsg, exists := c.Get("auth_error_msg"); exists { - message = errMsg.(string) - } - log.Printf("auth failure: ip=%s code=%d message=%s", c.ClientIP(), code, message) - c.JSON(code, gin.H{"error": message}) - }, - - TokenLookup: "header: Authorization", - TokenHeadName: "Bearer", - // TimeFunc is used by gin-jwt for orig_iat and expiry arithmetic. - // Must be kept consistent with the jwtv4.TimeFunc set above. - TimeFunc: func() time.Time { return time.Now().Add(clockSkew) }, - }) - return mw, err + code, accepting, err := checkFmsgID(idURL, addr) + if err != nil { + log.Printf("fmsgid check error for %s: %v", addr, err) + respondAuth(c, http.StatusServiceUnavailable, "identity service unavailable") + return + } + switch { + case code == http.StatusNotFound: + log.Printf("auth rejected: ip=%s addr=%s reason=not_found", c.ClientIP(), addr) + respondAuth(c, http.StatusBadRequest, fmt.Sprintf("User %s not found", addr)) + return + case code == http.StatusOK && !accepting: + log.Printf("auth rejected: ip=%s addr=%s reason=not_accepting", c.ClientIP(), addr) + respondAuth(c, http.StatusForbidden, fmt.Sprintf("User %s not authorised to send new messages", addr)) + return + case code != http.StatusOK: + log.Printf("auth rejected: ip=%s addr=%s reason=fmsgid_status=%d", c.ClientIP(), addr, code) + respondAuth(c, http.StatusServiceUnavailable, "identity service unavailable") + return + } + + c.Set(IdentityKey, addr) + c.Next() + }, nil +} + +// respondAuth aborts the request with a JSON error body. +func respondAuth(c *gin.Context, code int, message string) { + log.Printf("auth failure: ip=%s code=%d message=%s", c.ClientIP(), code, message) + c.AbortWithStatusJSON(code, gin.H{"error": message}) +} + +// extractBearer returns the token portion of an Authorization header value. +func extractBearer(header string) (string, error) { + if header == "" { + return "", errors.New("empty header") + } + const prefix = "Bearer " + if !strings.HasPrefix(header, prefix) { + return "", errors.New("missing Bearer prefix") + } + tok := strings.TrimSpace(header[len(prefix):]) + if tok == "" { + return "", errors.New("empty token") + } + return tok, nil } // GetIdentity retrieves the authenticated user address from the Gin context. @@ -152,7 +263,6 @@ func checkFmsgID(idURL, addr string) (int, bool, error) { return resp.StatusCode, false, nil } - // Parse acceptingNew from the JSON response. var result struct { AcceptingNew bool `json:"acceptingNew"` } diff --git a/src/middleware/jwt_test.go b/src/middleware/jwt_test.go index 1b210ca..549b1a9 100644 --- a/src/middleware/jwt_test.go +++ b/src/middleware/jwt_test.go @@ -1,12 +1,22 @@ package middleware import ( + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "net/http" + "net/http/httptest" "testing" "time" - jwtv4 "github.com/golang-jwt/jwt/v4" + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" ) +func init() { + gin.SetMode(gin.TestMode) +} + func TestIsValidAddr(t *testing.T) { tests := []struct { addr string @@ -15,11 +25,11 @@ func TestIsValidAddr(t *testing.T) { {"@alice@example.com", true}, {"@a@b", true}, {"@user@domain.org", true}, - {"alice@example.com", false}, // missing leading @ - {"@nodomain", false}, // no second @ - {"@", false}, // too short - {"ab", false}, // too short, no @ - {"", false}, // empty + {"alice@example.com", false}, + {"@nodomain", false}, + {"@", false}, + {"ab", false}, + {"", false}, } for _, tt := range tests { t.Run(tt.addr, func(t *testing.T) { @@ -30,60 +40,309 @@ func TestIsValidAddr(t *testing.T) { } } -func TestGetIdentityEmpty(t *testing.T) { - // GetIdentity with no context key set should return "". - // We can't easily construct a gin.Context here without more setup, - // so this is a smoke-level check that the function signature is correct. - _ = GetIdentity +// fakeJWKS returns a jwt.Keyfunc that yields a fixed Ed25519 public key for +// a single known kid. +func fakeJWKS(kid string, pub ed25519.PublicKey) jwt.Keyfunc { + return func(t *jwt.Token) (interface{}, error) { + k, _ := t.Header["kid"].(string) + if k != kid { + return nil, jwt.ErrTokenSignatureInvalid + } + return pub, nil + } } -// TestIATClockSkew verifies that SetupJWT installs a global TimeFunc that -// allows tokens whose iat is up to clockSkew seconds in the future, and -// rejects tokens whose iat exceeds that window. -func TestIATClockSkew(t *testing.T) { - const secret = "test-secret-for-iat-skew" +// fmsgIDServer returns an httptest server emulating fmsgid responses. +func fmsgIDServer(t *testing.T, status int, accepting bool) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if status != http.StatusOK { + w.WriteHeader(status) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]bool{"acceptingNew": accepting}) + })) +} - // SetupJWT sets jwtv4.TimeFunc as a side-effect. - if _, err := SetupJWT([]byte(secret), "http://127.0.0.1:0"); err != nil { - t.Fatalf("SetupJWT: %v", err) +// runMiddleware executes the middleware against a synthetic request bearing +// the given token, returning the recorded response. +func runMiddleware(t *testing.T, mw gin.HandlerFunc, token string) *httptest.ResponseRecorder { + t.Helper() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/fmsg", nil) + if token != "" { + c.Request.Header.Set("Authorization", "Bearer "+token) } + called := false + c.Set("__test_next__", &called) + mw(c) + return w +} - sign := func(iatOffset time.Duration) string { - claims := jwtv4.MapClaims{ - "sub": "@alice@example.com", - "iat": time.Now().Add(iatOffset).Unix(), - "exp": time.Now().Add(24 * time.Hour).Unix(), - } - tok := jwtv4.NewWithClaims(jwtv4.SigningMethodHS256, claims) - s, err := tok.SignedString([]byte(secret)) - if err != nil { - t.Fatalf("sign token: %v", err) - } - return s +func signEdDSA(t *testing.T, priv ed25519.PrivateKey, kid string, claims jwt.MapClaims) string { + t.Helper() + tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) + tok.Header["kid"] = kid + s, err := tok.SignedString(priv) + if err != nil { + t.Fatalf("sign: %v", err) } + return s +} - keyFunc := func(t *jwtv4.Token) (interface{}, error) { return []byte(secret), nil } +func signHS256(t *testing.T, secret []byte, claims jwt.MapClaims) string { + t.Helper() + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + s, err := tok.SignedString(secret) + if err != nil { + t.Fatalf("sign: %v", err) + } + return s +} - tests := []struct { - name string - iatOffset time.Duration - wantErr bool - }{ - {"iat now+1s accepted", 1 * time.Second, false}, - {"iat now+10s accepted", clockSkew, false}, - {"iat now+11s rejected", clockSkew + time.Second, true}, +func TestHMACMode_Happy(t *testing.T) { + srv := fmsgIDServer(t, http.StatusOK, true) + defer srv.Close() + + secret := []byte("dev-secret") + mw, err := New(Config{Mode: ModeHMAC, HMACKey: secret, IDURL: srv.URL}) + if err != nil { + t.Fatalf("New: %v", err) } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tokenString := sign(tt.iatOffset) - _, err := jwtv4.Parse(tokenString, keyFunc) - if tt.wantErr && err == nil { - t.Error("expected validation error but token was accepted") - } - if !tt.wantErr && err != nil { - t.Errorf("token should be accepted but got error: %v", err) - } - }) + tok := signHS256(t, secret, jwt.MapClaims{ + "sub": "@alice@example.com", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + }) + w := runMiddleware(t, mw, tok) + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String()) + } +} + +func TestHMACMode_ClockSkewLeeway(t *testing.T) { + srv := fmsgIDServer(t, http.StatusOK, true) + defer srv.Close() + secret := []byte("dev-secret") + mw, err := New(Config{Mode: ModeHMAC, HMACKey: secret, IDURL: srv.URL}) + if err != nil { + t.Fatalf("New: %v", err) + } + + // iat/nbf within leeway is accepted. + now := time.Now() + tok := signHS256(t, secret, jwt.MapClaims{ + "sub": "@alice@example.com", + "iat": now.Add(DefaultClockSkew - time.Second).Unix(), + "nbf": now.Add(DefaultClockSkew - time.Second).Unix(), + "exp": now.Add(time.Hour).Unix(), + }) + if w := runMiddleware(t, mw, tok); w.Code != http.StatusOK { + t.Fatalf("within-skew token should be accepted, got %d", w.Code) + } + + // Beyond leeway is rejected. + tok = signHS256(t, secret, jwt.MapClaims{ + "sub": "@alice@example.com", + "nbf": now.Add(DefaultClockSkew + 5*time.Second).Unix(), + "exp": now.Add(time.Hour).Unix(), + }) + if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { + t.Fatalf("out-of-skew token should be rejected, got %d", w.Code) + } +} + +func TestHMACMode_MissingHeader(t *testing.T) { + mw, err := New(Config{Mode: ModeHMAC, HMACKey: []byte("k"), IDURL: "http://127.0.0.1:0"}) + if err != nil { + t.Fatal(err) + } + if w := runMiddleware(t, mw, ""); w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } +} + +func newEdDSAFixture(t *testing.T) (priv ed25519.PrivateKey, jwks jwt.Keyfunc) { + t.Helper() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("genkey: %v", err) + } + return priv, fakeJWKS("prod-1", pub) +} + +func eddsaConfig(idURL string, jwks jwt.Keyfunc) Config { + return Config{ + Mode: ModeEdDSA, + JWKS: jwks, + Issuer: "https://idp.fmsg.io", + IDURL: idURL, + } +} + +func TestEdDSAMode_Happy(t *testing.T) { + srv := fmsgIDServer(t, http.StatusOK, true) + defer srv.Close() + priv, jwks := newEdDSAFixture(t) + mw, err := New(eddsaConfig(srv.URL, jwks)) + if err != nil { + t.Fatalf("New: %v", err) + } + + tok := signEdDSA(t, priv, "prod-1", jwt.MapClaims{ + "iss": "https://idp.fmsg.io", + "sub": "@alice@example.com", + "iat": time.Now().Unix(), + "nbf": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "jti": "11111111-1111-1111-1111-111111111111", + }) + if w := runMiddleware(t, mw, tok); w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String()) + } +} + +func TestEdDSAMode_WrongIssuer(t *testing.T) { + srv := fmsgIDServer(t, http.StatusOK, true) + defer srv.Close() + priv, jwks := newEdDSAFixture(t) + mw, err := New(eddsaConfig(srv.URL, jwks)) + if err != nil { + t.Fatal(err) + } + tok := signEdDSA(t, priv, "prod-1", jwt.MapClaims{ + "iss": "https://evil.example.com", + "sub": "@alice@example.com", + "exp": time.Now().Add(time.Hour).Unix(), + "jti": "a", + }) + if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } +} + +func TestEdDSAMode_UnknownKID(t *testing.T) { + srv := fmsgIDServer(t, http.StatusOK, true) + defer srv.Close() + priv, jwks := newEdDSAFixture(t) + mw, err := New(eddsaConfig(srv.URL, jwks)) + if err != nil { + t.Fatal(err) + } + tok := signEdDSA(t, priv, "rotated-key", jwt.MapClaims{ + "iss": "https://idp.fmsg.io", + "sub": "@alice@example.com", + "exp": time.Now().Add(time.Hour).Unix(), + "jti": "a", + }) + if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } +} + +func TestEdDSAMode_AlgDowngrade(t *testing.T) { + srv := fmsgIDServer(t, http.StatusOK, true) + defer srv.Close() + _, jwks := newEdDSAFixture(t) + mw, err := New(eddsaConfig(srv.URL, jwks)) + if err != nil { + t.Fatal(err) + } + // Sign with HS256 — must be rejected by an EdDSA-only middleware. + tok := signHS256(t, []byte("anything"), jwt.MapClaims{ + "iss": "https://idp.fmsg.io", + "sub": "@alice@example.com", + "exp": time.Now().Add(time.Hour).Unix(), + "jti": "a", + }) + if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } +} + +func TestEdDSAMode_Expired(t *testing.T) { + srv := fmsgIDServer(t, http.StatusOK, true) + defer srv.Close() + priv, jwks := newEdDSAFixture(t) + mw, err := New(eddsaConfig(srv.URL, jwks)) + if err != nil { + t.Fatal(err) + } + tok := signEdDSA(t, priv, "prod-1", jwt.MapClaims{ + "iss": "https://idp.fmsg.io", + "sub": "@alice@example.com", + "exp": time.Now().Add(-time.Hour).Unix(), + "jti": "a", + }) + if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } +} + +func TestEdDSAMode_Replay(t *testing.T) { + srv := fmsgIDServer(t, http.StatusOK, true) + defer srv.Close() + priv, jwks := newEdDSAFixture(t) + mw, err := New(eddsaConfig(srv.URL, jwks)) + if err != nil { + t.Fatal(err) + } + claims := jwt.MapClaims{ + "iss": "https://idp.fmsg.io", + "sub": "@alice@example.com", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "jti": "replay-me", + } + tok := signEdDSA(t, priv, "prod-1", claims) + + if w := runMiddleware(t, mw, tok); w.Code != http.StatusOK { + t.Fatalf("first call expected 200, got %d", w.Code) + } + if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { + t.Fatalf("replay expected 401, got %d", w.Code) + } +} + +func TestEdDSAMode_FmsgIDUnavailable(t *testing.T) { + srv := fmsgIDServer(t, http.StatusInternalServerError, false) + defer srv.Close() + priv, jwks := newEdDSAFixture(t) + mw, err := New(eddsaConfig(srv.URL, jwks)) + if err != nil { + t.Fatal(err) + } + tok := signEdDSA(t, priv, "prod-1", jwt.MapClaims{ + "iss": "https://idp.fmsg.io", + "sub": "@alice@example.com", + "exp": time.Now().Add(time.Hour).Unix(), + "jti": "x", + }) + if w := runMiddleware(t, mw, tok); w.Code != http.StatusServiceUnavailable { + t.Fatalf("expected 503, got %d", w.Code) + } +} + +func TestJTICache_SeenAndExpiry(t *testing.T) { + c := newJTICache() + defer c.Close() + exp := time.Now().Add(time.Hour) + if c.Seen("a", exp) { + t.Fatal("first Seen should be false") + } + if !c.Seen("a", exp) { + t.Fatal("second Seen should be true") + } + // Expired entry should not count as seen. + if c.Seen("b", time.Now().Add(-time.Second)) { + t.Fatal("expired entry: first Seen should be false") + } + // And subsequently the cached entry, having expired in the past, should + // be replaceable. + if c.Seen("b", time.Now().Add(-time.Second)) { + t.Fatal("expired entry: should not flag as replay") } }