diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c9486af..e1c3d71 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,10 +3,9 @@ name: Build on: push: tags: ['v*'] - pull_request: - branches: [main] - + jobs: + build: strategy: matrix: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..5f1a7ec --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,39 @@ +name: Test & Build + +on: + pull_request: + branches: [main, dev] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: '1.26' + cache: true + + - name: Run tests + run: go test ./... -v + build: + strategy: + matrix: + arch: [amd64, arm64] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: '1.26' + cache: true + + - name: Build yatund + run: | + GOOS=linux GOARCH=${{ matrix.arch }} CGO_ENABLED=0 go build -o build/yatund-${{ matrix.arch }} ./cmd/yatund/ + + - name: Build yatun + run: | + GOOS=linux GOARCH=${{ matrix.arch }} CGO_ENABLED=0 go build -o build/yatun-${{ matrix.arch }} ./cmd/yatun/ + diff --git a/.gitignore b/.gitignore index 4c49bd7..49c8451 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,5 @@ .env +/yatun +/yatund +*.exe + diff --git a/README.md b/README.md index 8fabcaf..8cbade6 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ ``` ┌──────────┐ TCP ┌────────────────┐ yamux ┌──────────┐ TCP ┌─────────┐ -│ Internet │───────────▶│ yatund │◀══════════▶│ yatun │────────▶│ local │ -│ Client │ random port│ (relay server) │ session │ (agent) │ :port │ service │ +│ Internet │───────────▶│ yatund │◀═════════▶│ yatun │────────▶│ local │ +│ Client │ random port│ (relay server)│ session │ (agent) │ :port │ service │ └──────────┘ └────────────────┘ └──────────┘ └─────────┘ ``` diff --git a/cmd/yatun/main.go b/cmd/yatun/main.go index fe5a05e..a2697f4 100644 --- a/cmd/yatun/main.go +++ b/cmd/yatun/main.go @@ -1,12 +1,13 @@ package main import ( - "errors" + "context" "flag" "fmt" "io" - "os" + "sync/atomic" + "sync" "time" @@ -18,34 +19,51 @@ import ( "github.com/hashicorp/yamux" ) -func sendMsg(con *yamux.Stream, m message.TransportMessage) { - byt := m.Encode() +func sendMsg(con *yamux.Stream, m message.TransportMessage) (err error) { + byt, err := m.Encode() + if err != nil { + + return fmt.Errorf("failed to encode message: %w", err) + } + + _, err = con.Write(byt) + return - con.Write(byt) } -func clientServerComms(ses *yamux.Session, tuiP *tea.Program) { +func clientServerComms(ses *yamux.Session, tuiP *tea.Program) (err error) { con, err := ses.OpenStream() if err != nil { - tuiP.Kill() - panic(errors.New("failed to accept new stream, is the server running?")) - } - - dat := message.ConnectionDetailsMessageData{ - SubdomainName: "asd", + tuiP.Send(tui.SetState{ + State: tui.ErrorState, + Err: err, + }) + tuiP.Quit() + return } - sendMsg(con, message.TransportMessage{ - Type: message.ConnectionDetails, - Data: &dat, + err = sendMsg(con, message.TransportMessage{ + Type: message.OpenMsg, }) + if err != nil { + tuiP.Send(tui.SetState{ + State: tui.ErrorState, + Err: fmt.Errorf("failed to send initial message to server, is the server running?\n%v", err), + }) + tuiP.Quit() + return err + } go func() { for { msg, err := message.Decode(con) if err != nil { - tuiP.Kill() - panic(errors.New("the server closed unexpectedly")) + tuiP.Send(tui.SetState{ + State: tui.ErrorState, + Err: fmt.Errorf("failed at decoding server message\n%v", err), + }) + tuiP.Quit() + return } switch msg.Type { @@ -60,6 +78,7 @@ func clientServerComms(ses *yamux.Session, tuiP *tea.Program) { } } }() + return nil } func initializeServerConnection(tuiP *tea.Program, server string) (sess *yamux.Session, err error) { @@ -74,7 +93,7 @@ func initializeServerConnection(tuiP *tea.Program, server string) (sess *yamux.S return } - clientServerComms(sess, tuiP) + err = clientServerComms(sess, tuiP) return } @@ -83,82 +102,121 @@ type trafficMonitor struct { underlying io.ReadWriter tuiP *tea.Program streamType tui.TrafficDirection + + bytesTransferred *atomic.Int64 } func (c trafficMonitor) Read(p []byte) (n int, err error) { n, err = c.underlying.Read(p) - go c.tuiP.Send(tui.TrafficUpdate{ - Direction: c.streamType, - Bytes: n, - }) + c.bytesTransferred.Add(int64(n)) return } func (c trafficMonitor) Write(p []byte) (n int, err error) { return c.underlying.Write(p) + } -func serverConnectionLoop(sess *yamux.Session, port *string, tuiP *tea.Program) { - // TODO: After initial handshake is done, io.Copy from server (yatun) to internal target server - for { +func handleStream(ctx context.Context, stream *yamux.Stream, port *string, tuiP *tea.Program) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() - stream, err := sess.AcceptStream() - if err != nil { - // TODO: Maybe? send a message to the TUI so that the user knows it is having trouble getting new sessions from server - tuiP.Send(tui.SetState{ - Err: err, - State: tui.ErrorState, - }) - return - } + tuiP.Send(tui.LiveConnection) + defer tuiP.Send(tui.DeadConnection) + defer stream.Close() - go func() { - tuiP.Send(tui.LiveConnection) - defer tuiP.Send(tui.DeadConnection) - defer stream.Close() + localConn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%v", *port), time.Second*10) + if err != nil { + tuiP.Send(tui.LocalConnectionError) + // Notify to the TUI the error, maybe the server is down? + return + } + defer localConn.Close() - localConn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%v", *port), time.Second*10) - if err != nil { - tuiP.Send(tui.LocalConnectionError) - // Notify to the TUI the error, maybe the server is down? + go func() { + <-ctx.Done() + + localConn.Close() + stream.Close() + }() + + streamMonitor := trafficMonitor{ + underlying: stream, + tuiP: tuiP, + streamType: tui.Inbound, + bytesTransferred: &atomic.Int64{}, + } + localConnMonitor := trafficMonitor{ + underlying: localConn, + tuiP: tuiP, + streamType: tui.Outbound, + bytesTransferred: &atomic.Int64{}, + } + + go func() { + t := time.NewTicker(time.Second * 5) + defer t.Stop() + + for { + select { + case <-t.C: + // The TUI already sums the data internally, so the right call is Swap instead of load, this could also be used to measure throughput + tuiP.Send(tui.TrafficUpdate{ + Direction: streamMonitor.streamType, + Bytes: int(streamMonitor.bytesTransferred.Swap(0)), + }) + tuiP.Send(tui.TrafficUpdate{ + Direction: localConnMonitor.streamType, + Bytes: int(localConnMonitor.bytesTransferred.Swap(0)), + }) + + case <-ctx.Done(): return } - defer localConn.Close() + } + }() - streamCopier := trafficMonitor{ - underlying: stream, - tuiP: tuiP, - streamType: tui.Inbound, - } - localConnCopier := trafficMonitor{ - underlying: localConn, - tuiP: tuiP, - streamType: tui.Outbound, - } + wg := sync.WaitGroup{} - wg := sync.WaitGroup{} + wg.Go(func() { - wg.Go(func() { + io.Copy(streamMonitor, localConnMonitor) + cancel() - io.Copy(streamCopier, localConnCopier) - localConn.Close() + }) - }) + wg.Go(func() { + io.Copy(localConnMonitor, streamMonitor) + cancel() - wg.Go(func() { - io.Copy(localConnCopier, streamCopier) - stream.Close() + }) - }) + wg.Wait() +} - wg.Wait() +func serverConnectionLoop(ctx context.Context, sess *yamux.Session, port *string, tuiP *tea.Program) { + // TODO: After initial handshake is done, io.Copy from server (yatun) to internal target server + for { - }() + stream, err := sess.AcceptStreamWithContext(ctx) + if err != nil { + // TODO: Maybe? send a message to the TUI so that the user knows it is having trouble getting new sessions from server + tuiP.Send(tui.SetState{ + Err: err, + State: tui.ErrorState, + }) + tuiP.Quit() + return + } + go handleStream(ctx, stream, port, tuiP) } } func main() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + port := flag.String("port", "", "--port") server := flag.String("server", "yatun.snowdev.one", "--server") @@ -175,20 +233,26 @@ func main() { sess, err := initializeServerConnection(tuiP, *server) if err != nil { go tuiP.Send(tui.SetState{ - Err: err, + + Err: &tui.FailedInitialConfigError{ + Err: err}, State: tui.ErrorState, }) + // tuiP.Quit() } if err == nil { go tuiP.Send(tui.SetState{ State: tui.OnlineState, }) - go serverConnectionLoop(sess, port, tuiP) + go serverConnectionLoop(ctx, sess, port, tuiP) } if _, err := tuiP.Run(); err != nil { - sess.Close() + if sess != nil { + sess.Close() + } + cancel() os.Exit(1) } } diff --git a/cmd/yatund/main.go b/cmd/yatund/main.go index d8ce9fb..391dbde 100644 --- a/cmd/yatund/main.go +++ b/cmd/yatund/main.go @@ -1,15 +1,27 @@ package main import ( + "context" "log" "net" + "os/signal" + "syscall" + "github.com/KatIsCoding/yatun/internal/config" "github.com/KatIsCoding/yatun/internal/server" ) func main() { - conf := config.ReadFromEnv() + ctx := context.Background() + + ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + conf, err := config.ReadFromEnv() + if err != nil { + log.Fatalf("FATAL: Failed to parse configuration\n%v", err) + } // So, we need to open a tcp server for external connections and another one for the agent. // But the external should only open if an agent requests it, so the very first server is the agent one. @@ -17,23 +29,40 @@ func main() { // The agents listener listener, err := net.Listen("tcp", "0.0.0.0:5678") if err != nil { - panic(err) + log.Printf("Error binding to port 5678: %v", err) + return } + go func() { + <-ctx.Done() + listener.Close() + }() + log.Printf("Listening on %v", listener.Addr().String()) for { conn, err := listener.Accept() if err != nil { - panic(err) + if ctx.Err() != nil { + log.Printf("Loop break, context canceled") + return + } + log.Printf("Failed to accept client: %v", err) + continue } sconn, err := server.NewServerConnection(conn, conf) if err != nil { - panic(err) + log.Printf("Failed on the creation of a new server: %v", err) + continue } - go sconn.StartListeningAgents() + go func() { + err := sconn.StartListeningAgents(ctx) + if err != nil { + log.Printf("Failed starting the agent(s) setup and loop: %v", err) + } + }() } } diff --git a/go.mod b/go.mod index 6294ec6..9ede070 100644 --- a/go.mod +++ b/go.mod @@ -3,33 +3,26 @@ module github.com/KatIsCoding/yatun go 1.26.1 require ( - charm.land/bubbles/v2 v2.1.0 // indirect - charm.land/bubbletea/v2 v2.0.6 // indirect - charm.land/lipgloss/v2 v2.0.3 // indirect - github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/charmbracelet/bubbletea v1.3.10 // indirect + charm.land/bubbles/v2 v2.1.0 + charm.land/bubbletea/v2 v2.0.6 + charm.land/lipgloss/v2 v2.0.3 + github.com/hashicorp/yamux v0.1.2 +) + +require ( github.com/charmbracelet/colorprofile v0.4.3 // indirect - github.com/charmbracelet/lipgloss v1.1.0 // indirect github.com/charmbracelet/ultraviolet v0.0.0-20260416155717-489999b90468 // indirect github.com/charmbracelet/x/ansi v0.11.7 // indirect - github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect github.com/charmbracelet/x/term v0.2.2 // indirect github.com/charmbracelet/x/termios v0.1.1 // indirect github.com/charmbracelet/x/windows v0.2.2 // indirect github.com/clipperhouse/displaywidth v0.11.0 // indirect github.com/clipperhouse/uax29/v2 v2.7.0 // indirect - github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect - github.com/hashicorp/yamux v0.1.2 // indirect github.com/lucasb-eyer/go-colorful v1.4.0 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-localereader v0.0.1 // indirect github.com/mattn/go-runewidth v0.0.23 // indirect - github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect - github.com/muesli/termenv v0.16.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.43.0 // indirect - golang.org/x/text v0.3.8 // indirect ) diff --git a/go.sum b/go.sum index 264bb70..a7ad83b 100644 --- a/go.sum +++ b/go.sum @@ -4,26 +4,16 @@ charm.land/bubbletea/v2 v2.0.6 h1:UHN/91OyuhaOFGSrBXQ/hMZD8IO1Uc4BvHlgHXL2WJo= charm.land/bubbletea/v2 v2.0.6/go.mod h1:MH/D8ZLlN3op37vQvijKuU29g3rqTp+aQapURFonF9g= charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU= charm.land/lipgloss/v2 v2.0.3/go.mod h1:7myLU9iG/3xluAWzpY/fSxYYHCgoKTie7laxk6ATwXA= -github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= -github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= -github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= -github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= -github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= -github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o= +github.com/aymanbagabas/go-udiff v0.4.1/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w= github.com/charmbracelet/colorprofile v0.4.3 h1:QPa1IWkYI+AOB+fE+mg/5/4HRMZcaXex9t5KX76i20Q= github.com/charmbracelet/colorprofile v0.4.3/go.mod h1:/zT4BhpD5aGFpqQQqw7a+VtHCzu+zrQtt1zhMt9mR4Q= -github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= -github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= github.com/charmbracelet/ultraviolet v0.0.0-20260416155717-489999b90468 h1:Q9fO0y1Zo5KB/5Vu8JZoLGm1N3RzF9bNj3Ao3xoR+Ac= github.com/charmbracelet/ultraviolet v0.0.0-20260416155717-489999b90468/go.mod h1:bAAz7dh/FTYfC+oiHavL4mX1tOIBZ0ZwYjSi3qE6ivM= -github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ= -github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI= github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ= -github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= -github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= -github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= -github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= +github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA= +github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I= github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY= @@ -34,40 +24,21 @@ github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSE github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0= github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk= github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= -github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= -github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= -github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= -github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4= github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= -github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= -github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= -github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw= github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= -github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= -github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= -github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= -github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= -golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= -golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= diff --git a/internal/config/config.go b/internal/config/config.go index 4e1ec5c..496ccd0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,16 +17,17 @@ type ServerConfig struct { TLSStore tlsutil.TLSStore } -func ReadFromEnv() ServerConfig { +func ReadFromEnv() (ServerConfig, error) { c := ServerConfig{} - domain, ok := os.LookupEnv("Domain") + domain, ok := os.LookupEnv("DOMAIN") if ok { if !strings.Contains(domain, "//") { domain = "//" + domain } u, err := url.Parse(domain) if err != nil { - panic(err) + log.Printf("Failed to parse set domain: %v", err) + return c, err } log.Printf("Host: %v", u.Hostname()) c.Domain = new(u.Hostname()) @@ -35,8 +36,12 @@ func ReadFromEnv() ServerConfig { tls, ok := os.LookupEnv("TLS") if ok && tls != "0" { c.TLS = true - c.TLSStore = tlsutil.LoadTLSCerts() + store, err := tlsutil.LoadTLSCerts() + if err != nil { + return c, err + } + c.TLSStore = store } - return c + return c, nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..fadff5e --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,81 @@ +package config + +import ( + "os" + "testing" +) + +func TestReadFromEnvDefaults(t *testing.T) { + os.Unsetenv("DOMAIN") + os.Unsetenv("TLS") + + cfg, err := ReadFromEnv() + if err != nil { + t.Fatalf("ReadFromEnv: %v", err) + } + + if cfg.Domain != nil { + t.Errorf("expected nil Domain, got %q", *cfg.Domain) + } + if cfg.TLS { + t.Errorf("expected TLS to be false by default") + } +} + +func TestReadFromEnvDomain(t *testing.T) { + os.Unsetenv("TLS") + t.Setenv("DOMAIN", "example.com") + + cfg, err := ReadFromEnv() + if err != nil { + t.Fatalf("ReadFromEnv: %v", err) + } + + if cfg.Domain == nil { + t.Fatal("expected Domain to be set") + } + if *cfg.Domain != "example.com" { + t.Errorf("expected example.com, got %q", *cfg.Domain) + } +} + +func TestReadFromEnvDomainWithScheme(t *testing.T) { + os.Unsetenv("TLS") + t.Setenv("DOMAIN", "https://sub.example.com") + + cfg, err := ReadFromEnv() + if err != nil { + t.Fatalf("ReadFromEnv: %v", err) + } + + if cfg.Domain == nil { + t.Fatal("expected Domain to be set") + } + if *cfg.Domain != "sub.example.com" { + t.Errorf("expected sub.example.com, got %q", *cfg.Domain) + } +} + +func TestReadFromEnvDomainInvalidURL(t *testing.T) { + os.Unsetenv("TLS") + t.Setenv("DOMAIN", "://bad host") + + _, err := ReadFromEnv() + if err == nil { + t.Fatal("expected error for invalid domain URL") + } +} + +func TestReadFromEnvTLSDisabled(t *testing.T) { + os.Unsetenv("DOMAIN") + t.Setenv("TLS", "0") + + cfg, err := ReadFromEnv() + if err != nil { + t.Fatalf("ReadFromEnv: %v", err) + } + + if cfg.TLS { + t.Errorf("expected TLS to be false when set to 0") + } +} diff --git a/internal/message/message.go b/internal/message/message.go index e869e72..0f976bd 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "encoding/json" "errors" + "fmt" + "log" "time" "io" @@ -14,25 +16,6 @@ type MessagePayload interface { Decode(b []byte) error } -type ConnectionDetailsMessageData struct { - SubdomainName string `json:"address"` -} - -func (c ConnectionDetailsMessageData) Encode() ([]byte, error) { - return json.Marshal(c) -} -func (c *ConnectionDetailsMessageData) Decode(b []byte) error { - var t ConnectionDetailsMessageData - err := json.Unmarshal(b, &t) - if err != nil { - return err - } - - // log.Printf("Subdom %v", t.SubdomainName) - c.SubdomainName = t.SubdomainName - return nil -} - type Response struct { Ok bool `json:"ok"` Address string `json:"address"` @@ -46,7 +29,7 @@ func (r *Response) Decode(b []byte) error { var rT Response err := json.Unmarshal(b, &rT) if err != nil { - return err + return fmt.Errorf("failed to unmarshall message json: %w", err) } r.Address = rT.Address @@ -66,7 +49,7 @@ func (r *PingMessage) Decode(b []byte) error { var rT PingMessage err := json.Unmarshal(b, &rT) if err != nil { - return err + return fmt.Errorf("failed to unmarshall message: %w", err) } r.Time = rT.Time @@ -80,16 +63,13 @@ type TransportMessage struct { Data MessagePayload } -var TCPConnection MessageType = 'A' // Agent -> Server msg type, used for signaling the agent is asking for TCP, no data transfer. -var ConnectionDetails MessageType = 'B' // Server -> Agent msg type, used to transfer data about the server created, data is transferred (json, maybe change later) +var TCPConnection MessageType = 'A' // Agent -> Server msg type, used for signaling the agent is asking for TCP, no data transfer. +var OpenMsg MessageType = 'B' // Server -> Agent msg type, used to transfer data about the server created, data is transferred (json, maybe change later) var ResponseMessage MessageType = 'C' var PingMessageType MessageType = 'D' -func (m TransportMessage) GetConnectionDetails() (*MessagePayload, error) { - return &m.Data, nil -} - func ParseConnectionType(b byte) (*MessageType, error) { + log.Printf("Byte rec %v", b) switch b { case byte(TCPConnection): return &TCPConnection, nil @@ -99,7 +79,7 @@ func ParseConnectionType(b byte) (*MessageType, error) { } } -func (m TransportMessage) Encode() []byte { +func (m TransportMessage) Encode() ([]byte, error) { // Shape // [type: 1byte][dataLen: 2 byte][data: ?bytes] out := make([]byte, 1) @@ -110,7 +90,7 @@ func (m TransportMessage) Encode() []byte { dataBuf, err := m.Data.Encode() if err != nil { - panic(err) + return nil, err } l := uint16(len(dataBuf)) @@ -122,7 +102,7 @@ func (m TransportMessage) Encode() []byte { out = append(out, dataBuf...) // Data } - return out + return out, nil } func readSize(r io.Reader) (uint16, error) { @@ -141,7 +121,7 @@ func readSize(r io.Reader) (uint16, error) { func Decode(r io.Reader) (*TransportMessage, error) { b := make([]byte, 1) - n, err := r.Read(b) + n, err := io.ReadFull(r, b) if err != nil { return nil, err @@ -159,6 +139,10 @@ func Decode(r io.Reader) (*TransportMessage, error) { return &m, nil } + if mType == OpenMsg { + return &m, nil + } + // Types that have some data, read and put it into the message for later Decoding. // Read the next data stuff bufSize, err := readSize(r) @@ -173,20 +157,20 @@ func Decode(r io.Reader) (*TransportMessage, error) { } var obj MessagePayload switch mType { - case ConnectionDetails: - obj = &ConnectionDetailsMessageData{} + case ResponseMessage: obj = &Response{} case PingMessageType: obj = &PingMessage{} } - err = obj.Decode(buf[:n]) - if err != nil { - return nil, err + if obj != nil { + err = obj.Decode(buf[:n]) + if err != nil { + return nil, err + } + m.Data = obj } - m.Data = obj - return &m, nil } diff --git a/internal/message/message_test.go b/internal/message/message_test.go new file mode 100644 index 0000000..d2650a9 --- /dev/null +++ b/internal/message/message_test.go @@ -0,0 +1,247 @@ +package message + +import ( + "bytes" + "io" + "strings" + "testing" + "time" +) + +func TestEncodeDecodeResponse(t *testing.T) { + original := TransportMessage{ + Type: ResponseMessage, + Data: &Response{ + Ok: true, + Address: "example.com:443", + }, + } + + encoded, err := original.Encode() + if err != nil { + t.Fatalf("encode: %v", err) + } + + decoded, err := Decode(bytes.NewReader(encoded)) + if err != nil { + t.Fatalf("decode: %v", err) + } + + if decoded.Type != original.Type { + t.Errorf("type: got %c, want %c", decoded.Type, original.Type) + } + + resp, ok := decoded.Data.(*Response) + if !ok { + t.Fatalf("expected *Response, got %T", decoded.Data) + } + if resp.Ok != true { + t.Errorf("Ok: got %v, want true", resp.Ok) + } + if resp.Address != "example.com:443" { + t.Errorf("Address: got %q, want %q", resp.Address, "example.com:443") + } +} + +func TestEncodeDecodePingMessage(t *testing.T) { + now := time.Now().Truncate(time.Millisecond) + + original := TransportMessage{ + Type: PingMessageType, + Data: &PingMessage{Time: now}, + } + + encoded, err := original.Encode() + if err != nil { + t.Fatalf("encode: %v", err) + } + + decoded, err := Decode(bytes.NewReader(encoded)) + if err != nil { + t.Fatalf("decode: %v", err) + } + + if decoded.Type != PingMessageType { + t.Errorf("type: got %c, want %c", decoded.Type, PingMessageType) + } + + ping, ok := decoded.Data.(*PingMessage) + if !ok { + t.Fatalf("expected *PingMessage, got %T", decoded.Data) + } + if !ping.Time.Equal(now) { + t.Errorf("Time: got %v, want %v", ping.Time, now) + } +} + +func TestEncodeDecodeNoData(t *testing.T) { + for _, mt := range []MessageType{TCPConnection, OpenMsg} { + original := TransportMessage{ + Type: mt, + Data: nil, + } + + encoded, err := original.Encode() + if err != nil { + t.Fatalf("encode %c: %v", mt, err) + } + + if len(encoded) != 1 { + t.Errorf("expected 1 byte for data-less message, got %d", len(encoded)) + } + + decoded, err := Decode(bytes.NewReader(encoded)) + if err != nil { + t.Fatalf("decode %c: %v", mt, err) + } + + if decoded.Type != mt { + t.Errorf("type: got %c, want %c", decoded.Type, mt) + } + if decoded.Data != nil { + t.Errorf("expected nil Data for type %c, got %v", mt, decoded.Data) + } + } +} + +func TestDecodeUnknownType(t *testing.T) { + // Create a message with an unknown type 'X' and a valid-length JSON payload + encoded := []byte{'X', 0, 2, '{', '}'} + + decoded, err := Decode(bytes.NewReader(encoded)) + if err != nil { + t.Fatalf("decode: %v", err) + } + + if decoded.Type != MessageType('X') { + t.Errorf("type: got %c, want X", decoded.Type) + } + if decoded.Data != nil { + t.Errorf("expected nil Data for unknown type, got %v", decoded.Data) + } +} + +func TestDecodeTruncatedType(t *testing.T) { + _, err := Decode(strings.NewReader("")) + if err != io.EOF { + t.Errorf("expected io.EOF for empty input, got %v", err) + } +} + +func TestDecodeTruncatedSize(t *testing.T) { + // Type byte present, but no size bytes + _, err := Decode(bytes.NewReader([]byte{'C'})) + if err == nil { + t.Fatal("expected error for truncated size field") + } +} + +func TestDecodeTruncatedPayload(t *testing.T) { + // Type 'C', size says 100 bytes, but only 5 follow + encoded := bytes.NewBuffer([]byte{'C', 0, 100}) + encoded.Write([]byte("short")) + + _, err := Decode(bytes.NewReader(encoded.Bytes())) + if err == nil { + t.Fatal("expected error for truncated payload") + } +} + +func TestDecodeGarbagePayload(t *testing.T) { + // Type 'C', size says 5 bytes of garbage (not valid JSON) + encoded := bytes.NewBuffer([]byte{'C', 0, 5}) + encoded.Write([]byte("!!!!!")) + + _, err := Decode(bytes.NewReader(encoded.Bytes())) + if err == nil { + t.Fatal("expected error for garbage JSON payload") + } +} + +func TestEncodeWithNilData(t *testing.T) { + m := TransportMessage{ + Type: PingMessageType, + Data: nil, + } + + encoded, err := m.Encode() + if err != nil { + t.Fatalf("encode: %v", err) + } + + if len(encoded) != 1 { + t.Errorf("expected 1 byte with nil Data, got %d", len(encoded)) + } + + if encoded[0] != byte(PingMessageType) { + t.Errorf("type byte: got %c, want %c", encoded[0], PingMessageType) + } +} + +func TestParseConnectionType(t *testing.T) { + mt, err := ParseConnectionType('A') + if err != nil { + t.Fatalf("ParseConnectionType('A'): %v", err) + } + if *mt != TCPConnection { + t.Errorf("got %c (%d), want %c (%d)", *mt, *mt, TCPConnection, TCPConnection) + } + + _, err = ParseConnectionType('Z') + if err == nil { + t.Fatal("expected error for unknown type") + } +} + +func TestResponseEncodeDecode(t *testing.T) { + original := Response{Ok: false, Address: "127.0.0.1:3000"} + + data, err := original.Encode() + if err != nil { + t.Fatalf("encode: %v", err) + } + + var decoded Response + if err := decoded.Decode(data); err != nil { + t.Fatalf("decode: %v", err) + } + + if decoded.Ok != original.Ok || decoded.Address != original.Address { + t.Errorf("mismatch: got %+v, want %+v", decoded, original) + } +} + +func TestPingMessageEncodeDecode(t *testing.T) { + now := time.Now().Truncate(time.Millisecond) + original := PingMessage{Time: now} + + data, err := original.Encode() + if err != nil { + t.Fatalf("encode: %v", err) + } + + var decoded PingMessage + if err := decoded.Decode(data); err != nil { + t.Fatalf("decode: %v", err) + } + + if !decoded.Time.Equal(original.Time) { + t.Errorf("Time mismatch: got %v, want %v", decoded.Time, original.Time) + } +} + +func TestResponseDecodeInvalidJSON(t *testing.T) { + var r Response + err := r.Decode([]byte("not json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestPingMessageDecodeInvalidJSON(t *testing.T) { + var p PingMessage + err := p.Decode([]byte("not json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 33af7b6..b51b8b5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,6 +2,7 @@ package server import ( "bufio" + "context" "errors" "fmt" @@ -18,87 +19,80 @@ import ( ) type ServerConnection struct { - agentSession *yamux.Session - connectionType *message.MessageType - connectionDetails *message.ConnectionDetailsMessageData - config config.ServerConfig + agentSession *yamux.Session + connectionType *message.MessageType + + config config.ServerConfig + + serverCancelFunc context.CancelFunc } func (s *ServerConnection) handleInitialConfig(stream *yamux.Stream) error { - // Since this is supposed to be called once only, we don't need a for - // TODO: Ctx with timeout maybe? - // - - // buf := make([]byte, 1) - // _, err := stream.Read(buf) - // if err != nil { - // return err - // } - // + // TODO: Maybe this initial message could be used for something? I'm not quite sure what just yet though m, err := message.Decode(stream) if err != nil { - return err + return fmt.Errorf("failed to decode initial config message: %w", err) } - if m.Type == message.ConnectionDetails { - details, err := m.GetConnectionDetails() - if err != nil { - return err - } - - if details == nil { - return errors.New("no details") - } - - // dr := *details - pDet := (*details).(*message.ConnectionDetailsMessageData) - log.Printf("Config with %v addr", pDet.SubdomainName) - - s.connectionDetails = pDet - } + log.Printf("Received message %v", m) s.connectionType = &m.Type return nil } -func (s *ServerConnection) setupServer() (net.Listener, error) { - // TODO: Expand this for more server options - return net.Listen("tcp", "0.0.0.0:") -} +func (s *ServerConnection) setupServer(ctx context.Context) (net.Listener, error) { -func pipe(out io.Writer, in io.Reader) { - _, err := io.Copy(out, in) + server, err := net.Listen("tcp", "0.0.0.0:") if err != nil { - log.Printf("Error while copying %v", err) + return server, fmt.Errorf("error starting tcp server: %w", err) } + ctx, s.serverCancelFunc = context.WithCancel(ctx) + + go func() { + <-ctx.Done() + + err := server.Close() + if err != nil { + log.Printf("Failed to close the server: %v", err) + } + }() + + return server, nil } -func (s *ServerConnection) handleNewConn(conn net.Conn) { +func (s *ServerConnection) handleNewConn(ctx context.Context, conn net.Conn) { agConn, err := s.agentSession.OpenStream() if err != nil { log.Printf("Error opening yamux stream %v", err) return } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + <-ctx.Done() + agConn.Close() + conn.Close() + }() + wg := sync.WaitGroup{} wg.Go(func() { + defer cancel() _, err := io.Copy(agConn, conn) log.Printf("Conn finished %v", err) - agConn.Close() }) wg.Go(func() { + defer cancel() _, err := io.Copy(conn, agConn) log.Printf("AgConn finished %v", err) - conn.Close() }) wg.Wait() log.Printf("WG finished") - conn.Close() - agConn.Close() } type bufferedConn struct { @@ -110,7 +104,7 @@ func (b *bufferedConn) Read(p []byte) (int, error) { return b.buf.Read(p) } -func (s *ServerConnection) peekAndUpgrade(conn net.Conn) (net.Conn, error) { +func (s *ServerConnection) peekAndUpgrade(conn net.Conn) (*bufferedConn, error) { buf := bufio.NewReader(conn) ob := bufferedConn{ @@ -119,67 +113,67 @@ func (s *ServerConnection) peekAndUpgrade(conn net.Conn) (net.Conn, error) { } if !s.config.TLS { - return conn, nil + return &ob, nil } fB, err := buf.Peek(1) if err != nil { log.Printf("Error peeking into the connection %v", err) - return conn, err + return &ob, err } - log.Printf("TCP first bytes %v", fB) - if fB[0] == 0x16 { log.Printf("Initiating TLS") - if !s.config.TLS { - - log.Printf("The incoming request attempted to initiate a TLS connection, but TLS is disabled.") - return conn, nil - } upgraded, err := s.config.TLSStore.Wrap(&ob) if err != nil { - return conn, err + return &ob, err } - // buf := bufio.NewReader(upgraded) - // ob.buf = buf - // ob.Conn = upgraded + buf := bufio.NewReader(upgraded) + ob.buf = buf + ob.Conn = upgraded log.Printf("TLS Upgrade success") - return upgraded, nil + return &ob, nil } - return conn, nil + return &ob, nil } -func (s *ServerConnection) handleExternalConnections(serv net.Listener) error { +func (s *ServerConnection) handleExternalConnections(ctx context.Context, serv net.Listener) error { // TODO: Accept a context so that if the connection with the agent is lost, we close the server! for { conn, err := serv.Accept() if err != nil { - log.Printf("Error accepting external conn %v", err) - return err - } + if ctx.Err() != nil { + return nil + } - // if tcpConn, ok := conn.(*net.TCPConn); ok { - // log.Printf("TCPConn conv") + return fmt.Errorf("error accepting external connection: %w", err) + } - // tcpConn.SetKeepAlive(true) - // tcpConn.SetKeepAlivePeriod(time.Second * 5) - // } - // + go func() { + <-ctx.Done() + err := conn.Close() + if err != nil { + log.Printf("Error closing connection: %v", err) + } + }() // Check if the connection requires TLS and upgrade in case it does - buffered, err := s.peekAndUpgrade(conn) + buffered, err := s.peekAndUpgrade(conn) // This function will only return an error IF TLS handshake fails, but if the request doesn't even ask for TLS it will return with the original conn and no error + if err != nil { + log.Printf("Error upgrading connection: %v", err) + continue + } log.Printf("New req received") - go s.handleNewConn(buffered) + go s.handleNewConn(ctx, buffered) } } @@ -188,7 +182,9 @@ func (s *ServerConnection) sendAddressInfo(stream *yamux.Stream, server net.List if s.config.Domain != nil { _, port, err := net.SplitHostPort(server.Addr().String()) if err != nil { - panic(err) + + return fmt.Errorf("error parsing host/port from address: %w", err) + } addr = fmt.Sprintf("%v:%v", *s.config.Domain, port) @@ -202,16 +198,20 @@ func (s *ServerConnection) sendAddressInfo(stream *yamux.Stream, server net.List }, } - _, err := stream.Write(m.Encode()) + byt, err := m.Encode() + if err != nil { + return fmt.Errorf("error encoding address info message: %w", err) + } + + _, err = stream.Write(byt) if err != nil { - log.Printf("Error sending to the client response info") - return err + return fmt.Errorf("error writing to stream: %w", err) } return nil } -func pingLoop(stream *yamux.Stream) { +func pingLoop(stream *yamux.Stream) error { t := time.NewTicker(time.Second * 5) // TODO: Change to something like 30s defer t.Stop() @@ -223,56 +223,75 @@ func pingLoop(stream *yamux.Stream) { }, } - _, err := stream.Write(m.Encode()) + byt, err := m.Encode() if err != nil { - return + log.Printf("Failed to encode ping message for agent: %v", err) + return err + } + + _, err = stream.Write(byt) + if err != nil { + log.Printf("Failed to send ping notification to agent: %v", err) + return err } } + return nil } -func (s *ServerConnection) StartListeningAgents() error { +func (s *ServerConnection) StartListeningAgents(ctx context.Context) error { - stream, err := s.agentSession.AcceptStream() + stream, err := s.agentSession.AcceptStreamWithContext(ctx) if err != nil { if errors.Is(err, io.EOF) { // This means the agent stream disconnected, it is fine to just break here return errors.New("sess disconnected") } - log.Printf("Unrecognized err: %v", err) - return err + + return fmt.Errorf("unrecoginzed err when listening to agents: %w", err) } // When we receive an agent stream, listen and parse for the first handshake and then we can start copying over / creating new sessions - s.handleInitialConfig(stream) + err = s.handleInitialConfig(stream) + if err != nil { + return fmt.Errorf("setting up the initial config for agent failed: %w", err) + } // After config, we setup whatever server type needs to be opened, and start copying over // defer stream.Close() - server, err := s.setupServer() + server, err := s.setupServer(ctx) if err != nil { - panic(err) + return fmt.Errorf("failed to setup tcp server: %w", err) } err = s.sendAddressInfo(stream, server) if err != nil { - log.Printf("Error sending address information %v", err) - return err + return fmt.Errorf("failed to send initial information to the agent: %w", err) + } - go pingLoop(stream) + go func() { + err := pingLoop(stream) + if err != nil { + log.Printf("Error executing the ping loop\n%v", err) + s.agentSession.Close() + return + } + }() go func() { + // If the agent disconnects, also close the server <-s.agentSession.CloseChan() - server.Close() - log.Printf("Server %v closed", server.Addr()) + s.serverCancelFunc() + }() log.Printf("Port opened in %v", server.Addr().String()) go func() { - err := s.handleExternalConnections(server) + err := s.handleExternalConnections(ctx, server) if err != nil { log.Printf("Error handling external connection: %v", err) } diff --git a/internal/tls/tlsutil.go b/internal/tls/tlsutil.go index cd6895e..d97fa70 100644 --- a/internal/tls/tlsutil.go +++ b/internal/tls/tlsutil.go @@ -20,10 +20,11 @@ func (t TLSStore) Wrap(conn net.Conn) (*tls.Conn, error) { return serv, nil } -func LoadTLSCerts() TLSStore { +func LoadTLSCerts() (TLSStore, error) { cert, err := tls.LoadX509KeyPair("certs/cert.cer", "certs/cert_key.key") if err != nil { - panic(err) + log.Printf("Failed to load certificates: %v", err) + return TLSStore{}, err } conf := &tls.Config{ @@ -32,5 +33,5 @@ func LoadTLSCerts() TLSStore { return TLSStore{ conf: conf, - } + }, nil } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index c037ece..950f919 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -1,8 +1,10 @@ package tui import ( + "errors" "fmt" "image/color" + "strings" "time" @@ -12,6 +14,14 @@ import ( "github.com/KatIsCoding/yatun/internal/message" ) +type FailedInitialConfigError struct { + Err error +} + +func (f *FailedInitialConfigError) Error() string { + return f.Err.Error() +} + type ServerAddress struct { Addr string } @@ -96,6 +106,11 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case SetState: m.state = msg.State m.err = msg.Err + + // var x *FailedInitialConfigError + if _, ok := errors.AsType[*FailedInitialConfigError](msg.Err); ok { + return m, tea.Quit + } case tea.KeyPressMsg: switch msg.String() { case "a":