diff --git a/README.md b/README.md index a3c83a3..730ce16 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ HTTP API providing user/client message handling for an fmsg host. Exposes CRUD o | `FMSG_API_MAX_DATA_SIZE`| `10` | Maximum message data size in megabytes | | `FMSG_API_MAX_ATTACH_SIZE`| `10` | Maximum attachment file size in megabytes | | `FMSG_API_MAX_MSG_SIZE`| `20` | Maximum total message size (data + attachments) in megabytes | +| `FMSG_API_SHORT_TEXT_SIZE`| `768` | Maximum bytes of message body returned inline as `short_text` for `text/*` UTF-8 messages | | `FMSG_CORS_ORIGINS` | *(optional)* | Comma-separated list of browser origins allowed via CORS, e.g. `https://example.com,https://www.example.com`. Use `*` to allow any origin. When unset, no CORS headers are emitted (server-to-server callers are unaffected). | Standard PostgreSQL environment variables (`PGHOST`, `PGPORT`, `PGUSER`, @@ -252,7 +253,7 @@ When `add_to` recipients are provided, `add_to_from` is automatically populated | Status | Condition | | ------ | --------- | -| `400` | Missing/invalid fields or empty `to` | +| `400` | Missing/invalid fields, empty `to`, `topic` set together with `pid`, or `add_to`/`add_to_from` set without `pid` | | `403` | `from` does not match authenticated user | ### GET `/fmsg/:id` @@ -279,10 +280,18 @@ Retrieves a single message by ID. The authenticated user must be a participant "topic": "Hello", "type": "text/plain", "size": 12, + "short_text": "hello world", "attachments": [] } ``` +The `short_text` field is included only when the message `type` is `text/*` and +the stored body is valid UTF-8. When `FMSG_API_SHORT_TEXT_SIZE` is greater than +`0`, it contains up to that many bytes (default 768) of the body, truncated on +a UTF-8 rune boundary, so UIs can render a preview without a separate +`GET /fmsg/:id/data` round-trip. Set `FMSG_API_SHORT_TEXT_SIZE=0` to disable +`short_text` generation. + **Errors:** | Status | Condition | @@ -298,7 +307,7 @@ Updates a draft message. Only the owner (`from`) may update, and the message mus | Status | Condition | | ------ | --------- | -| `400` | Invalid fields | +| `400` | Invalid fields, `topic` set together with `pid`, or `add_to`/`add_to_from` set without `pid` | | `403` | Not the owner, or message already sent | | `404` | Message not found | @@ -349,7 +358,7 @@ New addresses must be distinct among themselves (case-insensitive). | Status | Condition | | ------ | --------- | -| `400` | Empty `add_to` or duplicate addresses in request | +| `400` | Empty `add_to`, duplicate addresses, or target message has no `pid` | | `403` | Authenticated user is not an existing participant (sender or `to` recipient) | | `404` | Message not found | diff --git a/src/handlers/attachments.go b/src/handlers/attachments.go index f662215..c92b3da 100644 --- a/src/handlers/attachments.go +++ b/src/handlers/attachments.go @@ -30,7 +30,7 @@ func NewAttachmentHandler(database *db.DB, dataDir string, maxAttachSize, maxMsg return &AttachmentHandler{DB: database, DataDir: dataDir, MaxAttachSize: maxAttachSize, MaxMsgSize: maxMsgSize} } -// Upload handles POST /api/v1/messages/:id/attachments. +// Upload handles POST /fmsg/:id/attachments. func (h *AttachmentHandler) Upload(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -166,7 +166,7 @@ func (h *AttachmentHandler) Upload(c *gin.Context) { c.JSON(http.StatusCreated, gin.H{"filename": intendedFilename, "size": written}) } -// Download handles GET /api/v1/messages/:id/attachments/:filename. +// Download handles GET /fmsg/:id/attachments/:filename. func (h *AttachmentHandler) Download(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -239,7 +239,7 @@ func (h *AttachmentHandler) Download(c *gin.Context) { c.FileAttachment(cleanPath, filename) } -// DeleteAttachment handles DELETE /api/v1/messages/:id/attachments/:filename. +// DeleteAttachment handles DELETE /fmsg/:id/attachments/:filename. func (h *AttachmentHandler) DeleteAttachment(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) diff --git a/src/handlers/messages.go b/src/handlers/messages.go index f55f4ed..897b923 100644 --- a/src/handlers/messages.go +++ b/src/handlers/messages.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io" "log" "mime" "net/http" @@ -13,6 +14,7 @@ import ( "strconv" "strings" "time" + "unicode/utf8" "github.com/gin-gonic/gin" "github.com/jackc/pgx/v5" @@ -24,15 +26,16 @@ import ( // MessageHandler holds dependencies for message routes. type MessageHandler struct { - DB *db.DB - DataDir string - MaxDataSize int64 - MaxMsgSize int64 + DB *db.DB + DataDir string + MaxDataSize int64 + MaxMsgSize int64 + ShortTextSize int } // NewMessageHandler creates a MessageHandler. -func NewMessageHandler(database *db.DB, dataDir string, maxDataSize, maxMsgSize int64) *MessageHandler { - return &MessageHandler{DB: database, DataDir: dataDir, MaxDataSize: maxDataSize, MaxMsgSize: maxMsgSize} +func NewMessageHandler(database *db.DB, dataDir string, maxDataSize, maxMsgSize int64, shortTextSize int) *MessageHandler { + return &MessageHandler{DB: database, DataDir: dataDir, MaxDataSize: maxDataSize, MaxMsgSize: maxMsgSize, ShortTextSize: shortTextSize} } // messageListItem is the JSON shape for each message in the list response. @@ -53,6 +56,7 @@ type messageListItem struct { Topic string `json:"topic"` Type string `json:"type"` Size int `json:"size"` + ShortText string `json:"short_text,omitempty"` Attachments []models.Attachment `json:"attachments"` } @@ -173,7 +177,7 @@ func (h *MessageHandler) latestMessageIDForRecipient(ctx context.Context, identi return latestID, nil } -// List handles GET /api/v1/messages — lists messages where the authenticated user is a recipient. +// List handles GET /fmsg — lists messages where the authenticated user is a recipient. func (h *MessageHandler) List(c *gin.Context) { identity := middleware.GetIdentity(c) @@ -185,7 +189,7 @@ func (h *MessageHandler) List(c *gin.Context) { ctx := c.Request.Context() rows, err := h.DB.Pool.Query(ctx, - `SELECT m.id, m.version, m.pid, m.no_reply, m.is_important, m.is_deflate, m.time_sent, m.from_addr, m.topic, m.type, m.size + `SELECT m.id, m.version, m.pid, m.no_reply, m.is_important, m.is_deflate, m.time_sent, m.from_addr, m.topic, m.type, m.size, m.filepath FROM msg m WHERE EXISTS (SELECT 1 FROM msg_to mt WHERE mt.msg_id = m.id AND mt.addr = $1) OR EXISTS (SELECT 1 FROM msg_add_to mat WHERE mat.msg_id = m.id AND mat.addr = $1) @@ -204,12 +208,14 @@ func (h *MessageHandler) List(c *gin.Context) { var msgIDs []int64 for rows.Next() { var m messageListItem - if err := rows.Scan(&m.ID, &m.Version, &m.PID, &m.NoReply, &m.Important, &m.Deflate, &m.Time, &m.From, &m.Topic, &m.Type, &m.Size); err != nil { + var dataPath string + if err := rows.Scan(&m.ID, &m.Version, &m.PID, &m.NoReply, &m.Important, &m.Deflate, &m.Time, &m.From, &m.Topic, &m.Type, &m.Size, &dataPath); err != nil { log.Printf("list messages scan: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list messages"}) return } m.HasPid = m.PID != nil + m.ShortText = h.extractShortText(dataPath, m.Type) messages = append(messages, m) msgIDs = append(msgIDs, m.ID) } @@ -300,7 +306,7 @@ func (h *MessageHandler) Sent(c *gin.Context) { ctx := c.Request.Context() rows, err := h.DB.Pool.Query(ctx, - `SELECT m.id, m.version, m.pid, m.no_reply, m.is_important, m.is_deflate, m.time_sent, m.from_addr, m.topic, m.type, m.size + `SELECT m.id, m.version, m.pid, m.no_reply, m.is_important, m.is_deflate, m.time_sent, m.from_addr, m.topic, m.type, m.size, m.filepath FROM msg m WHERE m.from_addr = $1 ORDER BY m.id DESC @@ -318,12 +324,14 @@ func (h *MessageHandler) Sent(c *gin.Context) { var msgIDs []int64 for rows.Next() { var m messageListItem - if err := rows.Scan(&m.ID, &m.Version, &m.PID, &m.NoReply, &m.Important, &m.Deflate, &m.Time, &m.From, &m.Topic, &m.Type, &m.Size); err != nil { + var dataPath string + if err := rows.Scan(&m.ID, &m.Version, &m.PID, &m.NoReply, &m.Important, &m.Deflate, &m.Time, &m.From, &m.Topic, &m.Type, &m.Size, &dataPath); err != nil { log.Printf("list sent messages scan: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list sent messages"}) return } m.HasPid = m.PID != nil + m.ShortText = h.extractShortText(dataPath, m.Type) messages = append(messages, m) msgIDs = append(msgIDs, m.ID) } @@ -397,7 +405,7 @@ func (h *MessageHandler) Sent(c *gin.Context) { c.JSON(http.StatusOK, messages) } -// Create handles POST /api/v1/messages — creates a draft message. +// Create handles POST /fmsg — creates a draft message. func (h *MessageHandler) Create(c *gin.Context) { identity := middleware.GetIdentity(c) @@ -418,6 +426,16 @@ func (h *MessageHandler) Create(c *gin.Context) { return } + if err := validateAddresses(msg.From, msg.To, msg.AddTo, msg.AddToFrom); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := validatePidRelations(msg.PID, msg.Topic, msg.AddTo, msg.AddToFrom); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if int64(len(msg.Data)) > h.MaxDataSize { c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "message data exceeds maximum size"}) return @@ -489,7 +507,7 @@ func (h *MessageHandler) Create(c *gin.Context) { c.JSON(http.StatusCreated, gin.H{"id": msgID}) } -// Get handles GET /api/v1/messages/:id — retrieves a message. +// Get handles GET /fmsg/:id — retrieves a message. func (h *MessageHandler) Get(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -498,7 +516,7 @@ func (h *MessageHandler) Get(c *gin.Context) { } ctx := c.Request.Context() - msg, err := h.fetchMessage(ctx, msgID) + msg, dataPath, err := h.fetchMessage(ctx, msgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "message not found"}) @@ -515,10 +533,13 @@ func (h *MessageHandler) Get(c *gin.Context) { return } + // Compute ShortText only after authorization has been confirmed. + msg.ShortText = h.extractShortText(dataPath, msg.Type) + c.JSON(http.StatusOK, msg) } -// DownloadData handles GET /api/v1/messages/:id/data — downloads the message body as a file. +// DownloadData handles GET /fmsg/:id/data — downloads the message body as a file. func (h *MessageHandler) DownloadData(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -565,9 +586,8 @@ func (h *MessageHandler) DownloadData(c *gin.Context) { } // Path traversal protection: ensure the path is within DataDir. - cleanPath := filepath.Clean(dataPath) - cleanDataDir := filepath.Clean(h.DataDir) - if !strings.HasPrefix(cleanPath, cleanDataDir+string(filepath.Separator)) { + cleanPath, ok := safeDataPath(dataPath, h.DataDir) + if !ok { log.Printf("download data: path traversal attempt: %s", dataPath) c.JSON(http.StatusForbidden, gin.H{"error": "access denied"}) return @@ -576,7 +596,7 @@ func (h *MessageHandler) DownloadData(c *gin.Context) { c.FileAttachment(cleanPath, filepath.Base(cleanPath)) } -// Update handles PUT /api/v1/messages/:id — updates a draft message. +// Update handles PUT /fmsg/:id — updates a draft message. func (h *MessageHandler) Update(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -585,7 +605,7 @@ func (h *MessageHandler) Update(c *gin.Context) { } ctx := c.Request.Context() - existing, err := h.fetchMessage(ctx, msgID) + existing, _, err := h.fetchMessage(ctx, msgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "message not found"}) @@ -615,6 +635,16 @@ func (h *MessageHandler) Update(c *gin.Context) { return } + if err := validateAddresses(msg.From, msg.To, msg.AddTo, msg.AddToFrom); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := validatePidRelations(msg.PID, msg.Topic, msg.AddTo, msg.AddToFrom); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if int64(len(msg.Data)) > h.MaxDataSize { c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "message data exceeds maximum size"}) return @@ -671,7 +701,7 @@ func (h *MessageHandler) Update(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"id": msgID}) } -// Delete handles DELETE /api/v1/messages/:id — deletes a draft message. +// Delete handles DELETE /fmsg/:id — deletes a draft message. func (h *MessageHandler) Delete(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -680,7 +710,7 @@ func (h *MessageHandler) Delete(c *gin.Context) { } ctx := c.Request.Context() - existing, err := h.fetchMessage(ctx, msgID) + existing, _, err := h.fetchMessage(ctx, msgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "message not found"}) @@ -743,7 +773,7 @@ func (h *MessageHandler) Delete(c *gin.Context) { c.Status(http.StatusNoContent) } -// Send handles POST /api/v1/messages/:id/send — marks a message as sent. +// Send handles POST /fmsg/:id/send — marks a message as sent. func (h *MessageHandler) Send(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -752,7 +782,7 @@ func (h *MessageHandler) Send(c *gin.Context) { } ctx := c.Request.Context() - existing, err := h.fetchMessage(ctx, msgID) + existing, _, err := h.fetchMessage(ctx, msgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "message not found"}) @@ -787,7 +817,7 @@ type addToInput struct { AddTo []string `json:"add_to"` } -// AddRecipients handles POST /api/v1/messages/:id/add-to — adds additional recipients to a message. +// AddRecipients handles POST /fmsg/:id/add-to — adds additional recipients to a message. func (h *MessageHandler) AddRecipients(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -806,13 +836,21 @@ func (h *MessageHandler) AddRecipients(c *gin.Context) { return } + for _, addr := range input.AddTo { + if !middleware.IsValidAddr(addr) { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid add_to address: %q", addr)}) + return + } + } + ctx := c.Request.Context() // Load the message to verify it exists. var fromAddr string + var pid *int64 err := h.DB.Pool.QueryRow(ctx, - "SELECT from_addr FROM msg WHERE id = $1", msgID, - ).Scan(&fromAddr) + "SELECT from_addr, pid FROM msg WHERE id = $1", msgID, + ).Scan(&fromAddr, &pid) if err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "message not found"}) @@ -823,6 +861,12 @@ func (h *MessageHandler) AddRecipients(c *gin.Context) { return } + // add_to is only valid on replies (messages with a pid). + if pid == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "add_to is only valid when pid is supplied"}) + return + } + // Verify the requester is an existing participant (from or msg_to). if fromAddr != identity { var recipientCount int @@ -854,17 +898,20 @@ func (h *MessageHandler) AddRecipients(c *gin.Context) { } // fetchMessage loads a message with its recipients and attachments from the DB. -func (h *MessageHandler) fetchMessage(ctx context.Context, msgID int64) (*models.Message, error) { +// It also returns the raw filepath stored in the database so callers can use it +// after performing their own authorization checks. +func (h *MessageHandler) fetchMessage(ctx context.Context, msgID int64) (*models.Message, string, error) { row := h.DB.Pool.QueryRow(ctx, - `SELECT version, pid, no_reply, is_important, is_deflate, time_sent, from_addr, topic, type, size FROM msg WHERE id = $1`, + `SELECT version, pid, no_reply, is_important, is_deflate, time_sent, from_addr, topic, type, size, filepath FROM msg WHERE id = $1`, msgID, ) msg := &models.Message{} var pid *int64 var timeSent *float64 - if err := row.Scan(&msg.Version, &pid, &msg.NoReply, &msg.Important, &msg.Deflate, &timeSent, &msg.From, &msg.Topic, &msg.Type, &msg.Size); err != nil { - return nil, err + var dataPath string + if err := row.Scan(&msg.Version, &pid, &msg.NoReply, &msg.Important, &msg.Deflate, &timeSent, &msg.From, &msg.Topic, &msg.Type, &msg.Size, &dataPath); err != nil { + return nil, "", err } msg.PID = pid msg.Time = timeSent @@ -907,7 +954,7 @@ func (h *MessageHandler) fetchMessage(ctx context.Context, msgID int64) (*models attRows.Close() } - return msg, nil + return msg, dataPath, nil } // saveMessageData writes data to the filesystem and returns the absolute path. @@ -1021,6 +1068,132 @@ func isZip(data []byte) bool { return len(data) >= 4 && data[0] == 0x50 && data[1] == 0x4b && data[2] == 0x03 && data[3] == 0x04 } +// safeDataPath cleans dataPath and verifies it lies inside dataDir. Returns +// the cleaned absolute path and true on success, or "" and false otherwise. +func safeDataPath(dataPath, dataDir string) (string, bool) { + if dataPath == "" || dataDir == "" { + return "", false + } + absPath, err := filepath.Abs(dataPath) + if err != nil { + return "", false + } + absDataDir, err := filepath.Abs(dataDir) + if err != nil { + return "", false + } + cleanPath := filepath.Clean(absPath) + cleanDataDir := filepath.Clean(absDataDir) + relPath, err := filepath.Rel(cleanDataDir, cleanPath) + if err != nil { + return "", false + } + if strings.HasPrefix(relPath, "..") { + return "", false + } + return cleanPath, true +} + +// isTextMIME reports whether the given Content-Type's media type begins with +// "text/". Charset and other parameters are ignored. +func isTextMIME(mimeType string) bool { + if mimeType == "" { + return false + } + mediaType, _, err := mime.ParseMediaType(mimeType) + if err != nil { + return false + } + return strings.HasPrefix(mediaType, "text/") +} + +// extractShortText reads up to ShortTextSize bytes from the message body +// referenced by dataPath and returns it as a string when the message type +// is text/* and the bytes form valid UTF-8. Truncation is rounded down to +// the last complete UTF-8 rune so the result is always valid UTF-8. +// Returns "" on any failure (non-text type, invalid UTF-8, missing/unsafe +// path, read error). Errors are logged but not propagated. +func (h *MessageHandler) extractShortText(dataPath, mimeType string) string { + if h.ShortTextSize <= 0 { + return "" + } + if !isTextMIME(mimeType) { + return "" + } + cleanPath, ok := safeDataPath(dataPath, h.DataDir) + if !ok { + return "" + } + f, err := os.Open(cleanPath) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + log.Printf("short text: open %s: %v", cleanPath, err) + } + return "" + } + defer f.Close() + + buf := make([]byte, h.ShortTextSize) + n, err := io.ReadFull(f, buf) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { + log.Printf("short text: read %s: %v", cleanPath, err) + return "" + } + buf = buf[:n] + + // Trim to the largest valid UTF-8 prefix so that we only drop trailing + // incomplete bytes, while preserving complete multi-byte runes that end + // exactly at the buffer boundary. + for len(buf) > 0 && !utf8.Valid(buf) { + buf = buf[:len(buf)-1] + } + if len(buf) == 0 { + return "" + } + return string(buf) +} + +// validateAddresses returns an error if any of the provided fmsg address +// fields is not a valid "@user@domain" address. addToFrom is optional. +func validateAddresses(from string, to, addTo []string, addToFrom *string) error { + if !middleware.IsValidAddr(from) { + return fmt.Errorf("invalid from address: %q", from) + } + for _, addr := range to { + if !middleware.IsValidAddr(addr) { + return fmt.Errorf("invalid to address: %q", addr) + } + } + for _, addr := range addTo { + if !middleware.IsValidAddr(addr) { + return fmt.Errorf("invalid add_to address: %q", addr) + } + } + if addToFrom != nil && *addToFrom != "" && !middleware.IsValidAddr(*addToFrom) { + return fmt.Errorf("invalid add_to_from address: %q", *addToFrom) + } + return nil +} + +// validatePidRelations enforces: +// - If pid is set, topic must be empty (replies inherit topic from parent). +// - If pid is not set, add_to and add_to_from must be empty (a thread +// must exist before recipients can be added to it). +func validatePidRelations(pid *int64, topic string, addTo []string, addToFrom *string) error { + if pid != nil && topic != "" { + return fmt.Errorf("topic must be empty when pid is supplied") + } + if pid == nil { + if len(addTo) > 0 { + return fmt.Errorf("add_to is only valid when pid is supplied") + } + if addToFrom != nil && *addToFrom != "" { + return fmt.Errorf("add_to_from is only valid when pid is supplied") + } + } + return nil +} + // checkDistinctRecipients returns an error if any address in to or addTo // appears more than once (case-insensitive). func checkDistinctRecipients(to, addTo []string) error { diff --git a/src/handlers/messages_test.go b/src/handlers/messages_test.go index e058ede..27713ac 100644 --- a/src/handlers/messages_test.go +++ b/src/handlers/messages_test.go @@ -1,7 +1,11 @@ package handlers import ( + "os" + "path/filepath" + "strings" "testing" + "unicode/utf8" ) func TestParseAddr(t *testing.T) { @@ -142,3 +146,139 @@ func searchSubstring(s, sub string) bool { } return false } + +func TestIsTextMIME(t *testing.T) { + tests := []struct { + mime string + want bool + }{ + {"text/plain", true}, + {"text/html", true}, + {"text/plain; charset=utf-8", true}, + {"TEXT/PLAIN", true}, + {"application/json", false}, + {"application/octet-stream", false}, + {"image/png", false}, + {"application/pdf", false}, + {"", false}, + {"totally invalid;;;", false}, + } + for _, tt := range tests { + t.Run(tt.mime, func(t *testing.T) { + if got := isTextMIME(tt.mime); got != tt.want { + t.Errorf("isTextMIME(%q) = %v, want %v", tt.mime, got, tt.want) + } + }) + } +} + +func TestSafeDataPath(t *testing.T) { + dir := t.TempDir() + inside := filepath.Join(dir, "sub", "file.txt") + if _, ok := safeDataPath(inside, dir); !ok { + t.Errorf("expected inside path to be allowed: %s", inside) + } + outside := filepath.Join(filepath.Dir(dir), "evil.txt") + if _, ok := safeDataPath(outside, dir); ok { + t.Errorf("expected outside path to be rejected: %s", outside) + } + if _, ok := safeDataPath("", dir); ok { + t.Error("expected empty path to be rejected") + } + if _, ok := safeDataPath(inside, ""); ok { + t.Error("expected empty data dir to be rejected") + } +} + +// writeTempFile writes contents to a fresh file inside dir and returns its path. +func writeTempFile(t *testing.T, dir, name string, contents []byte) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, contents, 0o600); err != nil { + t.Fatalf("write %s: %v", path, err) + } + return path +} + +func TestExtractShortText(t *testing.T) { + dir := t.TempDir() + h := &MessageHandler{DataDir: dir, ShortTextSize: 768} + + t.Run("short text/plain returns full content", func(t *testing.T) { + path := writeTempFile(t, dir, "short.txt", []byte("hello world")) + got := h.extractShortText(path, "text/plain") + if got != "hello world" { + t.Errorf("got %q, want %q", got, "hello world") + } + }) + + t.Run("ascii longer than max truncates exactly", func(t *testing.T) { + body := strings.Repeat("a", 2000) + path := writeTempFile(t, dir, "long.txt", []byte(body)) + got := h.extractShortText(path, "text/plain") + if len(got) != 768 { + t.Errorf("got len %d, want 768", len(got)) + } + if got != strings.Repeat("a", 768) { + t.Errorf("unexpected content") + } + }) + + t.Run("utf8 multibyte truncation respects rune boundaries", func(t *testing.T) { + // "€" is 3 bytes in UTF-8 (0xE2 0x82 0xAC). 768 / 3 = 256 runes + // (= 768 bytes) — exact boundary. Use 257 runes (771 bytes) so the + // truncation must drop a partial rune. + body := strings.Repeat("€", 257) + path := writeTempFile(t, dir, "utf8.txt", []byte(body)) + small := &MessageHandler{DataDir: dir, ShortTextSize: 770} + got := small.extractShortText(path, "text/plain; charset=utf-8") + if !utf8.ValidString(got) { + t.Errorf("result is not valid UTF-8: % x", got) + } + // 770 bytes / 3 bytes per rune => 256 complete runes (768 bytes); + // trailing 2 bytes of the 257th rune must be dropped. + if got != strings.Repeat("€", 256) { + t.Errorf("unexpected truncation: len=%d runes=%d", len(got), utf8.RuneCountInString(got)) + } + }) + + t.Run("non-text mime returns empty", func(t *testing.T) { + path := writeTempFile(t, dir, "img.bin", []byte("hello")) + for _, mt := range []string{"application/octet-stream", "image/png", "application/pdf", "application/json"} { + if got := h.extractShortText(path, mt); got != "" { + t.Errorf("mime %q: got %q, want empty", mt, got) + } + } + }) + + t.Run("invalid utf8 returns empty", func(t *testing.T) { + path := writeTempFile(t, dir, "bad.txt", []byte{0xff, 0xfe, 0xfd, 0xfc}) + if got := h.extractShortText(path, "text/plain"); got != "" { + t.Errorf("got %q, want empty", got) + } + }) + + t.Run("missing file returns empty", func(t *testing.T) { + got := h.extractShortText(filepath.Join(dir, "does-not-exist.txt"), "text/plain") + if got != "" { + t.Errorf("got %q, want empty", got) + } + }) + + t.Run("path traversal returns empty", func(t *testing.T) { + outside := filepath.Join(filepath.Dir(dir), "evil.txt") + _ = os.WriteFile(outside, []byte("evil"), 0o600) + defer os.Remove(outside) + if got := h.extractShortText(outside, "text/plain"); got != "" { + t.Errorf("got %q, want empty", got) + } + }) + + t.Run("zero ShortTextSize disables", func(t *testing.T) { + path := writeTempFile(t, dir, "z.txt", []byte("hello")) + zero := &MessageHandler{DataDir: dir, ShortTextSize: 0} + if got := zero.extractShortText(path, "text/plain"); got != "" { + t.Errorf("got %q, want empty", got) + } + }) +} diff --git a/src/main.go b/src/main.go index 5966ea1..2b88d9c 100644 --- a/src/main.go +++ b/src/main.go @@ -50,6 +50,7 @@ func main() { maxDataSize := int64(envOrDefaultInt("FMSG_API_MAX_DATA_SIZE", 10)) * 1024 * 1024 maxAttachSize := int64(envOrDefaultInt("FMSG_API_MAX_ATTACH_SIZE", 10)) * 1024 * 1024 maxMsgSize := int64(envOrDefaultInt("FMSG_API_MAX_MSG_SIZE", 20)) * 1024 * 1024 + shortTextSize := envOrDefaultInt("FMSG_API_SHORT_TEXT_SIZE", 768) // CORS: comma-separated list of allowed browser origins, e.g. // "https://fmsg.io,https://www.fmsg.io". Empty disables CORS. @@ -91,7 +92,7 @@ func main() { router.Use(middleware.NewRateLimiter(ctx, float64(rateLimit), rateBurst)) // Instantiate handlers. - msgHandler := handlers.NewMessageHandler(database, dataDir, maxDataSize, maxMsgSize) + msgHandler := handlers.NewMessageHandler(database, dataDir, maxDataSize, maxMsgSize, shortTextSize) attHandler := handlers.NewAttachmentHandler(database, dataDir, maxAttachSize, maxMsgSize) // Register routes under /fmsg, all protected by JWT. diff --git a/src/middleware/jwt.go b/src/middleware/jwt.go index 2a6e466..e8f530e 100644 --- a/src/middleware/jwt.go +++ b/src/middleware/jwt.go @@ -151,7 +151,7 @@ func New(cfg Config) (gin.HandlerFunc, error) { } addr, _ := claims["sub"].(string) - if !isValidAddr(addr) { + if !IsValidAddr(addr) { log.Printf("auth rejected: ip=%s reason=invalid_addr sub=%q", c.ClientIP(), addr) respondAuth(c, http.StatusUnauthorized, "invalid identity") return @@ -234,8 +234,8 @@ func GetIdentity(c *gin.Context) string { return addr } -// isValidAddr checks that the address has the form "@user@domain". -func isValidAddr(addr string) bool { +// IsValidAddr checks that the address has the form "@user@domain". +func IsValidAddr(addr string) bool { if len(addr) < 3 { return false } diff --git a/src/middleware/jwt_test.go b/src/middleware/jwt_test.go index 549b1a9..1cbd94b 100644 --- a/src/middleware/jwt_test.go +++ b/src/middleware/jwt_test.go @@ -33,7 +33,7 @@ func TestIsValidAddr(t *testing.T) { } for _, tt := range tests { t.Run(tt.addr, func(t *testing.T) { - if got := isValidAddr(tt.addr); got != tt.want { + if got := IsValidAddr(tt.addr); got != tt.want { t.Errorf("isValidAddr(%q) = %v, want %v", tt.addr, got, tt.want) } }) diff --git a/src/models/models.go b/src/models/models.go index c19324e..cb8d8d5 100644 --- a/src/models/models.go +++ b/src/models/models.go @@ -23,5 +23,6 @@ type Message struct { Topic string `json:"topic"` Type string `json:"type"` Size int `json:"size"` + ShortText string `json:"short_text,omitempty"` Attachments []Attachment `json:"attachments"` }