From a1dcaea6ee8cd74c010da68c0e5984addc066e97 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:07 -0500 Subject: [PATCH 01/51] connmgr: Add context-aware semaphore. This adds a new context-aware semaphore type with Acquire and Release methods for use in upcoming changes that aim to simplify connection limiting by making use of semaphores for blocking until permits become available. --- internal/connmgr/semaphore.go | 36 +++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 internal/connmgr/semaphore.go diff --git a/internal/connmgr/semaphore.go b/internal/connmgr/semaphore.go new file mode 100644 index 000000000..fb7d7eed4 --- /dev/null +++ b/internal/connmgr/semaphore.go @@ -0,0 +1,36 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import "context" + +// semaphore is a simple context-aware channel based semaphore for bounding +// concurrent access. +type semaphore chan struct{} + +// makeSemaphore returns a new semaphore with the given capacity. +func makeSemaphore(n uint32) semaphore { + return make(chan struct{}, n) +} + +// Acquire acquires the semaphore. It blocks until resources are available or +// the provided context is done. It returns true on success and false when the +// context is done before semaphore can be acquired. +func (s semaphore) Acquire(ctx context.Context) bool { + select { + case s <- struct{}{}: + case <-ctx.Done(): + return false + } + return true +} + +// Release release the semaphore. +func (s semaphore) Release() { + select { + case <-s: + default: + } +} From 704eb8509483489f301ad2f8f1045353e3c1b182 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:08 -0500 Subject: [PATCH 02/51] connmgr: Add semaphore tests. This adds tests for the new context-aware semaphore to ensure the acquire, release, and context cancel semantics work as expected. --- internal/connmgr/semaphore_test.go | 109 +++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 internal/connmgr/semaphore_test.go diff --git a/internal/connmgr/semaphore_test.go b/internal/connmgr/semaphore_test.go new file mode 100644 index 000000000..9542176df --- /dev/null +++ b/internal/connmgr/semaphore_test.go @@ -0,0 +1,109 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import ( + "context" + "testing" + "time" +) + +// TestSemaphore ensures the semaphore acquire, release, and context cancel +// semantics are as expected. +func TestSemaphore(t *testing.T) { + // Create a closure that acquires a semaphore with a timeout. + ctx := context.Background() + timedAcquire := func(sem semaphore, timeout time.Duration) bool { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return sem.Acquire(ctx) + } + + // perSemTest describes a test to run against the same semaphore. + type perSemTest struct { + name string // test description + numAcquires uint32 // num to acquire + numReleases uint32 // num to release + } + + tests := []struct { + name string // test description + cap uint32 // capacity of the semaphore + perSemTests []perSemTest // tests to run against same semaphore + want []bool // expected results + }{{ + name: "normal block/release behavior", + cap: 2, + perSemTests: []perSemTest{{ + name: "cap 2 (0 acquired): acquire 3, release 1", + numAcquires: 3, + numReleases: 1, + }, { + name: "cap 2 (1 acquired): acquire 2, release 0", + numAcquires: 2, + numReleases: 0, + }, { + name: "cap 2 (2 acquired): acquire 1, release 2", + numAcquires: 1, + numReleases: 2, + }}, + want: []bool{true, true, false, true, false, false}, + }, { + // Releasing more than acquired ignores the extra release and does not + // influence future ops. + name: "relase more than acquired", + cap: 5, + perSemTests: []perSemTest{{ + name: "cap 5 (0 acquired): acquire 1, release 2", + numAcquires: 1, + numReleases: 2, + }, { + name: "cap 5 (0 acquired): acquire 5, release 1", + numAcquires: 5, + numReleases: 1, + }, { + name: "cap 5 (4 acquired): acquire 2, release 5", + numAcquires: 2, + numReleases: 5, + }}, + want: []bool{true, true, true, true, true, true, true, false}, + }} + + for _, test := range tests { + // Create semaphore with the capacity specified in the test and the + // a slice to hold the results. + sem := makeSemaphore(test.cap) + results := make([]bool, 0, len(test.want)) + + // Perform each sequence of acquires and releases as specified by the + // per semaphore tests. + for _, psTest := range test.perSemTests { + const timeout = 10 * time.Millisecond + for range psTest.numAcquires { + results = append(results, timedAcquire(sem, timeout)) + } + for range psTest.numReleases { + sem.Release() + } + } + + if len(results) != len(test.want) { + t.Errorf("%q: unexpected number of results: got %d, want %d", + test.name, len(results), len(test.want)) + } + for i := range results { + if results[i] != test.want[i] { + t.Errorf("%q: unexpected result for [%d]: got %v, want %v", + test.name, i, results[i], test.want[i]) + } + } + + // Ensure all acquires were released as expected. + if numAcquired := uint32(len(sem)); numAcquired != 0 { + t.Errorf("%q: unexpected final semaphore count: got %v, want %v", + test.name, numAcquired, 0) + } + } +} From 7db9f9e35ae9a93e124a281384d70410d912b3b0 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:09 -0500 Subject: [PATCH 03/51] connmgr: Overhaul to use wrapped conns plus ctx. The existing connection manager code was written well before contexts were introduced. Further, due to the old async model that has now been converted to a synchronous model, it is based around connection requests that have their state atomically updated asynchronously as various things happen. While it has undoubtedly worked well enough for over a decade, it has always been a challenge to add new functionality to it and requires the use of a lot of less than ideal and highly outdated techniques such as polling for state changes. It is also rather brittle in terms of requiring output connections to be manually disconnected in the connection manager after they've been closed to avoid things like leaking goroutines and failing to update target outbound counts. Moreover, it only tracks outgoing connections which ultimately forces a lot of connection-related tasks to be split across different layers instead of residing in the connection manager itself where they more naturally belong. Notably, that split, for all intents and purposes, prevents implementing some desirable more advanced features such as immediate connection shedding, different connection types, and listeners tied to specific network types. With the primary goal of addressing all of the aforementioned points and providing a solid base to work on for adding new features moving forward, this significantly reworks the connection manager to completely get rid of the notion of exposed connection requests in favor of a new custom connection type that wraps the underlying net.Conn. The new wrapped connections automatically handle cleanup when closed and have an associated connection type enum that allows easily distinguishing inbound, outbound, and manual connections as well as supporting new connection types in the future. Another nice feature of the new wrapped connections is they provide efficient access to concrete parsed address types which paves the way for avoiding a lot of constant reparsing, repeated host/port splitting and joining, and generally much more ergonomic immutable address types. Since changing to wrapped connections basically required a rather significant rewrite of large portions of the connection manager anyway, this also takes the opportunity to improve several other aspects of the connection manager in the process such as implementing full context support, full tracking of all connection types by the manager itself, much more robust semaphore-based automatic connection limiting, cleaner persistent connection handling with independent limits, prevention of multiple connections of any type to the same address:port, more useful debug logging, and cleanly closing all connections during shutdown. It is also important to note that the following overall semantics have intentionally been changed versus the existing connection manager: - A maximum of 8 persistent connections is now imposed and they no longer count toward the configured target number of automatic outbound peers to maintain - Duplicate addresses (host:port) are now rejected by the connection manager for all types (inbound, outbound, manual, persistent) - Note that inbound conns from the same IP will necessarily have different ports, so the same max IP limits apply in that case - RPC 'node connect' for all connection attempts now: - Supports the RPC connection and server contexts - Properly handles duplicate address rejection including pending attempts - RPC 'node connect' for non-persistent conn attempts now: - Waits for the connection attempt result before returning - Returns an error if the connection attempt fails - Cancels the connection attempt if the RPC connection is closed before it succeeds - RPC 'node remove' now supports removing a pending connection by its persistent connection ID (since no peer ID exists before a valid connection is established) - It is no longer possible for state transitions to allow things like duplicate addresses or failed cancellation --- internal/connmgr/connmanager.go | 1353 +++++++++++++----- internal/connmgr/connmanager_test.go | 701 ++++----- internal/connmgr/conntype_test.go | 35 + internal/connmgr/error.go | 26 +- internal/connmgr/error_test.go | 8 +- internal/rpcserver/interface.go | 2 +- internal/rpcserver/rpcserver.go | 94 +- internal/rpcserver/rpcserverhandlers_test.go | 8 +- rpcadaptors.go | 159 +- server.go | 133 +- 10 files changed, 1544 insertions(+), 975 deletions(-) create mode 100644 internal/connmgr/conntype_test.go diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 60e1afcc1..83ddf10c5 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -7,19 +7,29 @@ package connmgr import ( "context" + "errors" "fmt" "net" + "strconv" "sync" "sync/atomic" "time" + + "github.com/decred/dcrd/addrmgr/v4" ) -var ( +const ( + // MaxPersistent is the maximum number of persistent connections that can be + // added. Persistent connections do not count towards the automatic + // outbound connection limits. + MaxPersistent = 8 +) - // maxRetryDuration is the max duration of time retrying of a persistent - // connection is allowed to grow to. This is necessary since the retry - // logic uses a backoff mechanism which increases the interval base times - // the number of retries that have been done. +var ( + // maxRetryDuration is the maximum duration a persistent connection retry + // backoff is allowed to grow to. This is necessary since the retry logic + // uses a backoff mechanism which increases the interval base times the + // number of retries that have been done. maxRetryDuration = time.Minute * 5 ) @@ -35,75 +45,166 @@ const ( // defaultTargetOutbound is the default number of outbound connections to // maintain. - defaultTargetOutbound = uint32(8) + defaultTargetOutbound = 8 ) -// ConnState represents the state of the requested connection. -type ConnState uint32 +// ConnectionType specifies the different types of supported connections. +type ConnectionType uint8 -// ConnState can be either pending, established, disconnected or failed. When -// a new connection is requested, it is attempted and categorized as -// established or failed depending on the connection result. An established -// connection which was disconnected is categorized as disconnected. const ( - ConnPending ConnState = iota - ConnEstablished - ConnDisconnected - ConnFailed - ConnCanceled + // ConnTypeInbound indicates the connection was established by a remote + // peer. No further details are known about this connection until a + // handshake takes place. + ConnTypeInbound ConnectionType = iota + + // ConnTypeOutbound indicates a normal outbound connection that was + // established with no additional restrictions imposed on the type of + // information that the local peer/server is willing to relay. + // + // Note that this in no way implies further restrictions may not be + // negotiated depending on the protocol messages exchanged between the two + // peers. + ConnTypeOutbound + + // ConnTypeManual indicates an outbound connection that was manually + // requested via [ConnManager.Connect] or [ConnManager.AddPersistent]. In + // practice, this connection type is the result of requesting manual + // connections via an RPC method (e.g. "node connect") or via command line + // configuration options (e.g. --addpeer and --connect). + ConnTypeManual + + // numConnTypes is the number of connection types. This entry MUST be the + // last entry in the enum. + numConnTypes ) -// ConnReq is the connection request to a network address. If permanent, the -// connection will be retried on disconnection. -type ConnReq struct { - // id is the unique identifier for this connection request. - id atomic.Uint64 +// connTypeStrings is a map of connection types to human-readable names for +// pretty printing. +var connTypeStrings = map[ConnectionType]string{ + ConnTypeInbound: "inbound", + ConnTypeOutbound: "outbound", + ConnTypeManual: "manual", +} - // state is the current connection state for this connection request. - state atomic.Uint32 +// String returns the [ConnectionType] in human-readable form. +func (connType ConnectionType) String() string { + if s, ok := connTypeStrings[connType]; ok { + return s + } - // The following fields are owned by the connection manager and must not - // be accessed without its connection mutex held. + return fmt.Sprintf("Unknown ConnectionType (%d)", uint8(connType)) +} + +// Conn houses information about a managed connection. It is the callers +// responsibility to always ensure [Conn.Close] is called when the connection +// is no longer required. +type Conn struct { + // The following variables are set at the time the instance is created and + // are safe for concurrent access. + // + // net.Conn is the underlying connection. It is embedded which makes all of + // its methods immediately available. // - // retryCount is the number of times a permanent connection request that - // fails to connect has been retried since the last successful connection. + // id is the unique identifier for this connection. // - // conn is the underlying network connection. It will be nil before a - // connection has been established. - retryCount uint32 - conn net.Conn + // connType specifies the connection type. + // + // remoteAddr is the remote address associated with the connection. It is + // a concrete address manager address. + // + // onClose is a callback that will be invoked when the connection is closed. + net.Conn + id uint64 + connType ConnectionType + remoteAddr addrmgr.NetAddress + onClose func() + + // closed houses whether or not the connection has already been closed. + closed atomic.Bool +} - // Addr is the address to connect to. - Addr net.Addr +// newConn returns a new connection given an underlying [net.Conn], connection +// ID, and connection type. +// +// The returned connection is automatically configured to invoke the provided on +// close handler followed by the [Config.OnDisconnection] that was configured +// when initially creating the connection manager when the connection is closed. +// The on close handler is invoked in the same goroutine as the caller of +// [Conn.Close] and [Config.OnDisconnection] is invoked in a separate goroutine. +func newConn(cm *ConnManager, conn net.Conn, id uint64, connType ConnectionType, remoteAddr *addrmgr.NetAddress, onClose func()) *Conn { + c := &Conn{Conn: conn, id: id, connType: connType, remoteAddr: *remoteAddr} + c.onClose = func() { + onClose() + if cm.cfg.OnDisconnection != nil { + go cm.cfg.OnDisconnection(c) + } + } + return c +} - // Permanent specifies whether or not the connection request represents what - // should be treated as a permanent connection, meaning the connection - // manager will try to always maintain the connection including retries with - // increasing backoff timeouts. - Permanent bool +// ID returns a unique identifier for the connection. +// +// This function is safe for concurrent access. +func (c *Conn) ID() uint64 { + return c.id } -// updateState updates the state of the connection request. -func (c *ConnReq) updateState(state ConnState) { - c.state.Store(uint32(state)) +// Close closes the connection. The [Config.OnDisconnection] that was +// configured when initially creating the connection manager will be invoked in +// a separate goroutine prior to closing the underlying connection. +// +// Repeated close attempts are ignored. Closing a connection that has already +// been closed will not return an error. +// +// This function is safe for concurrent access. +func (c *Conn) Close() error { + // Already closed. + if !c.closed.CompareAndSwap(false, true) { + return nil + } + + // Invoke close callback associated with the connection when it's closed. + if c.onClose != nil { + c.onClose() + } + + // Close the underlying connection. + return c.Conn.Close() } -// ID returns a unique identifier for the connection request. -func (c *ConnReq) ID() uint64 { - return c.id.Load() +// RemoteAddr returns the remote address manager network address associated with +// the connection. It returns a [net.Addr] to implement the [net.Conn] +// interface, but the underlying type will be a [*addrmgr.NetAddress]. +func (c *Conn) RemoteAddr() net.Addr { + return &c.remoteAddr } -// State is the connection state of the requested connection. -func (c *ConnReq) State() ConnState { - return ConnState(c.state.Load()) +// Type returns the [ConnectionType] of the connection. +// +// This function is safe for concurrent access. +func (c *Conn) Type() ConnectionType { + return c.connType } -// String returns a human-readable string for the connection request. -func (c *ConnReq) String() string { - if c.Addr == nil || c.Addr.String() == "" { - return fmt.Sprintf("reqid %d", c.ID()) - } - return fmt.Sprintf("%s (reqid %d)", c.Addr, c.ID()) +// pendingConnInfo houses information about a pending connection attempt. +type pendingConnInfo struct { + id uint64 + addr *addrmgr.NetAddress + cancel context.CancelFunc +} + +// persistentEntry houses information about a persistent connection that has +// been added to the connection manager. Once an ID has been assigned, all +// future connections established for the persistent connection will have the +// same ID. This allows it to be uniquely identified and removed later. +type persistentEntry struct { + id uint64 + addr *addrmgr.NetAddress + + // cancel shuts down the goroutine that maintains the persistent connection. + // It is owned by the connection manager and must not be accessed without + // its connection mutex held. + cancel context.CancelFunc } // Config holds the configuration options related to the connection manager. @@ -129,10 +230,13 @@ type Config struct { // This field will not have any effect if the Listeners field is not // also specified since there couldn't possibly be any accepted // connections in that case. - OnAccept func(net.Conn) + OnAccept func(*Conn) - // TargetOutbound is the number of outbound network connections to - // maintain. Defaults to 8. + // TargetOutbound is the number of outbound network connections to maintain + // automatically. Defaults to 8. + // + // Persistent connections do not count against this value. They have their + // own maximum limit defined by [MaxPersistent]. TargetOutbound uint32 // RetryDuration is the duration to wait before retrying connection @@ -141,11 +245,10 @@ type Config struct { // OnConnection is a callback that is fired when a new outbound // connection is established. - OnConnection func(*ConnReq, net.Conn) + OnConnection func(*Conn) - // OnDisconnection is a callback that is fired when an outbound - // connection is disconnected. - OnDisconnection func(*ConnReq) + // OnDisconnection is a callback that is fired when a connection is closed. + OnDisconnection func(*Conn) // GetNewAddress is a way to get an address to make a network connection // to. If nil, no new connections will be made automatically. @@ -161,13 +264,8 @@ type Config struct { // ConnManager provides a manager to handle network connections. type ConnManager struct { - // connReqCount is the number of connection requests that have been made and - // is primarily used to assign unique connection request IDs. - connReqCount atomic.Uint64 - - // assignIDMtx synchronizes the assignment of an ID to a connection request - // with overall connection request count above. - assignIDMtx sync.Mutex + // nextConnID is used to assign unique connection request IDs. + nextConnID atomic.Uint64 // quit is used for lifecycle management of the connection manager. quit chan struct{} @@ -176,424 +274,884 @@ type ConnManager struct { // creating time and treated as immutable after that. cfg Config - // failedAttempts tracks the total number of failed oubound connection - // attempts since the last successful connection made by the connection - // manager. It is primarily used to detect network outages in order to - // impose a retry timeout on achieving the target number of outbound - // connections which prevents runaway failed connection attempt churn. + // runPersistentChan is used to signal the persistent connections handler to + // launch a goroutine that attempts to always maintain an established + // connection with a given address. // - // This field is owned by the connection handler and must not be accessed - // outside of it. - failedAttempts uint64 + // It is a buffered channel with size [MaxPersistent]. + runPersistentChan chan *persistentEntry + + // outboundSem limits the number of active outbound connections. It does + // not apply to persistent connections which are separately limited to + // [MaxPersistent]. + activeOutboundsSem semaphore + + // The fields below this point are all protected by the connection mutex. + connMtx sync.Mutex - // The following fields are used to track the various connections managed - // by the connection manager. They are protected by the associated - // connection mutex. + // persistent tracks all registered persistent connection entries. // - // pending holds all registered connection requests that have yet to - // succeed. + // A persistent connection can be in one of three states: // - // conns represents the set of all active connections. - connMtx sync.Mutex - pending map[uint64]*ConnReq - conns map[uint64]*ConnReq -} + // - Established with the connection instance in the active map + // - Pending with an entry in the pending map + // - Awaiting a retry + // + // Regardless of the state, there will always be an entry in this map. + persistent map[uint64]*persistentEntry -// registerPending registers the provided connection request as a pending -// connection attempt. -// -// This function MUST be called with the connection mutex lock held (writes). -func (cm *ConnManager) registerPending(connReq *ConnReq) { - connReq.updateState(ConnPending) - cm.pending[connReq.ID()] = connReq + // pending holds all registered connection attempts that have yet to + // succeed. + pending map[uint64]*pendingConnInfo + + // active represents the set of all active connections. + active map[uint64]*Conn + + // connIDByAddr provides fast O(1) lookup of connection IDs by address + // (host:port). It is kept in sync with the persistent, pending, and active + // maps and is primarily used to efficiently reject duplicate connections. + connIDByAddr map[string]uint64 } -// newConnReq creates a new connection request and connects to the corresponding -// address. -func (cm *ConnManager) newConnReq(ctx context.Context) { - // Ignore during shutdown. - if ctx.Err() != nil { - return +// checkShutdown returns [ErrShutdown] when the connection manager quit channel +// has been closed. +func (cm *ConnManager) checkShutdown() error { + select { + case <-cm.quit: + const str = "connection manager shutdown" + return MakeError(ErrShutdown, str) + default: } + return nil +} - c := &ConnReq{} - c.id.Store(cm.connReqCount.Add(1)) +// stdlibNetAddrToAddrMgrNetAddr converts the provided standard lib [net.Addr] +// to a concrete address manager address. +func stdlibNetAddrToAddrMgrNetAddr(addr net.Addr) (*addrmgr.NetAddress, error) { + host, portStr, err := net.SplitHostPort(addr.String()) + if err != nil { + str := fmt.Sprintf("unable to split address %q", addr) + return nil, MakeError(ErrUnsupportedAddr, str) + } + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + str := fmt.Sprintf("invalid port for address %q", addr) + return nil, MakeError(ErrUnsupportedAddr, str) + } - // Register the pending connection attempt so it can be canceled via the - // [ConnManager.Remove] method. - cm.connMtx.Lock() - cm.registerPending(c) - cm.connMtx.Unlock() + addrType, addrBytes := addrmgr.EncodeHost(host) + if addrType == addrmgr.UnknownAddressType { + str := fmt.Sprintf("unable to determine address type for %q", addr) + return nil, MakeError(ErrUnsupportedAddr, str) + } - addr, err := cm.cfg.GetNewAddress() + now := time.Unix(time.Now().Unix(), 0) + netAddr, err := addrmgr.NewNetAddressFromParams(addrType, addrBytes, + uint16(port), now, 0) if err != nil { - cm.connMtx.Lock() - cm.handleFailedPending(ctx, c, err) - cm.connMtx.Unlock() - return + return nil, MakeError(ErrUnsupportedAddr, err.Error()) } + return netAddr, nil +} - c.Addr = addr +// addPendingInfo adds information about a pending connection attempt to the +// local state. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) addPendingInfo(info *pendingConnInfo) { + cm.pending[info.id] = info + if _, ok := cm.persistent[info.id]; !ok { + cm.connIDByAddr[info.addr.String()] = info.id + } +} - cm.Connect(ctx, c) +// removePendingInfo removes a pending connection attempt from the local state. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) removePendingInfo(info *pendingConnInfo) { + delete(cm.pending, info.id) + if _, ok := cm.persistent[info.id]; !ok { + delete(cm.connIDByAddr, info.addr.String()) + } } -// handleFailedConn handles a connection failed due to a disconnect or any other -// failure. Permanent connection requests are retried after the configured -// retry duration. A new connection request is created if required. +// addActiveConn adds an established connection to the local state. // -// In the event there have been [maxFailedAttempts] failed successive attempts, -// new connections will be retried after the configured retry duration. +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) addActiveConn(conn *Conn) { + cm.active[conn.id] = conn + if _, ok := cm.persistent[conn.id]; !ok { + cm.connIDByAddr[conn.remoteAddr.String()] = conn.id + } +} + +// removeActiveConn removes an established connection from the local state. It +// has no effect if the connection has already been removed from the active map. // -// This function MUST be called with the connection lock held (writes). -func (cm *ConnManager) handleFailedConn(ctx context.Context, c *ConnReq) { - // Ignore during shutdown. - select { - case <-cm.quit: - return - case <-ctx.Done(): +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) removeActiveConn(conn *Conn) { + // The active connection might have already been removed before releasing + // the mutex to call [Conn.Close]. + if _, ok := cm.active[conn.id]; !ok { return - default: } - // Reconnect to permanent connection requests after a retry timeout with - // an increasing backoff up to a max for repeated failed attempts. - if c.Permanent { - c.retryCount++ - retryWait := time.Duration(c.retryCount) * cm.cfg.RetryDuration - retryWait = min(retryWait, maxRetryDuration) - log.Debugf("Retrying connection to %v in %v", c, retryWait) - go func() { - select { - case <-time.After(retryWait): - cm.Connect(ctx, c) - case <-cm.quit: - case <-ctx.Done(): - } - }() - return + delete(cm.active, conn.id) + if _, ok := cm.persistent[conn.id]; !ok { + delete(cm.connIDByAddr, conn.remoteAddr.String()) } +} - // Nothing more to do when the method to automatically get new addresses - // to connect to isn't configured. - if cm.cfg.GetNewAddress == nil { - return +// addPersistentEntry adds a persistent connection entry to the local state. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) addPersistentEntry(entry *persistentEntry) { + cm.persistent[entry.id] = entry + cm.connIDByAddr[entry.addr.String()] = entry.id +} + +// removePersistentEntry removes a persistent connection entry from the local +// state. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) removePersistentEntry(entry *persistentEntry) { + delete(cm.persistent, entry.id) + _, pending := cm.pending[entry.id] + _, active := cm.active[entry.id] + if !pending && !active { + delete(cm.connIDByAddr, entry.addr.String()) } +} - // Wait to attempt new connections when there are too many successive - // failures. This prevents massive connection spam when no connections can - // be made, such as a network outtage. - cm.failedAttempts++ - if cm.failedAttempts >= maxFailedAttempts { - log.Debugf("Max failed connection attempts reached: [%d] -- retrying "+ - "connection in: %v", maxFailedAttempts, cm.cfg.RetryDuration) - go func() { - select { - case <-time.After(cm.cfg.RetryDuration): - cm.newConnReq(ctx) - case <-cm.quit: - case <-ctx.Done(): - } - }() - return +// rejectConnectedAddr returns an error if there is already either an +// established connection to the provided address or a pending attempt to +// connect to it. Persistent connections in the retry state are intentionally +// not detected. +// +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) rejectConnectedAddr(addr *addrmgr.NetAddress) error { + connID, ok := cm.connIDByAddr[addr.String()] + if !ok { + return nil } - // Otherwise, attempt a new connection with a new address now. - go cm.newConnReq(ctx) + if _, ok := cm.pending[connID]; ok { + str := fmt.Sprintf("a pending connection to %s already exists", addr) + return MakeError(ErrAlreadyPending, str) + } + if _, ok := cm.active[connID]; ok { + str := fmt.Sprintf("a connection to %s is already established", addr) + return MakeError(ErrAlreadyConnected, str) + } + return nil } -// handleFailedPending handles failed pending connection requests. Connection -// requests that were canceled are ignored. Otherwise, their state is updated -// to mark it failed and it is passed along to [ConnManager.handleFailedConn] to -// possibly retry or be reused for a new connection depending on settings. +// findPersistentAddrID attempts to find and return the persistent connection ID +// associated with the passed address. The bool return indicates whether or not +// it was found. // -// This function MUST be called with the connection lock held (writes). -func (cm *ConnManager) handleFailedPending(ctx context.Context, c *ConnReq, failedErr error) { - if _, ok := cm.pending[c.ID()]; !ok { - log.Debugf("Ignoring connection for canceled conn req: %v", c) - return +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) findPersistentAddrID(addr net.Addr) (uint64, bool) { + connID, ok := cm.connIDByAddr[addr.String()] + if !ok { + return 0, false + } + + entry, ok := cm.persistent[connID] + if !ok { + return 0, false } - c.updateState(ConnFailed) - log.Debugf("Failed to connect to %v: %v", c, failedErr) - cm.handleFailedConn(ctx, c) + return entry.id, true } -// Connect assigns an id and dials a connection to the address of the connection -// request using the provided context and the dial function configured when -// initially creating the connection manager. +// rejectPersistentAddr returns an error if there is already a persistent +// connection entry for the given address. // -// The connection attempt will be ignored if the connection manager has been -// shutdown by canceling the lifecycle context the Run method was invoked with -// or the provided connection request is already in the failed state. +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) rejectPersistentAddr(addr *addrmgr.NetAddress) error { + if _, ok := cm.findPersistentAddrID(addr); ok { + str := fmt.Sprintf("a persistent connection for %s already exists", addr) + return MakeError(ErrDuplicatePersistent, str) + } + return nil +} + +// rejectDuplicateAddr returns an error if there is already a persistent +// connection entry, a pending connection attempt, or an established connection +// for the given address. // -// Note that the context parameter to this function and the lifecycle context -// may be independent. -func (cm *ConnManager) Connect(ctx context.Context, c *ConnReq) { +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) rejectDuplicateAddr(addr *addrmgr.NetAddress) error { + if err := cm.rejectPersistentAddr(addr); err != nil { + return err + } + if err := cm.rejectConnectedAddr(addr); err != nil { + return err + } + return nil +} + +// dial attempts to connect to the provided address and returns a connection +// configured with the provided params on success. +// +// A new globally unique connection ID is assigned unless one is provided by +// passing a non-nil value in the persistent connection ID parameter. This +// allows persistent connections to retain the same ID across reconnects. +// +// Attempts to dial addresses that are already connected, pending, or (in most +// cases) persistent will return an error as described below. Only established +// and pending connections are rejected when a non-nil persistent connection ID +// is passed. +// +// On success, the returned connection is configured to remove itself from the +// set of all active connections and invoke the provided on close callback (if +// set) when it is closed. +// +// On failure, the provided on close callback (when non-nil) will be invoked +// prior to returning. +// +// In addition to errors returned by [Config.Dial], the following errors are +// possible: +// +// - [ErrDuplicatePersistent] when a persistent connection already exists for +// the address and no persistent connection ID is provided +// - [ErrAlreadyPending] when there is already a pending connection attempt +// to the address +// - [ErrAlreadyConnected] when there is already an established connection to +// the address +// - [ErrShutdown] when the connection manager is shutting down +// - [context.Canceled] or [context.DeadlineExceeded] depending on the +// provided context or when the dialer fails to establish a connection +// before the timeout configured for the connection manager +// +// This function is safe for concurrent access. +func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType ConnectionType, onClose func(), persistentConnID *uint64) (*Conn, error) { + var skipOnClose bool + defer func() { + if !skipOnClose && onClose != nil { + onClose() + } + }() + // Ignore during shutdown and when caller provided context is already // canceled. - select { - case <-cm.quit: - return - default: + if err := cm.checkShutdown(); err != nil { + return nil, err } if ctx.Err() != nil { - return + return nil, ctx.Err() } - // During the time we wait for retry there is a chance that this - // connection was already cancelled. - if c.State() == ConnCanceled { - log.Debugf("Ignoring connect for canceled connreq=%v", c) - return + rAddr, err := stdlibNetAddrToAddrMgrNetAddr(addr) + if err != nil { + return nil, err } - // Assign an ID and register the pending connection attempt when an ID has - // not already been assigned so it can be canceled via the - // [ConnManager.Remove] method. + // Reject attempts to dial addresses that are already connected (or in the + // process of it). Additionally, reject attempts to dial existing + // persistent addresses unless a persistent connection ID was provided + // indicating the dial is specifically for a persistent connection. // - // Note that the assignment of the ID and the overall request count need to - // be synchronized. So long as this is the only place an existing conn - // request ID is updated and this method is not called concurrently on the - // same conn request, no race could occur. However, those preconditions - // would be easy to inadvertently violate via updates to the code, so the - // mutex is added here for additional safety. - var doRegisterPending bool - cm.assignIDMtx.Lock() - if c.ID() == 0 { - c.id.Store(cm.connReqCount.Add(1)) - doRegisterPending = true - } - connReqID := c.ID() - cm.assignIDMtx.Unlock() - if doRegisterPending { - cm.connMtx.Lock() - cm.registerPending(c) + // This needs to be done under the same lock as adding a pending entry to + // prevent the possibility of two simultaneous attempts logic racing. + rejectFn := cm.rejectDuplicateAddr + if persistentConnID != nil { + rejectFn = cm.rejectConnectedAddr + } + cm.connMtx.Lock() + if err := rejectFn(rAddr); err != nil { cm.connMtx.Unlock() + log.Debugf("Rejected connection: %v", err) + return nil, err } - log.Debugf("Attempting to connect to %v", c) - - // Attempt to establish the connection to the address associated with the - // connection request. Apply a timeout if requested. + // Apply a dial timeout if requested. Otherwise, use a regular cancel + // context to support canceling the pending connection later. + var cancel context.CancelFunc if cm.cfg.DialTimeout != 0 { - var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, cm.cfg.DialTimeout) - defer cancel() + } else { + ctx, cancel = context.WithCancel(ctx) } - var conn net.Conn - conn, err := cm.cfg.Dial(ctx, c.Addr.Network(), c.Addr.String()) - if err != nil { + defer cancel() + + // Register the pending connection attempt and defer its removal to ensure + // it is always removed on failure. + var connID uint64 + if persistentConnID != nil { + connID = *persistentConnID + } else { + connID = cm.nextConnID.Add(1) + } + info := &pendingConnInfo{connID, rAddr, cancel} + cm.addPendingInfo(info) + cm.connMtx.Unlock() + defer func() { cm.connMtx.Lock() - cm.handleFailedPending(ctx, c, err) + if _, ok := cm.pending[connID]; ok { + cm.removePendingInfo(info) + } cm.connMtx.Unlock() - return + }() + + log.Debugf("Attempting to connect to %v (id: %d, type: %v)", addr, connID, + connType) + + // Attempt to establish the connection to the address. + netConn, err := cm.cfg.Dial(ctx, addr.Network(), addr.String()) + if err != nil { + var logErrStr string + switch { + case errors.Is(err, context.DeadlineExceeded): + logErrStr = fmt.Sprintf("no response for %v", cm.cfg.DialTimeout) + case errors.Is(err, context.Canceled): + // Override the error with the shutdown error instead when that is + // the upstream cause of the context cancel. + if sErr := cm.checkShutdown(); sErr != nil { + err = sErr + break + } + logErrStr = "attempt manually canceled" + } + if logErrStr == "" { + logErrStr = err.Error() + } + log.Debugf("Failed to connect to %v: %v", addr, logErrStr) + return nil, err } + // Ignore any connections that succeed after they were manually canceled. cm.connMtx.Lock() - defer cm.connMtx.Unlock() + if _, ok := cm.pending[connID]; !ok { + cm.connMtx.Unlock() + netConn.Close() + log.Debugf("Ignoring canceled connection %v (id: %d, type: %v)", addr, + connID, connType) + return nil, context.Canceled + } - if _, ok := cm.pending[connReqID]; !ok { - conn.Close() - log.Debugf("Ignoring connection for canceled connreq=%v", c) - return + // Remove the pending entry under the lock. This ensures the maps are + // mutually exclusive for a given id. + cm.removePendingInfo(info) + + // Successful return means the on close callback is not invoked until the + // connection is closed. + skipOnClose = true + + // Setup a close callback to remove the connection from the map that tracks + // all active connections when the connection is closed and also to invoke + // the close callback provided by the caller when specified. + var conn *Conn + dialOnClose := func() { + cm.connMtx.Lock() + cm.removeActiveConn(conn) + cm.connMtx.Unlock() + if onClose != nil { + onClose() + } + log.Debugf("Disconnected from %v (id: %d, type: %v)", addr, connID, + connType) } - c.updateState(ConnEstablished) - c.conn = conn - cm.conns[connReqID] = c - log.Debugf("Connected to %v", c) - c.retryCount = 0 - cm.failedAttempts = 0 - delete(cm.pending, connReqID) + // Create a new connection instance with the connection ID and type and add + // an entry to the map that tracks all active connections. + conn = newConn(cm, netConn, connID, connType, rAddr, dialOnClose) + cm.addActiveConn(conn) + cm.connMtx.Unlock() + + log.Debugf("Connected to %v (id: %d, type: %v)", addr, connID, connType) + return conn, nil +} +// Connect assigns an ID and dials a connection to the provided address using +// the provided context and the dial function configured when initially creating +// the connection manager. +// +// Attempts to dial addresses that already have an established, pending, or +// persistent connection will return an error as described below. +// +// The connection will have type [ConnTypeManual]. +// +// Note that the context parameter to this function and the lifecycle context +// may be independent. +// +// In addition to errors returned by the underlying dialer, the following errors +// are possible: +// +// - [ErrDuplicatePersistent] when a persistent connection already exists for +// the address (regardless of its current state) +// - [ErrAlreadyPending] when there is already a pending connection attempt +// to the address +// - [ErrAlreadyConnected] when there is already an established connection to +// the address +// - [ErrShutdown] when the connection manager is shutting down +// - [context.Canceled] or [context.DeadlineExceeded] depending on the +// provided context or when the dialer fails to establish a connection +// before the timeout configured for the connection manager +// +// This function is safe for concurrent access. +func (cm *ConnManager) Connect(ctx context.Context, addr net.Addr) (*Conn, error) { + conn, err := cm.dial(ctx, addr, ConnTypeManual, nil, nil) + if err != nil { + return nil, err + } if cm.cfg.OnConnection != nil { - go cm.cfg.OnConnection(c, conn) + go cm.cfg.OnConnection(conn) } + return conn, nil } -// handleDisconnected handles a connection that has been disconnected. +// Disconnect either disconnects the connection corresponding to the given +// connection id or cancels any pending attempts associated with it. Persistent +// connections will be retried with an increasing backoff duration. // -// This function MUST be called with the connection mutex held (writes). -func (cm *ConnManager) handleDisconnected(id uint64, retry bool) { - // Mark the connection request as canceled and remove it from the pending - // connections when it is still pending. Since the connection attempt is - // taking place asynchronously, this ensures any later successful connection - // is ignored. - connReq, ok := cm.pending[id] - if ok { - connReq.updateState(ConnCanceled) - log.Debugf("Canceling: %v", connReq) - delete(cm.pending, id) - } - - connReq, ok = cm.conns[id] - if !ok { - log.Errorf("Unknown connid=%d", id) - return +// This function is safe for concurrent access. +func (cm *ConnManager) Disconnect(id uint64) error { + // Cancel and remove pending entries. Even though the pending entry will be + // removed from the map regardless by the dialer, doing it now ensures that + // any connections that are already in progress and later succeed are + // ignored. + cm.connMtx.Lock() + if info, ok := cm.pending[id]; ok { + info.cancel() + cm.removePendingInfo(info) + cm.connMtx.Unlock() + return nil } - // Close the underlying connection and invoke the associated callback (if - // assigned). - log.Debugf("Disconnected from %v", connReq) - delete(cm.conns, id) - if connReq.conn != nil { - connReq.conn.Close() - } - if cm.cfg.OnDisconnection != nil { - go cm.cfg.OnDisconnection(connReq) + conn := cm.active[id] + if conn != nil { + cm.connMtx.Unlock() + conn.Close() // Close requires the conn mutex. + return nil } + _, isPersistent := cm.persistent[id] + cm.connMtx.Unlock() - // Mark the associated connection request as disconnected and return when no - // further attempts will be made now that all internal state has been - // cleaned up. - if !retry { - connReq.updateState(ConnDisconnected) - return + // Not found in active or pending, but it might still be a persistent conn + // waiting to retry. No error in that case. + if isPersistent { + return nil } - // Otherwise, attempt a reconnection when the associated connection request - // is marked as permanent or there are not already enough outbound peers to - // satisfy the target number of outbound peers. - numConns := uint32(len(cm.conns)) - if connReq.Permanent || numConns < cm.cfg.TargetOutbound { - // The connection request is reused for permanent ones, so add it back - // to the pending map in that case so that subsequent processing of - // connections and failures do not ignore the request. - if connReq.Permanent { - cm.registerPending(connReq) - log.Debugf("Reconnecting to %v", connReq) - } - - // A background context is the only viable choice here. It is not - // ideal, but it is acceptable, because, ultimately, this context is - // really only used for persistent peers when they retry and persistent - // peers are not tied to a specific context anyway. They are instead - // removed by other means. Due to that, there also is no machinery to - // cancel a given persistent peer from a given context anyway. - // - // Future work ideally should refactor the persistent peer handling to - // have proper full context support. - cm.handleFailedConn(context.Background(), connReq) - } + str := fmt.Sprintf("no entries with id %d exist", id) + return MakeError(ErrNotFound, str) } -// Disconnect disconnects the connection corresponding to the given connection -// id. Permanent connections will be retried with an increasing backoff -// duration. +// Remove closes, cancels, or removes the connection corresponding to the given +// connection id. // -// This function is safe for concurrent access. -func (cm *ConnManager) Disconnect(id uint64) { - cm.connMtx.Lock() - cm.handleDisconnected(id, true) - cm.connMtx.Unlock() -} - -// Remove removes the connection corresponding to the given connection id from -// known connections. +// This function may be used for all connections states and types, including +// established, pending, and persistent connections. // -// NOTE: This method can also be used to cancel a pending connection attempt -// that hasn't yet succeeded. +// Connections that are already established are closed and connection attempts +// that are still pending are canceled. Persistent connections are additionally +// removed so that no future retries will occur. // // This function is safe for concurrent access. -func (cm *ConnManager) Remove(id uint64) { +func (cm *ConnManager) Remove(id uint64) error { + // When the ID is for a persistent connection, cancel the associated context + // and remove it from the persistent map to prevent future retries. cm.connMtx.Lock() - cm.handleDisconnected(id, false) + entry, isPersistent := cm.persistent[id] + if isPersistent { + cm.removePersistentEntry(entry) + if entry.cancel != nil { + entry.cancel() + } + log.Debugf("Removed persistent connection to %v (id %d)", entry.addr, + entry.id) + } + + // Cancel and remove pending entries. Even though the pending entry will be + // removed from the map regardless by the dialer, doing it now ensures that + // any connections that are already in progress and later succeed are + // ignored. + if info, ok := cm.pending[id]; ok { + info.cancel() + cm.removePendingInfo(info) + cm.connMtx.Unlock() + return nil + } + + // Close active connections and remove the entry from the active map. + // + // Even though the connection close handler would remove it from the map, it + // needs to be removed under same lock as removals from the persistent map + // to prevent the possibility of two simultaneous attempts logic racing. + if conn, ok := cm.active[id]; ok { + cm.removeActiveConn(conn) + cm.connMtx.Unlock() + conn.Close() // Close requires the conn mutex. + return nil + } cm.connMtx.Unlock() + + // Not found in active or pending, but no error if it was a removed + // persistent conn. + if isPersistent { + return nil + } + + str := fmt.Sprintf("no entries with id %d exist", id) + return MakeError(ErrNotFound, str) } -// findPendingByAddr attempts to find and return the pending connection request -// associated with the provided address. It returns nil if no matching request -// is found. -// -// This function MUST be called with the connection mutex held (writes). -func (cm *ConnManager) findPendingByAddr(addr net.Addr) *ConnReq { - pendingAddr := addr.String() - for _, req := range cm.pending { - if req == nil || req.Addr == nil { +// inboundStdlibNetAddrToAddrMgrAddr converts the provided standard library +// [net.Addr] that is expected to be from an inbound connection to a concrete +// address manager address. +func inboundStdlibNetAddrToAddrMgrAddr(addr net.Addr) (*addrmgr.NetAddress, error) { + // Fast path for inbounds since they will almost always be one of these + // given they are created by [net.Listener.Accept]. + switch a := addr.(type) { + case *net.TCPAddr: + return addrmgr.NewNetAddressFromIPPort(a.IP, uint16(a.Port), 0), nil + case *net.UDPAddr: + return addrmgr.NewNetAddressFromIPPort(a.IP, uint16(a.Port), 0), nil + } + + // Fall back to slower string parsing. + return stdlibNetAddrToAddrMgrNetAddr(addr) +} + +// listenHandler accepts incoming connections on a given listener. It must be +// run as a goroutine. +func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) { + log.Infof("Server listening on %s", listener.Addr()) + defer log.Tracef("Listener handler done for %s", listener.Addr()) + + for ctx.Err() == nil { + netConn, err := listener.Accept() + if err != nil { + // Only log the error if not forcibly shutting down. + if ctx.Err() == nil { + log.Errorf("Can't accept connection: %v", err) + } + continue + } + + rAddr, err := inboundStdlibNetAddrToAddrMgrAddr(netConn.RemoteAddr()) + if err != nil { + log.Warnf("Dropped connection from %v: failed to parse address", + netConn.RemoteAddr()) + netConn.Close() continue } - if pendingAddr == req.Addr.String() { - return req + + // Reject connections with the same host:port as any existing pending, + // established, or persistent connections. Note that this does NOT + // prevent multiple connections from the same host given they typically + // will be coming from different ports. + // + // The aforementioned behavior is intentional as it allows connections + // from the same host to be independently limited to more than one + // elsewhere. + cm.connMtx.Lock() + if err := cm.rejectDuplicateAddr(rAddr); err != nil { + cm.connMtx.Unlock() + log.Debugf("Dropped connection from %v: %v", rAddr, err) + netConn.Close() + continue } + cm.connMtx.Unlock() + + go func(netConn net.Conn) { + // Create a new connection instance with the next globally unique + // connection ID, add an entry to the map that tracks all active + // connections, and invoke the configured accept callback with it. + // + // Also set a close callback to remove the connection from the map + // when it is closed. + id := cm.nextConnID.Add(1) + const connType = ConnTypeInbound + var conn *Conn + onClose := func() { + cm.connMtx.Lock() + cm.removeActiveConn(conn) + cm.connMtx.Unlock() + log.Debugf("Disconnected from %v (id: %d, type: %v)", rAddr, id, + connType) + } + conn = newConn(cm, netConn, id, connType, rAddr, onClose) + cm.connMtx.Lock() + cm.addActiveConn(conn) + cm.connMtx.Unlock() + log.Debugf("Accepted connection from %v (id: %d, type: %v)", rAddr, + id, connType) + cm.cfg.OnAccept(conn) + }(netConn) } - return nil } -// CancelPending removes the connection corresponding to the given address -// from the list of pending failed connections. +// AddPersistent adds an address the connection manager will attempt to always +// maintain an established connection with until the persistent connection entry +// is removed via [ConnManager.Remove] or the context associated with +// [ConnManager.Run] is canceled. +// +// When the associated connection is dropped, it will be retried with an +// increasing backoff, up to a maximum for repeated failed attempts. +// +// A maximum of [MaxPersistent] connections may be added. Attempting to add any +// more will return [ErrMaxPersistent]. +// +// Adding a duplicate persistent address will return [ErrDuplicatePersistent] +// and adding addresses that already have an established or pending connection +// will return [ErrAlreadyConnected] or [ErrAlreadyPending], respectively. +// +// An ID is returned that uniquely identifies the persistent connection. All +// future connections established will have the same ID. // -// Returns an error if the connection manager is stopped or there is no pending -// connection for the given address. -func (cm *ConnManager) CancelPending(addr net.Addr) error { +// Persistent connections do not count against [Config.TargetOutbound]. +// +// Note that the actual connections to the address happen asynchronously and +// will have type [ConnTypeManual]. Established connections will invoke the +// [Config.OnConnection] callback that was configured when initially creating +// the connection manager. +// +// Since connections happen asynchronously, the error only indicates issues with +// adding the persistent connection entry. +// +// The persistent connection may be removed by passing the returned connection +// ID to [ConnManager.Remove]. +// +// This function is safe for concurrent access. +func (cm *ConnManager) AddPersistent(addr net.Addr) (uint64, error) { cm.connMtx.Lock() defer cm.connMtx.Unlock() - connReq := cm.findPendingByAddr(addr) - if connReq == nil { - str := fmt.Sprintf("no pending connection to %v", addr) - return MakeError(ErrNotFound, str) + if len(cm.persistent)+1 > MaxPersistent { + str := fmt.Sprintf("a maximum of %d persistent connections is allowed", + MaxPersistent) + return 0, MakeError(ErrMaxPersistent, str) } - delete(cm.pending, connReq.ID()) - connReq.updateState(ConnCanceled) - log.Debugf("Canceled pending connection to %v", addr) - return nil + rAddr, err := stdlibNetAddrToAddrMgrNetAddr(addr) + if err != nil { + return 0, err + } + + if err := cm.rejectDuplicateAddr(rAddr); err != nil { + return 0, err + } + + entry := &persistentEntry{id: cm.nextConnID.Add(1), addr: rAddr} + cm.addPersistentEntry(entry) + log.Debugf("Added persistent connection to %v (id: %d)", addr, entry.id) + + // The channel is buffered with the max allowed persistent conns, so there + // is no possibility of blocking here. This approach allows persistent + // peers to be added both before and after the connection manager is running + // without starting the goroutines before it is running. + cm.runPersistentChan <- entry + return entry.id, nil } -// ForEachConnReq calls the provided function with each connection request known -// to the connection manager, including pending requests. Returning an error -// from the provided function will stop the iteration early and return said -// error from this function. +// IsPersistent returns whether or not the provided connection id belongs to a +// persistent connection. // // This function is safe for concurrent access. +func (cm *ConnManager) IsPersistent(id uint64) bool { + cm.connMtx.Lock() + _, ok := cm.persistent[id] + cm.connMtx.Unlock() + return ok +} + +// FindPersistentAddrID attempts to find and return the persistent connection ID +// associated with the passed address. The bool return indicates whether or not +// it was found. // -// NOTE: This must not call any other connection manager methods during -// iteration or it will result in a deadlock. -func (cm *ConnManager) ForEachConnReq(f func(c *ConnReq) error) error { +// This function is safe for concurrent access. +func (cm *ConnManager) FindPersistentAddrID(addr net.Addr) (uint64, bool) { cm.connMtx.Lock() - defer cm.connMtx.Unlock() + id, ok := cm.findPersistentAddrID(addr) + cm.connMtx.Unlock() + return id, ok +} - var err error - for _, connReq := range cm.pending { - err = f(connReq) - if err != nil { - return err +// runPersistent attempts to maintain a persistent connection to the provided +// address until the passed context is canceled. +// +// When the associated connection is dropped, it will be retried with an +// increasing backoff, up to a maximum for repeated failed attempts. +// +// This MUST be run as a goroutine. +func (cm *ConnManager) runPersistent(ctx context.Context, connID uint64, addr net.Addr) { + // Ensure the connection is closed when the goroutine exits. + var conn *Conn + defer func() { + if conn != nil { + conn.Close() } + }() + + // Setup a callback that notifies a disconnect channel for use below and + // start with the channel signaled. + disconnected := make(chan struct{}, 1) + disconnected <- struct{}{} + onClose := func() { + disconnected <- struct{}{} } - for _, connReq := range cm.conns { - err = f(connReq) + + var retryCount uint32 + var retryAfter <-chan time.Time + var lastAttempt time.Time + for { + // Wait for disconnect or retry timer when it's set. + select { + case <-ctx.Done(): + return + case <-cm.quit: + return + case <-retryAfter: + retryAfter = nil + case <-disconnected: + // Wait to retry any time the connection was not maintained for at + // least a single retry interval. + // + // This approach is used over only incrementing the retry count when + // the dial fails to effectively rate limit the attempts with an + // increasing backoff regardless of the reason a stable connection + // was not maintained. + // + // For example, the remote might repeatedly reject the peer for a + // variety of reasons (max limits, not enough peers of a desired + // type, etc) after a successful connection is made. + if !lastAttempt.IsZero() && time.Since(lastAttempt) < cm.cfg.RetryDuration { + // Reconnect after a retry timeout with an increasing backoff up + // to a max for repeated failed attempts. + const maxUint32 = 1<<32 - 1 + if retryCount < maxUint32 { + retryCount++ + } + retryWait := time.Duration(retryCount) * cm.cfg.RetryDuration + retryWait = min(retryWait, maxRetryDuration) + log.Debugf("Retrying connection to %v in %v (retries %d)", addr, + retryWait, retryCount) + retryAfter = time.After(retryWait) + continue + } + + // A connection succeeded and was maintained for at least a single + // retry interval. + // + // Clear the retry state. + retryCount = 0 + retryAfter = nil + } + + lastAttempt = time.Now() + var err error + conn, err = cm.dial(ctx, addr, ConnTypeManual, onClose, &connID) if err != nil { - return err + if ctx.Err() != nil { + return + } + + // Retry, potentially after a timeout with backoff. + continue + } + + // Successful connection. + if cm.cfg.OnConnection != nil { + go cm.cfg.OnConnection(conn) } } - return nil } -// listenHandler accepts incoming connections on a given listener. It must be -// run as a goroutine. -func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) { - log.Infof("Server listening on %s", listener.Addr()) +// persistentConnsHandler handles launching individual goroutines for persistent +// connections. +func (cm *ConnManager) persistentConnsHandler(ctx context.Context) { + for { + select { + case entry := <-cm.runPersistentChan: + pCtx, cancel := context.WithCancel(ctx) + cm.connMtx.Lock() + entry.cancel = cancel + cm.connMtx.Unlock() + go cm.runPersistent(pCtx, entry.id, entry.addr) + + case <-ctx.Done(): + return + } + } +} + +// targetOutboundHandler attempts to automatically maintain the target number of +// outbound connections configured via [Config.TargetOutbound] when initially +// creating the connection manager. +// +// This MUST be run as a goroutine. +func (cm *ConnManager) targetOutboundHandler(ctx context.Context) { + log.Trace("Starting target outbound handler") + defer log.Trace("Target outbound handler done") + + // failedAttempts tracks the total number of failed outbound connection + // attempts since the last successful connection. It is primarily used to + // detect network outages in order to impose a retry timeout on achieving + // the target number of outbound connections which prevents runaway failed + // connection attempt churn. + // + // Overflow is not checked since it would be virtually impossible to hit + // anywhere max uint64 in practice and even if it ever happened, the only + // consequence would potentially be a few extra retries before it hit the + // max failures again. + var failedAttempts atomic.Uint64 + for ctx.Err() == nil { - conn, err := listener.Accept() - if err != nil { - // Only log the error if not forcibly shutting down. - if ctx.Err() == nil { - log.Errorf("Can't accept connection: %v", err) + // Pause automatic outbound connections for a retry timeout after too + // many failed connection attempts. The network very likely has become + // temporarily unreachable. + if failedAttempts.Load() >= maxFailedAttempts { + log.Debugf("Max failed connection attempts reached [%d] -- "+ + "pausing connections for %v", maxFailedAttempts, + cm.cfg.RetryDuration) + + select { + case <-time.After(cm.cfg.RetryDuration): + case <-cm.quit: + return + case <-ctx.Done(): + return } + } + + // Wait for a permit to make another outbound connection. + if !cm.activeOutboundsSem.Acquire(ctx) { + return + } + + addr, err := cm.cfg.GetNewAddress() + if err != nil { + failedAttempts.Add(1) + log.Debugf("Failed to get address for outbound connection: %v", err) + cm.activeOutboundsSem.Release() continue } - go cm.cfg.OnAccept(conn) - } - log.Tracef("Listener handler done for %s", listener.Addr()) + go func(addr net.Addr) { + onClose := cm.activeOutboundsSem.Release + conn, err := cm.dial(ctx, addr, ConnTypeOutbound, onClose, nil) + if err != nil { + failedAttempts.Add(1) + return + } + + failedAttempts.Store(0) + if cm.cfg.OnConnection != nil { + go cm.cfg.OnConnection(conn) + } + }(addr) + } } // Run starts the connection manager along with its configured listeners and // begins connecting to the network. It blocks until the provided context is -// cancelled. +// canceled. func (cm *ConnManager) Run(ctx context.Context) { log.Trace("Starting connection manager") + defer log.Trace("Connection manager stopped") // Start all the listeners so long as the caller requested them and provided // a callback to be invoked when connections are accepted. @@ -610,27 +1168,56 @@ func (cm *ConnManager) Run(ctx context.Context) { }(listener) } - // Start enough outbound connections to reach the target number when not - // in manual connect mode. + // Start persistent connections handler which starts individual goroutines + // for each persistent connection already added and any newly added ones + // later. + wg.Add(1) + go func() { + cm.persistentConnsHandler(ctx) + wg.Done() + }() + + // Start outbound connection handler to maintain the target number of + // normal outbound connections when not in manual connect mode. if cm.cfg.GetNewAddress != nil { - curConnReqCount := cm.connReqCount.Load() - for i := curConnReqCount; i < uint64(cm.cfg.TargetOutbound); i++ { - go cm.newConnReq(ctx) - } + wg.Add(1) + go func() { + cm.targetOutboundHandler(ctx) + wg.Done() + }() } - // Stop all the listeners and shutdown the connection manager when the - // context is cancelled. There will not be any listeners if listening is - // disabled. + // Shutdown the connection manager when the context is canceled. <-ctx.Done() close(cm.quit) + + // Stop all the listeners. There will not be any listeners if listening is + // disabled. for _, listener := range listeners { // Ignore the error since this is shutdown and there is no way // to recover anyways. _ = listener.Close() } + + // Shutdown persistent conns, cancel pending conns, and close active conns. + cm.connMtx.Lock() + totalIDs := len(cm.persistent) + len(cm.pending) + len(cm.active) + ids := make(map[uint64]struct{}, totalIDs) + for id := range cm.persistent { + ids[id] = struct{}{} + } + for id := range cm.pending { + ids[id] = struct{}{} + } + for id := range cm.active { + ids[id] = struct{}{} + } + cm.connMtx.Unlock() + for id := range ids { + cm.Remove(id) + } + wg.Wait() - log.Trace("Connection manager stopped") } // New returns a new connection manager with the provided configuration. @@ -648,10 +1235,14 @@ func New(cfg *Config) (*ConnManager, error) { cfg.TargetOutbound = defaultTargetOutbound } cm := ConnManager{ - cfg: *cfg, // Copy so caller can't mutate - quit: make(chan struct{}), - pending: make(map[uint64]*ConnReq), - conns: make(map[uint64]*ConnReq, cfg.TargetOutbound), + cfg: *cfg, // Copy so caller can't mutate + quit: make(chan struct{}), + runPersistentChan: make(chan *persistentEntry, MaxPersistent), + activeOutboundsSem: makeSemaphore(cfg.TargetOutbound), + persistent: make(map[uint64]*persistentEntry, MaxPersistent), + pending: make(map[uint64]*pendingConnInfo), + active: make(map[uint64]*Conn, cfg.TargetOutbound), + connIDByAddr: make(map[string]uint64), } return &cm, nil } diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 65900d583..6001ec75a 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "sync/atomic" "testing" @@ -22,6 +23,28 @@ func init() { maxRetryDuration = 2 * time.Millisecond } +const ( + // connTestReceiveTimeout is the default receive timeout used throughout the + // tests when expecting to receive connections to prevent test hangs. + connTestReceiveTimeout = 10 * time.Millisecond + + // connTestNonReceiveTimeout is the default timeout used throughout the + // tests when expecting that a connection will NOT be received. + connTestNonReceiveTimeout = 20 * time.Millisecond +) + +// mustParseAddrPort parses the provided address into a [*net.TCPAddr] and will +// panic if there is an error. It will only (and must only) be called with +// hard-coded, and therefore known good, addresses. +func mustParseAddrPort(addr string) *net.TCPAddr { + addrPort := netip.MustParseAddrPort(addr) + return &net.TCPAddr{ + IP: addrPort.Addr().AsSlice(), + Port: int(addrPort.Port()), + Zone: addrPort.Addr().Zone(), + } +} + // runConnMgrAsync invokes the Run method on the passed connection manager in a // separate goroutine and returns a cancelable context and wait group the caller // can use to shutdown the connection manager and wait for clean shutdown. @@ -100,38 +123,124 @@ func TestNewConfig(t *testing.T) { } } -// assertConnReqID ensures the provided connection request has the given ID. -func assertConnReqID(t *testing.T, connReq *ConnReq, wantID uint64) { +// assertConnID ensures the provided connection has the given ID. +func assertConnID(t *testing.T, conn *Conn, wantID uint64) { t.Helper() - gotID := connReq.ID() + gotID := conn.ID() if gotID != wantID { t.Fatalf("unexpected ID -- got %v, want %v", gotID, wantID) } } -// assertConnReqState ensures the provided connection request has the given -// state. -func assertConnReqState(t *testing.T, connReq *ConnReq, wantState ConnState) { +// assertConnType ensures the provided connection has the given type. +func assertConnType(t *testing.T, conn *Conn, wantType ConnectionType) { + t.Helper() + + gotType := conn.Type() + if gotType != wantType { + t.Fatalf("unexpected type -- got %v, want %v", gotType, wantType) + } +} + +// pendingAddrConnID returns the connection ID associated with the pending +// connection attempt for the provided address. The second return value will be +// false if no pending attempt is found. +func pendingAddrConnID(cm *ConnManager, addr net.Addr) (uint64, bool) { + cm.connMtx.Lock() + defer cm.connMtx.Unlock() + addrStr := addr.String() + for _, info := range cm.pending { + if info.addr.String() == addrStr { + return info.id, true + } + } + return 0, false +} + +// assertPendingAddr ensures there is a pending connection with the given +// address. +func assertPendingAddr(t *testing.T, cm *ConnManager, addr net.Addr) { t.Helper() - gotState := connReq.State() - if gotState != wantState { - t.Fatalf("unexpected state -- got %v, want %v", gotState, wantState) + if _, ok := pendingAddrConnID(cm, addr); !ok { + t.Fatalf("connection %s is not pending", addr) } } +// assertRemovedPersistent ensures there are no persistent conns with the +// provided address. +func assertRemovedPersistent(t *testing.T, cm *ConnManager, addr net.Addr) { + t.Helper() + + if _, ok := cm.FindPersistentAddrID(addr); ok { + t.Fatalf("found persistent entry for %s", addr) + } +} + +// assertConnReceivedTimeout ensures a connection with the given type is +// received on the provided channel before the given timeout. When given a +// non-zero connection ID, it asserts the received connection has that ID. +func assertConnReceivedTimeout(t *testing.T, ch <-chan *Conn, timeout time.Duration, connID uint64, connType ConnectionType) *Conn { + t.Helper() + + select { + case conn := <-ch: + if connID != 0 { + assertConnID(t, conn, connID) + } + assertConnType(t, conn, connType) + return conn + case <-time.After(timeout): + t.Fatal("connection not received before timeout") + } + return nil +} + +// assertConnReceived ensures a connection with the given type is received on +// the provided channel before the default timeout. When given a non-zero +// connection ID, it asserts the received connection has that ID. +func assertConnReceived(t *testing.T, ch <-chan *Conn, connID uint64, connType ConnectionType) *Conn { + t.Helper() + + return assertConnReceivedTimeout(t, ch, connTestReceiveTimeout, connID, + connType) +} + +// assertNoConnReceivedTimeout ensures no connections are received on the +// provided channel before the given timeout. +func assertNoConnReceivedTimeout(t *testing.T, ch <-chan *Conn, timeout time.Duration) { + t.Helper() + + select { + case conn := <-ch: + conn.Close() + t.Fatalf("got unexpected connection from %v", conn.RemoteAddr()) + case <-time.After(timeout): + // Connection not received as expected. + } +} + +// assertNoConnReceived ensures no connections are received on the provided +// channel before the default timeout. +func assertNoConnReceived(t *testing.T, ch <-chan *Conn) { + t.Helper() + + assertNoConnReceivedTimeout(t, ch, connTestNonReceiveTimeout) +} + // TestConnectMode tests that the connection manager works in the connect mode. // -// In connect mode, automatic connections are disabled, so we test that -// requests using Connect are handled and that no other connections are made. +// In connect mode, automatic connections are disabled, so test that connections +// using [ConnManager.Connect] are handled and that no other connections are +// made. func TestConnectMode(t *testing.T) { - connected := make(chan *ConnReq) + connected := make(chan *Conn) cmgr, err := New(&Config{ TargetOutbound: 2, Dial: mockDialer, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + OnConnection: func(conn *Conn) { + connected <- conn }, }) if err != nil { @@ -139,31 +248,12 @@ func TestConnectMode(t *testing.T) { } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, - } - go cmgr.Connect(ctx, cr) - - // Ensure that the connection was received. - select { - case gotConnReq := <-connected: - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnEstablished) + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) - case <-time.After(time.Millisecond * 5): - t.Fatalf("connect mode: connection timeout - %v", cr.Addr) - } - - // Ensure only a single connection was made. - select { - case c := <-connected: - t.Fatalf("connect mode: got unexpected connection - %v", c.Addr) - case <-time.After(time.Millisecond * 5): - } + // Ensure that only a single connection is received. + assertConnReceived(t, connected, 0, ConnTypeManual) + assertNoConnReceived(t, connected) // Ensure clean shutdown of connection manager. shutdown() @@ -175,27 +265,17 @@ func TestConnectMode(t *testing.T) { // ensuring they are the only connections made. func TestTargetOutbound(t *testing.T) { const targetOutbound = 10 - var numConnections atomic.Uint32 - hitTargetConns := make(chan struct{}) - extraConns := make(chan *ConnReq) + var nextAddr atomic.Uint32 + connected := make(chan *Conn) cmgr, err := New(&Config{ TargetOutbound: targetOutbound, Dial: mockDialer, GetNewAddress: func() (net.Addr, error) { - return &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, nil + addrStr := fmt.Sprintf("127.0.0.%d:18555", nextAddr.Add(1)) + return mustParseAddrPort(addrStr), nil }, - OnConnection: func(c *ConnReq, conn net.Conn) { - totalConnections := numConnections.Add(1) - if totalConnections == targetOutbound { - close(hitTargetConns) - return - } - if totalConnections > targetOutbound { - extraConns <- c - } + OnConnection: func(conn *Conn) { + connected <- conn }, }) if err != nil { @@ -203,74 +283,61 @@ func TestTargetOutbound(t *testing.T) { } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - // Wait for the expected number of target outbound conns to be established. - select { - case <-hitTargetConns: - case <-time.After(20 * time.Millisecond): - t.Fatal("did not reach target number of conns before timeout") - } - - // Ensure no additional connections are made. - select { - case c := <-extraConns: - t.Fatalf("target outbound: got unexpected connection - %v", c.Addr) - case <-time.After(time.Millisecond * 5): - break + // Ensure only the expected number of target outbound conns are established + // and no more. + for range targetOutbound { + assertConnReceived(t, connected, 0, ConnTypeOutbound) } + assertNoConnReceived(t, connected) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() } -// TestRetryPermanent tests that permanent connection requests are retried. -// -// We make a permanent connection request using Connect, disconnect it using -// Disconnect and we wait for it to be connected back. -func TestRetryPermanent(t *testing.T) { - connected := make(chan *ConnReq) - disconnected := make(chan *ConnReq) +// TestRetryPersistent tests that persistent connections are retried. +func TestRetryPersistent(t *testing.T) { + connected := make(chan *Conn) + disconnected := make(chan *Conn) cmgr, err := New(&Config{ RetryDuration: time.Millisecond, TargetOutbound: 1, Dial: mockDialer, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + OnConnection: func(conn *Conn) { + connected <- conn }, - OnDisconnection: func(c *ConnReq) { - disconnected <- c + OnDisconnection: func(conn *Conn) { + disconnected <- conn }, }) if err != nil { t.Fatalf("New error: %v", err) } - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, + addr := mustParseAddrPort("127.0.0.1:18555") + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + if !cmgr.IsPersistent(connID) { + t.Fatal("IsPersistent did not reported true for persistent conn") } - go cmgr.Connect(ctx, cr) - gotConnReq := <-connected - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnEstablished) - - cmgr.Disconnect(cr.ID()) - gotConnReq = <-disconnected - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnPending) - gotConnReq = <-connected - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnEstablished) + // Wait for the first connection, close it, wait for the disconnect, and + // ensure the retry succeeds. + conn := assertConnReceived(t, connected, connID, ConnTypeManual) + conn.Close() + assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertConnReceived(t, connected, connID, ConnTypeManual) - cmgr.Remove(cr.ID()) - gotConnReq = <-disconnected - assertConnReqID(t, gotConnReq, cr.ID()) - assertConnReqState(t, cr, ConnDisconnected) + // Remove the persistent connection, wait for it to disconnect, and ensure + // it is actually removed. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("failed to remove persistent connection: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertRemovedPersistent(t, cmgr, addr) // Ensure clean shutdown of connection manager. shutdown() @@ -290,9 +357,6 @@ func TestMaxRetryDuration(t *testing.T) { } networkUp := make(chan struct{}) - time.AfterFunc(5*time.Millisecond, func() { - close(networkUp) - }) timedDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { select { case <-networkUp: @@ -302,36 +366,34 @@ func TestMaxRetryDuration(t *testing.T) { } } - connected := make(chan *ConnReq) + connected := make(chan *Conn) cmgr, err := New(&Config{ RetryDuration: time.Millisecond, TargetOutbound: 1, Dial: timedDialer, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + OnConnection: func(conn *Conn) { + connected <- conn }, }) if err != nil { t.Fatalf("New error: %v", err) } - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, + connID, err := cmgr.AddPersistent(mustParseAddrPort("127.0.0.1:18555")) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) } - go cmgr.Connect(ctx, cr) + // retry in 1ms // retry in 2ms - max retry duration reached - // retry in 2ms - timedDialer returns mockDial - select { - case <-connected: - case <-time.After(200 * time.Millisecond): - t.Fatal("max retry duration: connection timeout") - } + // retry in 2ms - timedDialer returns [mockDialer] + const networkUpTimeout = 5 * time.Millisecond + time.AfterFunc(networkUpTimeout, func() { + close(networkUp) + }) + const timeout = connTestReceiveTimeout + networkUpTimeout + assertConnReceivedTimeout(t, connected, timeout, connID, ConnTypeManual) // Ensure clean shutdown of connection manager. shutdown() @@ -349,24 +411,24 @@ func TestNetworkFailure(t *testing.T) { connMgrDone := make(chan struct{}) errDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { totalDials := dials.Add(1) - if totalDials >= maxFailedAttempts { + if totalDials > maxFailedAttempts { closeOnce.Do(func() { close(reachedMaxFailedAttempts) }) <-connMgrDone } return nil, errors.New("network down") } + var nextAddr atomic.Uint32 cmgr, err := New(&Config{ TargetOutbound: targetOutbound, RetryDuration: retryTimeout, Dial: errDialer, GetNewAddress: func() (net.Addr, error) { - return &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, nil + addrStr := fmt.Sprintf("127.0.0.%d:18555", nextAddr.Add(1)) + return mustParseAddrPort(addrStr), nil }, - OnConnection: func(c *ConnReq, conn net.Conn) { - t.Fatalf("network failure: got unexpected connection - %v", c.Addr) + OnConnection: func(conn *Conn) { + t.Fatalf("network failure: got unexpected connection - %v", + conn.RemoteAddr()) }, }) if err != nil { @@ -377,7 +439,11 @@ func TestNetworkFailure(t *testing.T) { // Shutdown the connection manager after the max failed attempts is reached // and an additional retry duration has passed and then wait for the // shutdown to complete. - <-reachedMaxFailedAttempts + select { + case <-reachedMaxFailedAttempts: + case <-time.After(retryTimeout * maxFailedAttempts * 3): + t.Fatal("did not reach target number of failed attempts before timeout") + } time.Sleep(retryTimeout) shutdown() close(connMgrDone) @@ -396,7 +462,7 @@ func TestNetworkFailure(t *testing.T) { // TestMultipleFailedConns ensures that the connection manager remains // responsive when there are multiple simultaneous failed connections for -// persistent peers in the retry state. +// persistent conns in the retry state. func TestMultipleFailedConns(t *testing.T) { // Override the max retry duration for this test since it relies on having // multiple connections in the retry state. @@ -424,18 +490,15 @@ func TestMultipleFailedConns(t *testing.T) { if err != nil { t.Fatalf("New error: %v", err) } - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish several connection requests to localhost IPs. - for i := 0; i < targetFailed; i++ { - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP(fmt.Sprintf("127.0.0.%d", i+1)), - Port: 18555, - }, - Permanent: true, + for i := range targetFailed { + addr := mustParseAddrPort(fmt.Sprintf("127.0.0.%d:18555", i+1)) + _, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("unexpected add err: %v", err) } - go cmgr.Connect(ctx, cr) } // Wait for the target number of dials and ensure they happen simultaneously @@ -467,10 +530,6 @@ func TestMultipleFailedConns(t *testing.T) { // TestShutdownFailedConns tests that failed connections are ignored after // connmgr is shutdown. -// -// We have a dialer which sets the stop flag on the conn manager and returns an -// err so that the handler assumes that the conn manager is stopped and ignores -// the failure. func TestShutdownFailedConns(t *testing.T) { var closeOnce sync.Once dialed := make(chan struct{}) @@ -495,30 +554,25 @@ func TestShutdownFailedConns(t *testing.T) { shutdown() }() - // Establish a connection request to a localhost IP. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, - } - go cmgr.Connect(ctx, cr) + // Establish a connection. + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) // Ensure clean shutdown of connection manager. wg.Wait() } -// TestRemovePendingConnection tests that it's possible to cancel a pending -// connection, removing its internal state from the connection manager. +// TestRemovePendingConnection ensures that removing a pending outbound +// connection correctly cancels the context used to dial and removes the +// internal state. func TestRemovePendingConnection(t *testing.T) { - // Create a ConnMgr instance with an instance of a dialer that'll never - // succeed. + // Create a conn manager with an instance of a dialer that'll never succeed. dialed := make(chan struct{}) - wait := make(chan struct{}) + canceled := make(chan struct{}) indefiniteDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { close(dialed) - <-wait + <-ctx.Done() + close(canceled) return nil, errors.New("error") } cmgr, err := New(&Config{ @@ -530,121 +584,118 @@ func TestRemovePendingConnection(t *testing.T) { ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a connection request to a localhost IP. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - Permanent: true, - } - go cmgr.Connect(ctx, cr) + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) - // Wait for the connection manager to attempt to dial the connection request - // and ensure the connection is marked as pending while the dialer is - // blocked. + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. select { case <-dialed: case <-time.After(time.Millisecond * 20): t.Fatal("timeout waiting for dial") } - assertConnReqState(t, cr, ConnPending) + assertPendingAddr(t, cmgr, addr) - // The request launched above will never be able to establish a connection, - // so cancel it _before_ it's able to be completed. - cmgr.Remove(cr.ID()) + // Cancel the connection attempt while it's still pending. + connID, _ := pendingAddrConnID(cmgr, addr) + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected remove err: %v", err) + } - // Ensure the connection request is now marked as canceled after a short - // timeout to allow the transition to occur. - time.Sleep(10 * time.Millisecond) - assertConnReqState(t, cr, ConnCanceled) + // Wait for the dialer to signal the context associated with the dial was + // canceled and ensure the internal pending state is removed. + select { + case <-canceled: + case <-time.After(time.Millisecond * 20): + t.Fatal("timeout waiting for cancel") + } + if _, ok := pendingAddrConnID(cmgr, addr); ok { + t.Fatalf("connection %s is still pending", addr) + } // Ensure clean shutdown of connection manager. - close(wait) shutdown() wg.Wait() } -// TestCancelIgnoreDelayedConnection tests that a canceled connection request -// will not execute the on connection callback, even if an outstanding retry -// succeeds. +// TestCancelIgnoreDelayedConnection tests that a canceled pending persistent +// connection will not execute the on connection callback, even if a pending +// retry succeeds. func TestCancelIgnoreDelayedConnection(t *testing.T) { const retryTimeout = 10 * time.Millisecond - // Setup a dialer that will continue to return an error until the - // connect chan is signaled. The dial attempt immediately after that - // will succeed in returning a connection. + // Setup a dialer that returns an error on the first attempt and then blocks + // until the connect chan is signaled. The dial attempt immediately after + // that will succeed in returning a connection. + var numAttempts atomic.Uint32 connect := make(chan struct{}) + retried := make(chan struct{}) failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { - select { - case <-connect: - return mockDialer(ctx, network, addr) - default: + if numAttempts.Add(1) == 1 { + return nil, errors.New("network down") } - return nil, errors.New("error") + close(retried) + <-connect + + // Override the context to ensure the pending dial succeeds even though + // the passed context will be canceled. + ctx = context.Background() + return mockDialer(ctx, network, addr) } - connected := make(chan *ConnReq) + connected := make(chan *Conn) cmgr, err := New(&Config{ Dial: failingDialer, RetryDuration: retryTimeout, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c + OnConnection: func(conn *Conn) { + connected <- conn }, }) if err != nil { t.Fatalf("New error: %v", err) } - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - // Establish a connection request to a localhost IP. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, + // Establish a persistent connection to a localhost IP. + addr := mustParseAddrPort("127.0.0.1:18555") + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("unexpected error: %v", err) } - go cmgr.Connect(ctx, cr) - - // Allow for the first retry timeout to elapse. - time.Sleep(2 * retryTimeout) - // Ensure the status of the connection request is marked as failed, even - // after reattempting to connect. - assertConnReqState(t, cr, ConnFailed) + // Wait for the retry and ensure the connection is pending. + select { + case <-retried: + case <-time.After(20 * time.Millisecond): + t.Fatalf("did not get retry before timeout") + } + assertPendingAddr(t, cmgr, addr) - // Remove the connection, and then immediately allow the next connection - // to succeed. - cmgr.Remove(cr.ID()) + // Remove the connection and then immediately allow the next connection to + // succeed. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected remove err: %v", err) + } close(connect) - // Allow the connection manager to process the removal. - time.Sleep(5 * time.Millisecond) - - // Ensure the status of the connection request is canceled. - assertConnReqState(t, cr, ConnCanceled) - // Finally, the connection manager should not signal the OnConnection // callback, since the request was explicitly canceled. Give a generous - // timeout window to ensure the connection manager's linear backoff is - // allowed to properly elapse. - select { - case <-connected: - t.Fatal("on-connect should not be called for canceled req") - case <-time.After(5 * retryTimeout): - } + // timeout window to ensure the connection manager's backoff is allowed to + // properly elapse. + assertNoConnReceivedTimeout(t, connected, 5*retryTimeout) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() } -// TestDialTimeout ensure the Timeout configuration parameter works as intended -// by creating a dialer that blocks for three times the configured dial timeout -// before connecting and ensuring the connection fails as expected. +// TestDialTimeout ensure [Config.Timeout] works as intended by creating a +// dialer that blocks for three times the configured dial timeout before +// connecting and ensuring the connection fails as expected. func TestDialTimeout(t *testing.T) { - // Create a connection manager instance with a dialer that blocks for twice - // the configured dial timeout before connecting. + // Create a connection manager instance with a dialer that blocks for three + // times the configured dial timeout before connecting. const dialTimeout = time.Millisecond * 20 cancelled := make(chan struct{}) timeoutDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -664,35 +715,27 @@ func TestDialTimeout(t *testing.T) { if err != nil { t.Fatalf("New error: %v", err) } - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - // Establish a connection request to a localhost IP. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - } - go cmgr.Connect(context.Background(), cr) + // Establish a connection to a localhost IP. + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) // Wait to receive the signal that the dialer context was cancelled, which - // means the dial timeout was hit, and ensure the connection request is - // marked as failed after a short timeout to allow the transition to occur. + // means the dial timeout was hit. select { case <-cancelled: case <-time.After(dialTimeout * 10): t.Fatal("timeout waiting for dial cancellation") } - time.Sleep(10 * time.Millisecond) - assertConnReqState(t, cr, ConnFailed) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() } -// TestConnectContext ensures the Connect method works as intended when provided -// with a context that times out before a dial attempt succeeds. +// TestConnectContext ensures the [ConnManager.Connect] method works as intended +// when provided with a context that is canceled before a dial attempt succeeds. func TestConnectContext(t *testing.T) { // Create a connection manager instance with a dialer that blocks until its // provided context is canceled. @@ -708,18 +751,17 @@ func TestConnectContext(t *testing.T) { if err != nil { t.Fatalf("New error: %v", err) } - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a connection request to a localhost IP with a separate context // that can be canceled. - cr := &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, - } - connectCtx, cancelConnect := context.WithCancel(context.Background()) - go cmgr.Connect(connectCtx, cr) + addr := mustParseAddrPort("127.0.0.1:18555") + connectCtx, cancelConnect := context.WithCancel(ctx) + connectErr := make(chan error, 1) + go func() { + _, err := cmgr.Connect(connectCtx, addr) + connectErr <- err + }() // Wait for the connection manager to attempt to dial the connection request // and ensure the connection is marked as pending while the dialer is @@ -729,119 +771,19 @@ func TestConnectContext(t *testing.T) { case <-time.After(time.Millisecond * 20): t.Fatal("timeout waiting for dial") } - assertConnReqState(t, cr, ConnPending) + assertPendingAddr(t, cmgr, addr) - // Cancel the connection context and ensure the connection request is marked - // as failed after a short timeout to allow the transition to occur. + // Cancel the connection context, wait for the error from connect, and + // ensure it is the expected error. cancelConnect() - time.Sleep(10 * time.Millisecond) - assertConnReqState(t, cr, ConnFailed) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() -} - -// TestForEachConnReq tests the connection request iteration logic work as -// expected including for normal, permanent, and pending connections. -func TestForEachConnReq(t *testing.T) { - // Create a connection manager instance with a dialer that recognizes a - // special address to delay on in order to keep it pending. - targetOutbound := uint32(5) - connected := make(chan *ConnReq) - pending := make(chan struct{}) - delayDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { - if addr == "127.0.0.1:18557" { - close(pending) - time.Sleep(time.Second) - return nil, errors.New("error") - } - return mockDialer(ctx, network, addr) - } - cmgr, err := New(&Config{ - TargetOutbound: targetOutbound, - Dial: delayDialer, - GetNewAddress: func() (net.Addr, error) { - return &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }, nil - }, - OnConnection: func(c *ConnReq, conn net.Conn) { - connected <- c - }, - }) - if err != nil { - t.Fatalf("New error: %v", err) - } - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - - // Wait for the expected number of target outbound conns to be established. - allConnected := make(chan struct{}) - go func() { - for i := uint32(0); i < targetOutbound; i++ { - <-connected - } - close(allConnected) - }() - select { - case <-allConnected: - case <-time.After(time.Millisecond * 5 * time.Duration(targetOutbound)): - t.Fatal("timeout waiting for connections") - } - - // Create a permanent connection. - cr := &ConnReq{ - Permanent: true, - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - }, - } - go cmgr.Connect(context.Background(), cr) - select { - case <-connected: - case <-time.After(time.Millisecond * 5): - t.Fatal("timeout waiting for permanent connection") - } - - // Create a connection that triggers the mock dialer to keep it pending. - cr = &ConnReq{ - Addr: &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18557, - }, - } - go cmgr.Connect(context.Background(), cr) select { - case <-pending: - case <-time.After(time.Millisecond * 5): - t.Fatal("timeout waiting for pending connection") - } - - // Ensure the expected number of each type of connection exists. - var numConnected, numPermanent, numPending uint32 - _ = cmgr.ForEachConnReq(func(cr *ConnReq) error { - numConnected++ - if cr.State() == ConnPending { - numPending++ - } - if cr.Permanent { - numPermanent++ + case err := <-connectErr: + if !errors.Is(err, context.Canceled) { + t.Fatalf("unexpected connect err: got %v, want %v", err, + context.Canceled) } - return nil - }) - if numConnected != targetOutbound+2 { - t.Fatalf("unexpected number of iterated conn reqs -- got %d, want %d", - numConnected, targetOutbound+2) - } - if numPermanent != 1 { - t.Fatalf("unexpected number of permanent conn reqs -- got %d, want %d", - numPermanent, 1) - } - if numPending != 1 { - t.Fatalf("unexpected number of pending conn reqs -- got %d, want %d", - numPending, 1) + case <-time.After(10 * time.Millisecond): + t.Fatal("timeout waiting for dial cancellation") } // Ensure clean shutdown of connection manager. @@ -888,14 +830,11 @@ func (m *mockListener) Addr() net.Addr { // address. It will cause the Accept function to return a mock connection // configured with the provided remote address and the local address for the // mock listener. -func (m *mockListener) Connect(ip string, port int) { +func (m *mockListener) Connect(addr net.Addr) { m.provideConn <- &mockConn{ laddr: m.localAddr, lnet: "tcp", - rAddr: &net.TCPAddr{ - IP: net.ParseIP(ip), - Port: port, - }, + rAddr: addr, } } @@ -913,13 +852,13 @@ func newMockListener(localAddr string) *mockListener { func TestListeners(t *testing.T) { // Setup a connection manager with a couple of mock listeners that // notify a channel when they receive mock connections. - receivedConns := make(chan net.Conn) - listener1 := newMockListener("127.0.0.1:8333") - listener2 := newMockListener("127.0.0.1:9333") + receivedConns := make(chan *Conn) + listener1 := newMockListener("127.0.0.1:9108") + listener2 := newMockListener("127.0.0.1:9208") listeners := []net.Listener{listener1, listener2} cmgr, err := New(&Config{ Listeners: listeners, - OnAccept: func(conn net.Conn) { + OnAccept: func(conn *Conn) { receivedConns <- conn }, Dial: mockDialer, @@ -933,29 +872,15 @@ func TestListeners(t *testing.T) { go func() { for i, listener := range listeners { l := listener.(*mockListener) - l.Connect("127.0.0.1", 10000+i*2) - l.Connect("127.0.0.1", 10000+i*2+1) + l.Connect(mustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", 10000+i*2))) + l.Connect(mustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", 10000+i*2+1))) } }() - // Tally the receive connections to ensure the expected number are - // received. Also, fail the test after a timeout so it will not hang - // forever should the test not work. + // Ensure the expected number of inbound connections are received. expectedNumConns := len(listeners) * 2 - var numConns int -out: - for { - select { - case <-receivedConns: - numConns++ - if numConns == expectedNumConns { - break out - } - - case <-time.After(time.Millisecond * 50): - t.Fatalf("Timeout waiting for %d expected connections", - expectedNumConns) - } + for range expectedNumConns { + assertConnReceived(t, receivedConns, 0, ConnTypeInbound) } // Ensure clean shutdown of connection manager. diff --git a/internal/connmgr/conntype_test.go b/internal/connmgr/conntype_test.go new file mode 100644 index 000000000..f4d3f69e1 --- /dev/null +++ b/internal/connmgr/conntype_test.go @@ -0,0 +1,35 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import ( + "testing" +) + +// TestConnectionTypeStringer tests the stringized output for connection types. +func TestConnectionTypeStringer(t *testing.T) { + tests := []struct { + in ConnectionType + want string + }{ + {ConnTypeInbound, "inbound"}, + {ConnTypeOutbound, "outbound"}, + {ConnTypeManual, "manual"}, + {0xff, "Unknown ConnectionType (255)"}, + } + + // Detect additional defines that don't have the stringer added. + if len(tests)-1 != int(numConnTypes) { + t.Fatal("It appears a connection type was added without adding an " + + "associated stringer test") + } + + for i, test := range tests { + if got := test.in.String(); got != test.want { + t.Errorf("String #%d: got: %s, want: %s", i, got, test.want) + continue + } + } +} diff --git a/internal/connmgr/error.go b/internal/connmgr/error.go index 932a13f28..9c87dad6e 100644 --- a/internal/connmgr/error.go +++ b/internal/connmgr/error.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 The Decred developers +// Copyright (c) 2020-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -14,10 +14,34 @@ const ( // the configuration. ErrDialNil = ErrorKind("ErrDialNil") + // ErrAlreadyPending indicates an attempt to connect to an address that + // already has a pending connection attempt. + ErrAlreadyPending = ErrorKind("ErrAlreadyPending") + + // ErrAlreadyConnected indicates an attempt to connect to an address that + // already has an established connection. + ErrAlreadyConnected = ErrorKind("ErrAlreadyConnected") + + // ErrMaxPersistent indicates an attempt to add more than the maximum + // allowed number of persistent connections. + ErrMaxPersistent = ErrorKind("ErrMaxPersistent") + + // ErrDuplicatePersistent indicates an attempt to add a persistent + // connection to an address that already exists. + ErrDuplicatePersistent = ErrorKind("ErrDuplicatePersistent") + // ErrNotFound indicates a specified connection ID or address is unknown to // the connection manager. ErrNotFound = ErrorKind("ErrNotFound") + // ErrUnsupportedAddr indicates an address is either an unsupported type or + // an unrecognized type due to being malformed. + ErrUnsupportedAddr = ErrorKind("ErrUnsupportedAddr") + + // ErrShutdown indicates the connection manager is either in the process of + // shutting down or has already been shutdown. + ErrShutdown = ErrorKind("ErrShutdown") + // ErrTorInvalidAddressResponse indicates an invalid address was // returned by the Tor DNS resolver. ErrTorInvalidAddressResponse = ErrorKind("ErrTorInvalidAddressResponse") diff --git a/internal/connmgr/error_test.go b/internal/connmgr/error_test.go index 1e177b97e..d4e5d2262 100644 --- a/internal/connmgr/error_test.go +++ b/internal/connmgr/error_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021 The Decred developers +// Copyright (c) 2020-2026 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -17,7 +17,13 @@ func TestErrorKindStringer(t *testing.T) { want string }{ {ErrDialNil, "ErrDialNil"}, + {ErrAlreadyPending, "ErrAlreadyPending"}, + {ErrAlreadyConnected, "ErrAlreadyConnected"}, + {ErrMaxPersistent, "ErrMaxPersistent"}, + {ErrDuplicatePersistent, "ErrDuplicatePersistent"}, {ErrNotFound, "ErrNotFound"}, + {ErrUnsupportedAddr, "ErrUnsupportedAddr"}, + {ErrShutdown, "ErrShutdown"}, {ErrTorInvalidAddressResponse, "ErrTorInvalidAddressResponse"}, {ErrTorInvalidProxyResponse, "ErrTorInvalidProxyResponse"}, {ErrTorUnrecognizedAuthMethod, "ErrTorUnrecognizedAuthMethod"}, diff --git a/internal/rpcserver/interface.go b/internal/rpcserver/interface.go index 97b88fde9..c87134f49 100644 --- a/internal/rpcserver/interface.go +++ b/internal/rpcserver/interface.go @@ -114,7 +114,7 @@ type ConnManager interface { // permanent flag indicates whether or not to make the peer persistent // and reconnect if the connection is lost. Attempting to connect to an // already existing peer will return an error. - Connect(addr string, permanent bool) error + Connect(ctx context.Context, addr string, permanent bool) error // RemoveByID removes the peer associated with the provided id from the // list of persistent peers. Attempting to remove an id that does not diff --git a/internal/rpcserver/rpcserver.go b/internal/rpcserver/rpcserver.go index f4907bdd8..175ec9831 100644 --- a/internal/rpcserver/rpcserver.go +++ b/internal/rpcserver/rpcserver.go @@ -598,7 +598,7 @@ func newWorkState() *workState { } // handleAddNode handles addnode commands. -func handleAddNode(_ context.Context, s *Server, cmd any) (any, error) { +func handleAddNode(ctx context.Context, s *Server, cmd any) (any, error) { c := cmd.(*types.AddNodeCmd) addr := normalizeAddress(c.Addr, s.cfg.ChainParams.DefaultPort) @@ -606,25 +606,47 @@ func handleAddNode(_ context.Context, s *Server, cmd any) (any, error) { var err error switch c.SubCmd { case "add": - err = connMgr.Connect(addr, true) + err = connMgr.Connect(ctx, addr, true) case "remove": err = connMgr.RemoveByAddr(addr) case "onetry": - err = connMgr.Connect(addr, false) + err = connMgr.Connect(ctx, addr, false) default: return nil, rpcInvalidError("Invalid subcommand for addnode") } if err != nil { - return nil, rpcInvalidError("%v: %v", c.SubCmd, err) + switch { + // Connecting involves child contexts, so there is no guarantee that + // context errors returned from Connect are the result of the parent + // context. + // + // Check the parent context first to determine if the failure is the + // result of the RPC server (e.g. RPC connection closed, server + // shutdown, etc). + // + // Otherwise, context errors refer to the actual connection attempt. + case ctx.Err() != nil: + return nil, rpcConnectionClosedError() + + case errors.Is(err, context.Canceled): + return nil, rpcCancelError("%v: connection attempt to %v canceled", + c.SubCmd, addr) + + case errors.Is(err, context.DeadlineExceeded): + return nil, rpcCancelError("%v: timeout connecting to %v", c.SubCmd, + addr) + } + + prefix := fmt.Sprintf("%v: failed operation on %v", c.SubCmd, addr) + return nil, rpcInternalErr(err, prefix) } - // no data returned unless an error. return nil, nil } // handleNode handles node commands. -func handleNode(_ context.Context, s *Server, cmd any) (any, error) { +func handleNode(ctx context.Context, s *Server, cmd any) (any, error) { c := cmd.(*types.NodeCmd) connMgr := s.cfg.ConnMgr @@ -646,13 +668,16 @@ func handleNode(_ context.Context, s *Server, cmd any) (any, error) { addr = normalizeAddress(c.Target, params.DefaultPort) err = connMgr.DisconnectByAddr(addr) } else { - return nil, rpcInvalidError("%v: Invalid "+ - "address or node ID", c.SubCmd) + return nil, rpcInvalidError("%v: invalid address or node ID", + c.SubCmd) } } if err != nil && peerExists(connMgr, addr, int32(nodeID)) { - return nil, rpcMiscError("can't disconnect a permanent peer, " + - "use remove") + return nil, rpcMiscError("can't disconnect a permanent peer, use " + + "remove") + } + if err != nil { + return nil, rpcInvalidError("%v: %v", c.SubCmd, err) } case "remove": @@ -667,13 +692,16 @@ func handleNode(_ context.Context, s *Server, cmd any) (any, error) { addr = normalizeAddress(c.Target, params.DefaultPort) err = connMgr.RemoveByAddr(addr) } else { - return nil, rpcInvalidError("%v: invalid "+ - "address or node ID", c.SubCmd) + return nil, rpcInvalidError("%v: invalid address or node ID", + c.SubCmd) } } if err != nil && peerExists(connMgr, addr, int32(nodeID)) { - return nil, rpcMiscError("can't remove a temporary peer, " + - "use disconnect") + return nil, rpcMiscError("can't remove a temporary peer, use " + + "disconnect") + } + if err != nil { + return nil, rpcInvalidError("%v: %v", c.SubCmd, err) } case "connect": @@ -687,20 +715,42 @@ func handleNode(_ context.Context, s *Server, cmd any) (any, error) { switch subCmd { case "perm", "temp": - err = connMgr.Connect(addr, subCmd == "perm") + err = connMgr.Connect(ctx, addr, subCmd == "perm") default: - return nil, rpcInvalidError("%v: invalid subcommand "+ - "for node connect", subCmd) + return nil, rpcInvalidError("%v: invalid subcommand for node "+ + "connect", subCmd) } + if err != nil { + // Connecting involves child contexts, so there is no guarantee that + // context errors returned from Connect are the result of the parent + // context. + // + // Check the parent context first to determine if the failure is the + // result of the RPC server (e.g. RPC connection closed, server + // shutdown, etc). + // + // Otherwise, context errors refer to the actual connection attempt. + switch { + case ctx.Err() != nil: + return nil, rpcConnectionClosedError() + + case errors.Is(err, context.Canceled): + return nil, rpcCancelError("%v: connection attempt to %v "+ + "canceled", c.SubCmd, addr) + + case errors.Is(err, context.DeadlineExceeded): + return nil, rpcCancelError("%v: timeout connecting to %v", + c.SubCmd, addr) + } + + prefix := fmt.Sprintf("%v: failed operation on %v", c.SubCmd, addr) + return nil, rpcInternalErr(err, prefix) + } + default: return nil, rpcInvalidError("%v: invalid subcommand for node", c.SubCmd) } - if err != nil { - return nil, rpcInvalidError("%v: %v", c.SubCmd, err) - } - - // no data returned unless an error. return nil, nil } diff --git a/internal/rpcserver/rpcserverhandlers_test.go b/internal/rpcserver/rpcserverhandlers_test.go index ac6d8fd3b..25d1a070c 100644 --- a/internal/rpcserver/rpcserverhandlers_test.go +++ b/internal/rpcserver/rpcserverhandlers_test.go @@ -856,7 +856,7 @@ type testConnManager struct { // Connect provides a mock implementation for adding the provided address as a // new outbound peer. -func (c *testConnManager) Connect(addr string, permanent bool) error { +func (c *testConnManager) Connect(ctx context.Context, addr string, permanent bool) error { return c.connectErr } @@ -1919,7 +1919,7 @@ func TestHandleAddNode(t *testing.T) { return connManager }(), wantErr: true, - errCode: dcrjson.ErrRPCInvalidParameter, + errCode: dcrjson.ErrRPCInternal.Code, }, { name: "handleAddNode: 'remove' subcommand error", handler: handleAddNode, @@ -1933,7 +1933,7 @@ func TestHandleAddNode(t *testing.T) { return connManager }(), wantErr: true, - errCode: dcrjson.ErrRPCInvalidParameter, + errCode: dcrjson.ErrRPCInternal.Code, }, { name: "handleAddNode: 'onetry' subcommand error", handler: handleAddNode, @@ -1947,7 +1947,7 @@ func TestHandleAddNode(t *testing.T) { return connManager }(), wantErr: true, - errCode: dcrjson.ErrRPCInvalidParameter, + errCode: dcrjson.ErrRPCInternal.Code, }, { name: "handleAddNode: invalid subcommand", handler: handleAddNode, diff --git a/rpcadaptors.go b/rpcadaptors.go index 8dd072a94..a6826e047 100644 --- a/rpcadaptors.go +++ b/rpcadaptors.go @@ -15,7 +15,6 @@ import ( "github.com/decred/dcrd/chaincfg/v3" "github.com/decred/dcrd/dcrutil/v4" "github.com/decred/dcrd/internal/blockchain" - "github.com/decred/dcrd/internal/connmgr" "github.com/decred/dcrd/internal/mempool" "github.com/decred/dcrd/internal/mining" "github.com/decred/dcrd/internal/mining/cpuminer" @@ -125,29 +124,7 @@ var _ rpcserver.ConnManager = (*rpcConnManager)(nil) // // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. -func (cm *rpcConnManager) Connect(addr string, permanent bool) error { - // Prevent duplicate connections to the same peer. - connManager := cm.server.connManager - err := connManager.ForEachConnReq(func(c *connmgr.ConnReq) error { - if c.Addr != nil && c.Addr.String() == addr { - if c.Permanent { - return errors.New("peer exists as a permanent peer") - } - - switch c.State() { - case connmgr.ConnPending: - return errors.New("peer pending connection") - case connmgr.ConnEstablished: - return errors.New("peer already connected") - - } - } - return nil - }) - if err != nil { - return err - } - +func (cm *rpcConnManager) Connect(ctx context.Context, addr string, permanent bool) error { netAddr, err := addrStringToNetAddr(addr) if err != nil { return err @@ -161,40 +138,44 @@ func (cm *rpcConnManager) Connect(addr string, permanent bool) error { return errors.New("max peers reached") } - go connManager.Connect(context.Background(), &connmgr.ConnReq{ - Addr: netAddr, - Permanent: permanent, - }) - return nil + // Attempt to add a persistent peer when requested. + connManager := cm.server.connManager + if permanent { + _, err := connManager.AddPersistent(netAddr) + return err + } + + // Attempt to connect to the address. + _, err = connManager.Connect(ctx, netAddr) + return err } -// removeNode removes any peers that the provided compare function return true +// errPeerNotFound is returned by the RPC conn manager when no matching peer for +// a given address or ID is found. +var errPeerNotFound = errors.New("peer not found") + +// removeNode removes any peer that the provided compare function return true // for from the list of persistent peers. // -// An error will be returned if no matching peers are found (aka the compare +// An error will be returned if no matching peer is found (aka the compare // function returns false for all peers). func (cm *rpcConnManager) removeNode(cmp func(*serverPeer) bool) error { state := &cm.server.peerState + var found *serverPeer state.Lock() - found := disconnectPeer(state.persistentPeers, cmp, func(sp *serverPeer) { - // Update the group counts since the peer will be removed from the - // persistent peers just after this func returns. - state.outboundGroups[sp.remoteAddr.GroupKey()]-- - - connReq := sp.connReq.Load() - peerLog.Debugf("Removing persistent peer %s (reqid %d)", sp.remoteAddr, - connReq.ID()) - - // Mark the peer's connReq as nil to prevent it from scheduling a - // re-connect attempt. - sp.connReq.Store(nil) - cm.server.connManager.Remove(connReq.ID()) - }) + for _, peer := range state.persistentPeers { + if cmp(peer) { + found = peer + break + } + } state.Unlock() - - if !found { - return errors.New("peer not found") + if found == nil { + return errPeerNotFound } + + peerLog.Debugf("Removing persistent peer %s", found.remoteAddr) + cm.server.connManager.Remove(found.conn.ID()) return nil } @@ -203,10 +184,20 @@ func (cm *rpcConnManager) removeNode(cmp func(*serverPeer) bool) error { // an error. // // This function is safe for concurrent access and is part of the -// rpcserver.ConnManager interface implementation. +// [rpcserver.ConnManager] interface implementation. func (cm *rpcConnManager) RemoveByID(id int32) error { + // Attempt to remove the peer by ID first. When the ID does not correspond + // to an established persistent peer, fall back to treating the ID as a + // connection ID and remove it when it is for a persistent connection. + connManager := cm.server.connManager cmp := func(sp *serverPeer) bool { return sp.ID() == id } - return cm.removeNode(cmp) + err := cm.removeNode(cmp) + if errors.Is(err, errPeerNotFound) && connManager.IsPersistent(uint64(id)) { + if rErr := connManager.Remove(uint64(id)); rErr == nil { + return nil + } + } + return err } // RemoveByAddr removes the peer associated with the provided address from the @@ -214,57 +205,67 @@ func (cm *rpcConnManager) RemoveByID(id int32) error { // exist will return an error. // // This function is safe for concurrent access and is part of the -// rpcserver.ConnManager interface implementation. +// [rpcserver.ConnManager] interface implementation. func (cm *rpcConnManager) RemoveByAddr(addr string) error { + // Attempt to remove the peer by address first. When the address does not + // correspond to an established persistent peer, fall back to searching the + // connection manager directly for a matching persistent connection entry + // and remove it when found. cmp := func(sp *serverPeer) bool { return sp.Addr() == addr } err := cm.removeNode(cmp) - if err != nil { - netAddr, err := addrStringToNetAddr(addr) - if err != nil { - return err + if errors.Is(err, errPeerNotFound) { + netAddr := simpleAddr{"tcp", addr} + if id, ok := cm.server.connManager.FindPersistentAddrID(netAddr); ok { + cm.server.connManager.Remove(id) + return nil } - return cm.server.connManager.CancelPending(netAddr) } - return nil + return err } -// disconnectNode disconnects any peers that the provided compare function +// disconnectNode disconnects any peer that the provided compare function // returns true for. It applies to both inbound and outbound peers. // -// An error will be returned if no matching peers are found (aka the compare +// An error will be returned if no matching peer is found (aka the compare // function returns false for all peers). // // This function is safe for concurrent access. func (cm *rpcConnManager) disconnectNode(cmp func(sp *serverPeer) bool) error { state := &cm.server.peerState - defer state.Unlock() + state.Lock() + defer state.Unlock() - // Check inbound peers. No callback is passed since there are no additional - // actions on disconnect for inbound peers. - found := disconnectPeer(state.inboundPeers, cmp, nil) - if found { - return nil - } + // The code below uses the fact that the connection manager prevents + // connections with duplicate addresses to limit the search to a single + // match. - // Check outbound peers in a loop to ensure all outbound connections to the - // same ip:port are disconnected when there are multiple. - var numFound uint32 - for ; ; numFound++ { - found = disconnectPeer(state.outboundPeers, cmp, func(sp *serverPeer) { - // Update the group counts since the peer will be removed from the - // persistent peers just after this func returns. - state.outboundGroups[sp.remoteAddr.GroupKey()]-- - }) - if !found { + // Check inbound peers. + var inbound *serverPeer + for _, peer := range state.inboundPeers { + if cmp(peer) { + inbound = peer break } } + if inbound != nil { + inbound.Disconnect() + return nil + } - if numFound == 0 { - return errors.New("peer not found") + // Check outbound peers. + var outbound *serverPeer + for _, peer := range state.outboundPeers { + if cmp(peer) { + outbound = peer + } } - return nil + if outbound != nil { + outbound.Disconnect() + return nil + } + + return errPeerNotFound } // DisconnectByID disconnects the peer associated with the provided id. This diff --git a/server.go b/server.go index 8914090da..f61afaf33 100644 --- a/server.go +++ b/server.go @@ -442,6 +442,7 @@ type serverPeer struct { // The service flags are updated in the address manager directly once the // peer reports them. The service flags on this instance are never used. server *server + conn *connmgr.Conn remoteAddr *addrmgr.NetAddress persistent bool isWhitelisted bool @@ -455,7 +456,6 @@ type serverPeer struct { // otherwise modified during operation and thus need to consider whether or // not they need to be protected for concurrent access. - connReq atomic.Pointer[connmgr.ConnReq] continueHash atomic.Pointer[chainhash.Hash] disableRelayTx atomic.Bool knownAddresses *apbf.Filter @@ -507,11 +507,12 @@ type serverPeer struct { // newServerPeer returns a new serverPeer instance. The peer needs to be set by // the caller. -func newServerPeer(s *server, remoteAddr *addrmgr.NetAddress, isPersistent bool) *serverPeer { +func newServerPeer(s *server, conn *connmgr.Conn, remoteAddr *addrmgr.NetAddress) *serverPeer { return &serverPeer{ server: s, + conn: conn, remoteAddr: remoteAddr, - persistent: isPersistent, + persistent: s.connManager.IsPersistent(conn.ID()), knownAddresses: apbf.NewFilter(maxKnownAddrsPerPeer, knownAddrsFPRate), quit: make(chan struct{}), getDataQueue: make(chan []*wire.InvVect, maxConcurrentGetDataReqs), @@ -2177,57 +2178,6 @@ func (s *server) handleBroadcastMsg(state *peerState, bmsg *broadcastMsg) { }) } -// disconnectPeer attempts to drop the connection of a targeted peer in the -// passed peer list. Targets are identified via usage of the passed -// `compareFunc`, which should return `true` if the passed peer is the target -// peer. This function returns true on success and false if the peer is unable -// to be located. If the peer is found, and the passed callback: `whenFound' -// isn't nil, we call it with the peer as the argument before it is removed -// from the peerList, and is disconnected from the server. -func disconnectPeer(peerList map[int32]*serverPeer, compareFunc func(*serverPeer) bool, whenFound func(*serverPeer)) bool { - for addr, peer := range peerList { - if compareFunc(peer) { - if whenFound != nil { - whenFound(peer) - } - - // This is ok because we are not continuing - // to iterate so won't corrupt the loop. - delete(peerList, addr) - peer.Disconnect() - return true - } - } - return false -} - -// connToNetAddr parses and returns an address manager network address from the -// remote address associated with the given connection. -// -// This function is safe for concurrent access. -func connToNetAddr(conn net.Conn) (*addrmgr.NetAddress, error) { - addrStr := conn.RemoteAddr().String() - host, portStr, err := net.SplitHostPort(addrStr) - if err != nil { - return nil, err - } - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return nil, err - } - - addrType, addrBytes := addrmgr.EncodeHost(host) - if addrType == addrmgr.UnknownAddressType { - return nil, fmt.Errorf("unable to determine address type: %v", addrStr) - } - - // Since the host type has been successfully recognized and encoded, - // there is no need to perform a DNS lookup. - now := time.Unix(time.Now().Unix(), 0) - return addrmgr.NewNetAddressFromParams(addrType, addrBytes, uint16(port), - now, 0) -} - // handleBannedConn closes the provided connection if the remote address // associated with it is banned or the address can't be properly parsed. It // returns true when the connection is closed. @@ -2323,11 +2273,11 @@ func newPeerConfig(sp *serverPeer) *peer.Config { // instance, associates it with the connection, runs the peer (which starts all // additional server peer processing goroutines) and blocks until the peer // disconnects. -func (s *server) inboundPeerConnected(ctx context.Context, conn net.Conn) { - remoteNetAddr, err := connToNetAddr(conn) - if err != nil { - srvrLog.Debugf("Unable to create inbound peer for address %s: %v", - conn.RemoteAddr(), err) +func (s *server) inboundPeerConnected(ctx context.Context, conn *connmgr.Conn) { + remoteNetAddr, ok := conn.RemoteAddr().(*addrmgr.NetAddress) + if !ok { + srvrLog.Warnf("remote address for connection is incorrect type %T", + conn.RemoteAddr()) conn.Close() return } @@ -2337,9 +2287,9 @@ func (s *server) inboundPeerConnected(ctx context.Context, conn net.Conn) { return } - sp := newServerPeer(s, remoteNetAddr, false) - sp.isWhitelisted = isWhitelisted(remoteNetAddr) + sp := newServerPeer(s, conn, remoteNetAddr) sp.Peer = peer.NewInboundPeer(newPeerConfig(sp), conn) + sp.isWhitelisted = isWhitelisted(remoteNetAddr) if err := sp.Handshake(ctx, sp.OnVersion); err != nil { srvrLog.Debugf("Failed handshake for inbound peer %s: %v", remoteNetAddr, err) @@ -2355,31 +2305,28 @@ func (s *server) inboundPeerConnected(ctx context.Context, conn net.Conn) { // peer instance, associates it with the relevant state such as the connection // request instance and the connection itself, and start all additional server // peer processing goroutines. -func (s *server) outboundPeerConnected(ctx context.Context, c *connmgr.ConnReq, conn net.Conn) { - remoteNetAddr, err := connToNetAddr(conn) - if err != nil { - srvrLog.Debugf("Unable to create outbound peer for address %s: %v", - conn.RemoteAddr(), err) +func (s *server) outboundPeerConnected(ctx context.Context, conn *connmgr.Conn) { + remoteNetAddr, ok := conn.RemoteAddr().(*addrmgr.NetAddress) + if !ok { + srvrLog.Warnf("remote address for connection is incorrect type %T", + conn.RemoteAddr()) conn.Close() - s.connManager.Disconnect(c.ID()) + return } // Disconnect banned connections. Ideally we would never connect to a // banned peer, but the connection manager is currently unaware of banned // addresses, so this is needed. if disconnected := s.handleBannedConn(remoteNetAddr, conn); disconnected { - s.connManager.Disconnect(c.ID()) return } - sp := newServerPeer(s, remoteNetAddr, c.Permanent) - p := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr, conn) - sp.Peer = p - sp.connReq.Store(c) + sp := newServerPeer(s, conn, remoteNetAddr) + sp.Peer = peer.NewOutboundPeer(newPeerConfig(sp), conn.RemoteAddr(), conn) sp.isWhitelisted = isWhitelisted(remoteNetAddr) if err := sp.Handshake(ctx, sp.OnVersion); err != nil { - srvrLog.Debugf("Failed handshake for outbound peer %s: %v", c.Addr, err) - s.connManager.Disconnect(c.ID()) + srvrLog.Debugf("Failed handshake for outbound peer %s: %v", + conn.RemoteAddr(), err) return } sp.syncMgrPeer = netsync.NewPeer(sp.Peer) @@ -2818,21 +2765,12 @@ func (s *server) DonePeer(sp *serverPeer) { if _, ok := list[sp.ID()]; ok { if !sp.Inbound() { state.outboundGroups[sp.remoteAddr.GroupKey()]-- - connReq := sp.connReq.Load() - if connReq != nil { - s.connManager.Disconnect(connReq.ID()) - } } delete(list, sp.ID()) srvrLog.Debugf("Removed peer %s", sp) return } - connReq := sp.connReq.Load() - if connReq != nil { - s.connManager.Disconnect(connReq.ID()) - } - // Update the address manager with the last seen time. This is skipped when // running on the simulation and regression test networks since they are // only intended to connect to specified peers and actively avoid @@ -4418,15 +4356,15 @@ func newServer(ctx context.Context, profiler *profileServer, } cmgr, err := connmgr.New(&connmgr.Config{ Listeners: listeners, - OnAccept: func(conn net.Conn) { + OnAccept: func(conn *connmgr.Conn) { s.inboundPeerConnected(ctx, conn) }, RetryDuration: connectionRetryInterval, TargetOutbound: s.targetOutbound, Dial: s.attemptDcrdDial, DialTimeout: cfg.DialTimeout, - OnConnection: func(c *connmgr.ConnReq, conn net.Conn) { - s.outboundPeerConnected(ctx, c, conn) + OnConnection: func(conn *connmgr.Conn) { + s.outboundPeerConnected(ctx, conn) }, GetNewAddress: newAddressFunc, }) @@ -4435,22 +4373,21 @@ func newServer(ctx context.Context, profiler *profileServer, } s.connManager = cmgr - // Start up persistent peers. - permanentPeers := cfg.ConnectPeers - if len(permanentPeers) == 0 { - permanentPeers = cfg.AddPeers + // Add persistent peers. + persistentPeers := cfg.ConnectPeers + if len(persistentPeers) == 0 { + persistentPeers = cfg.AddPeers } - for _, addr := range permanentPeers { + for _, addr := range persistentPeers { tcpAddr, err := addrStringToNetAddr(addr) if err != nil { return nil, err } - go s.connManager.Connect(ctx, - &connmgr.ConnReq{ - Addr: tcpAddr, - Permanent: true, - }) + _, err = s.connManager.AddPersistent(tcpAddr) + if err != nil { + return nil, err + } } if !cfg.DisableRPC { @@ -4636,14 +4573,14 @@ func addrStringToNetAddr(addr string) (net.Addr, error) { return nil, fmt.Errorf("no addresses found for %s", host) } - port, err := strconv.Atoi(strPort) + port, err := strconv.ParseUint(strPort, 10, 16) if err != nil { return nil, err } return &net.TCPAddr{ IP: ips[0], - Port: port, + Port: int(port), }, nil } From 9ad909d60d27e833f5582c641e528aa484e2dd8e Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Wed, 20 May 2026 18:45:53 -0500 Subject: [PATCH 04/51] connmgr: Make max retry duration a field. The max retry duration is currently an unexported global variable that the tests override at init time. At least one of the tests also additionally overrides it for that specified test too. While this works, it is somewhat brittle and prevents the tests from being run in parallel. This improves the situation by making the max retry duration a field on the connection manager instead of a global variable and adding a test helper for creating a new connection manager that overrides it by default. Then any tests that need a different value can simply override it on their local instance. It also makes the tests parallel since they can no longer clobber one another. --- internal/connmgr/connmanager.go | 21 +++-- internal/connmgr/connmanager_test.go | 130 +++++++++++++-------------- 2 files changed, 76 insertions(+), 75 deletions(-) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 83ddf10c5..e4e75b06d 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -25,14 +25,6 @@ const ( MaxPersistent = 8 ) -var ( - // maxRetryDuration is the maximum duration a persistent connection retry - // backoff is allowed to grow to. This is necessary since the retry logic - // uses a backoff mechanism which increases the interval base times the - // number of retries that have been done. - maxRetryDuration = time.Minute * 5 -) - const ( // maxFailedAttempts is the maximum number of successive failed connection // attempts after which network failure is assumed and new connections will @@ -43,6 +35,12 @@ const ( // persistent connections. defaultRetryDuration = time.Second * 5 + // defaultMaxRetryDuration is the default maximum duration a persistent + // connection retry backoff is allowed to grow to. This is necessary since + // the retry logic uses a backoff mechanism which increases the interval + // base times the number of retries that have been done. + defaultMaxRetryDuration = time.Minute * 5 + // defaultTargetOutbound is the default number of outbound connections to // maintain. defaultTargetOutbound = 8 @@ -274,6 +272,10 @@ type ConnManager struct { // creating time and treated as immutable after that. cfg Config + // maxRetryDuration is the maximum duration a persistent connection retry + // backoff is allowed to grow to. + maxRetryDuration time.Duration + // runPersistentChan is used to signal the persistent connections handler to // launch a goroutine that attempts to always maintain an established // connection with a given address. @@ -1026,7 +1028,7 @@ func (cm *ConnManager) runPersistent(ctx context.Context, connID uint64, addr ne retryCount++ } retryWait := time.Duration(retryCount) * cm.cfg.RetryDuration - retryWait = min(retryWait, maxRetryDuration) + retryWait = min(retryWait, cm.maxRetryDuration) log.Debugf("Retrying connection to %v in %v (retries %d)", addr, retryWait, retryCount) retryAfter = time.After(retryWait) @@ -1237,6 +1239,7 @@ func New(cfg *Config) (*ConnManager, error) { cm := ConnManager{ cfg: *cfg, // Copy so caller can't mutate quit: make(chan struct{}), + maxRetryDuration: defaultMaxRetryDuration, runPersistentChan: make(chan *persistentEntry, MaxPersistent), activeOutboundsSem: makeSemaphore(cfg.TargetOutbound), persistent: make(map[uint64]*persistentEntry, MaxPersistent), diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 6001ec75a..fda415cd2 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -18,12 +18,11 @@ import ( "time" ) -func init() { - // Override the max retry duration when running tests. - maxRetryDuration = 2 * time.Millisecond -} - const ( + // defaultTestMaxRetryDuration is the default max duration a connection + // retry backoff is allowed to grow to when running tests. + defaultTestMaxRetryDuration = 2 * time.Millisecond + // connTestReceiveTimeout is the default receive timeout used throughout the // tests when expecting to receive connections to prevent test hangs. connTestReceiveTimeout = 10 * time.Millisecond @@ -109,18 +108,32 @@ func mockDialer(ctx context.Context, network, addr string) (net.Conn, error) { return c, ctx.Err() } +// newTestConnManager returns a new connection manager with the provided +// configuration and some timeout tweaks so that it is suitable for use in the +// tests. +func newTestConnManager(t *testing.T, cfg *Config) *ConnManager { + t.Helper() + + cmgr, err := New(cfg) + if err != nil { + t.Fatalf("New: unexpected error: %v", err) + } + cmgr.maxRetryDuration = defaultTestMaxRetryDuration + return cmgr +} + // TestNewConfig tests that new ConnManager config is validated as expected. func TestNewConfig(t *testing.T) { + t.Parallel() + _, err := New(&Config{}) if err == nil { t.Fatal("New expected error: 'Dial can't be nil', got nil") } - _, err = New(&Config{ + + newTestConnManager(t, &Config{ Dial: mockDialer, }) - if err != nil { - t.Fatalf("New unexpected error: %v", err) - } } // assertConnID ensures the provided connection has the given ID. @@ -235,17 +248,16 @@ func assertNoConnReceived(t *testing.T, ch <-chan *Conn) { // using [ConnManager.Connect] are handled and that no other connections are // made. func TestConnectMode(t *testing.T) { + t.Parallel() + connected := make(chan *Conn) - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ TargetOutbound: 2, Dial: mockDialer, OnConnection: func(conn *Conn) { connected <- conn }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) addr := mustParseAddrPort("127.0.0.1:18555") @@ -264,10 +276,12 @@ func TestConnectMode(t *testing.T) { // configuration option by waiting until all connections are established and // ensuring they are the only connections made. func TestTargetOutbound(t *testing.T) { + t.Parallel() + const targetOutbound = 10 var nextAddr atomic.Uint32 connected := make(chan *Conn) - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ TargetOutbound: targetOutbound, Dial: mockDialer, GetNewAddress: func() (net.Addr, error) { @@ -278,9 +292,6 @@ func TestTargetOutbound(t *testing.T) { connected <- conn }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Ensure only the expected number of target outbound conns are established @@ -297,9 +308,11 @@ func TestTargetOutbound(t *testing.T) { // TestRetryPersistent tests that persistent connections are retried. func TestRetryPersistent(t *testing.T) { + t.Parallel() + connected := make(chan *Conn) disconnected := make(chan *Conn) - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ RetryDuration: time.Millisecond, TargetOutbound: 1, Dial: mockDialer, @@ -310,9 +323,6 @@ func TestRetryPersistent(t *testing.T) { disconnected <- conn }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) addr := mustParseAddrPort("127.0.0.1:18555") @@ -349,11 +359,13 @@ func TestRetryPersistent(t *testing.T) { // We have a timed dialer which initially returns err but after RetryDuration // hits maxRetryDuration returns a mock conn. func TestMaxRetryDuration(t *testing.T) { + t.Parallel() + // This test relies on the current value of the max retry duration defined // in the tests, so assert it. - if maxRetryDuration != 2*time.Millisecond { + if defaultTestMaxRetryDuration != 2*time.Millisecond { t.Fatalf("max retry duration of %v is not the required value for test", - maxRetryDuration) + defaultTestMaxRetryDuration) } networkUp := make(chan struct{}) @@ -367,7 +379,7 @@ func TestMaxRetryDuration(t *testing.T) { } connected := make(chan *Conn) - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ RetryDuration: time.Millisecond, TargetOutbound: 1, Dial: timedDialer, @@ -375,9 +387,6 @@ func TestMaxRetryDuration(t *testing.T) { connected <- conn }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) connID, err := cmgr.AddPersistent(mustParseAddrPort("127.0.0.1:18555")) @@ -403,6 +412,8 @@ func TestMaxRetryDuration(t *testing.T) { // TestNetworkFailure tests that the connection manager handles a network // failure gracefully. func TestNetworkFailure(t *testing.T) { + t.Parallel() + var closeOnce sync.Once const targetOutbound = 5 const retryTimeout = time.Millisecond * 5 @@ -418,7 +429,7 @@ func TestNetworkFailure(t *testing.T) { return nil, errors.New("network down") } var nextAddr atomic.Uint32 - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ TargetOutbound: targetOutbound, RetryDuration: retryTimeout, Dial: errDialer, @@ -431,9 +442,6 @@ func TestNetworkFailure(t *testing.T) { conn.RemoteAddr()) }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Shutdown the connection manager after the max failed attempts is reached @@ -464,13 +472,11 @@ func TestNetworkFailure(t *testing.T) { // responsive when there are multiple simultaneous failed connections for // persistent conns in the retry state. func TestMultipleFailedConns(t *testing.T) { + t.Parallel() + // Override the max retry duration for this test since it relies on having // multiple connections in the retry state. - curMaxRetryDuration := maxRetryDuration - maxRetryDuration = 500 * time.Millisecond - defer func() { - maxRetryDuration = curMaxRetryDuration - }() + const maxRetryDuration = 500 * time.Millisecond const targetFailed = 5 var dials atomic.Uint32 @@ -483,13 +489,11 @@ func TestMultipleFailedConns(t *testing.T) { } return nil, errors.New("network down") } - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ RetryDuration: maxRetryDuration, Dial: errDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } + cmgr.maxRetryDuration = maxRetryDuration _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish several connection requests to localhost IPs. @@ -531,26 +535,25 @@ func TestMultipleFailedConns(t *testing.T) { // TestShutdownFailedConns tests that failed connections are ignored after // connmgr is shutdown. func TestShutdownFailedConns(t *testing.T) { + t.Parallel() + var closeOnce sync.Once dialed := make(chan struct{}) waitDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { closeOnce.Do(func() { close(dialed) }) return nil, errors.New("network down") } - cmgr, err := New(&Config{ - RetryDuration: maxRetryDuration, + cmgr := newTestConnManager(t, &Config{ + RetryDuration: defaultTestMaxRetryDuration, Dial: waitDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Shutdown the connection manager during the retry timeout after a failed // dial attempt. go func() { <-dialed - time.Sleep(maxRetryDuration / 2) + time.Sleep(cmgr.maxRetryDuration / 2) shutdown() }() @@ -566,6 +569,8 @@ func TestShutdownFailedConns(t *testing.T) { // connection correctly cancels the context used to dial and removes the // internal state. func TestRemovePendingConnection(t *testing.T) { + t.Parallel() + // Create a conn manager with an instance of a dialer that'll never succeed. dialed := make(chan struct{}) canceled := make(chan struct{}) @@ -575,12 +580,9 @@ func TestRemovePendingConnection(t *testing.T) { close(canceled) return nil, errors.New("error") } - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Dial: indefiniteDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a connection request to a localhost IP. @@ -622,6 +624,8 @@ func TestRemovePendingConnection(t *testing.T) { // connection will not execute the on connection callback, even if a pending // retry succeeds. func TestCancelIgnoreDelayedConnection(t *testing.T) { + t.Parallel() + const retryTimeout = 10 * time.Millisecond // Setup a dialer that returns an error on the first attempt and then blocks @@ -645,16 +649,13 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { } connected := make(chan *Conn) - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Dial: failingDialer, RetryDuration: retryTimeout, OnConnection: func(conn *Conn) { connected <- conn }, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a persistent connection to a localhost IP. @@ -694,6 +695,8 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { // dialer that blocks for three times the configured dial timeout before // connecting and ensuring the connection fails as expected. func TestDialTimeout(t *testing.T) { + t.Parallel() + // Create a connection manager instance with a dialer that blocks for three // times the configured dial timeout before connecting. const dialTimeout = time.Millisecond * 20 @@ -708,13 +711,10 @@ func TestDialTimeout(t *testing.T) { return mockDialer(ctx, network, addr) } - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Dial: timeoutDialer, DialTimeout: dialTimeout, }) - if err != nil { - t.Fatalf("New error: %v", err) - } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a connection to a localhost IP. @@ -737,6 +737,8 @@ func TestDialTimeout(t *testing.T) { // TestConnectContext ensures the [ConnManager.Connect] method works as intended // when provided with a context that is canceled before a dial attempt succeeds. func TestConnectContext(t *testing.T) { + t.Parallel() + // Create a connection manager instance with a dialer that blocks until its // provided context is canceled. dialed := make(chan struct{}) @@ -745,12 +747,9 @@ func TestConnectContext(t *testing.T) { <-ctx.Done() return nil, ctx.Err() } - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Dial: indefiniteDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Establish a connection request to a localhost IP with a separate context @@ -850,22 +849,21 @@ func newMockListener(localAddr string) *mockListener { // TestListeners ensures providing listeners to the connection manager along // with an accept callback works properly. func TestListeners(t *testing.T) { + t.Parallel() + // Setup a connection manager with a couple of mock listeners that // notify a channel when they receive mock connections. receivedConns := make(chan *Conn) listener1 := newMockListener("127.0.0.1:9108") listener2 := newMockListener("127.0.0.1:9208") listeners := []net.Listener{listener1, listener2} - cmgr, err := New(&Config{ + cmgr := newTestConnManager(t, &Config{ Listeners: listeners, OnAccept: func(conn *Conn) { receivedConns <- conn }, Dial: mockDialer, }) - if err != nil { - t.Fatalf("New error: %v", err) - } _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) // Fake a couple of mock connections to each of the listeners. From 61122e4761d0a559a3b833580b669c2e0ef0dee7 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 24 May 2026 17:25:31 -0500 Subject: [PATCH 05/51] connmgr: Correct shutdown failed conns test. This updates the test for checking the connection manager cleanly shuts down with failed conns to actualy test what it is intended to. Manual connections do not automatically retry, only persistent connections. --- internal/connmgr/connmanager_test.go | 29 +++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index fda415cd2..166387a53 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -537,6 +537,7 @@ func TestMultipleFailedConns(t *testing.T) { func TestShutdownFailedConns(t *testing.T) { t.Parallel() + const retryTimeout = time.Second var closeOnce sync.Once dialed := make(chan struct{}) waitDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -544,22 +545,28 @@ func TestShutdownFailedConns(t *testing.T) { return nil, errors.New("network down") } cmgr := newTestConnManager(t, &Config{ - RetryDuration: defaultTestMaxRetryDuration, + RetryDuration: retryTimeout, Dial: waitDialer, }) - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + cmgr.maxRetryDuration = retryTimeout + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Add a persistent connection. + addr := mustParseAddrPort("127.0.0.1:18555") + _, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } // Shutdown the connection manager during the retry timeout after a failed // dial attempt. - go func() { - <-dialed - time.Sleep(cmgr.maxRetryDuration / 2) - shutdown() - }() - - // Establish a connection. - addr := mustParseAddrPort("127.0.0.1:18555") - go cmgr.Connect(ctx, addr) + select { + case <-dialed: + case <-time.After(connTestNonReceiveTimeout): + t.Fatal("timeout waiting for dial") + } + time.Sleep(connTestNonReceiveTimeout) + shutdown() // Ensure clean shutdown of connection manager. wg.Wait() From 1ef516b1380b7e3defc6557eae33351c46cab4df Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:10 -0500 Subject: [PATCH 06/51] connmgr: Add double close tests. This adds tests to ensure closing a connection multiple times works as intended. --- internal/connmgr/connmanager_test.go | 42 ++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 166387a53..b298f9198 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -306,6 +306,48 @@ func TestTargetOutbound(t *testing.T) { wg.Wait() } +// TestDoubleClose ensures closing a connection multiple times is a noop after +// the first call. +func TestDoubleClose(t *testing.T) { + t.Parallel() + + connected := make(chan *Conn) + cmgr := newTestConnManager(t, &Config{ + TargetOutbound: 1, + Dial: mockDialer, + GetNewAddress: func() (net.Addr, error) { + return mustParseAddrPort("127.0.0.1:18555"), nil + }, + OnConnection: func(conn *Conn) { + connected <- conn + }, + }) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Wait for the connection to be established. + conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) + + // Override the close func to cleanly detect closes. + var numClosed uint32 + origOnClose := conn.onClose + conn.onClose = func() { + numClosed++ + origOnClose() + } + + // Close the connection multiple times and make sure it only happens once. + for range 3 { + conn.Close() + } + if numClosed != 1 { + t.Fatal("connection closed more than once") + } + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + // TestRetryPersistent tests that persistent connections are retried. func TestRetryPersistent(t *testing.T) { t.Parallel() From a3f4690705f37fec1c7dccf0ecd7056bad371a86 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:10 -0500 Subject: [PATCH 07/51] connmgr: Add duplicate conn rejection tests. This adds tests to ensure duplication connections are rejected for all possible states. --- internal/connmgr/connmanager_test.go | 129 +++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index b298f9198..aaa59b273 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -934,3 +934,132 @@ func TestListeners(t *testing.T) { shutdown() wg.Wait() } + +// TestRejectDuplicateConns ensures duplicate addresses are rejected. This +// includes: +// - Attempts to dial addresses that already have pending, established, and +// persistent connections (via [ConnManager.Connect] +// - Attempts to add duplicate persistent conns (via [ConnManager.AddPersistent]) +// - Attempts to receive inbound remote addresses that already have pending, +// established, and persistent connections +func TestRejectDuplicateConns(t *testing.T) { + t.Parallel() + + var closeDialedOnce sync.Once + inboundConns := make(chan *Conn) + listener := newMockListener("127.0.0.1:18109") + connected := make(chan *Conn) + disconnected := make(chan *Conn) + dialed := make(chan struct{}) + pending := make(chan struct{}) + pendingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + closeDialedOnce.Do(func() { close(dialed) }) + <-pending + return mockDialer(ctx, network, addr) + } + cmgr := newTestConnManager(t, &Config{ + Listeners: []net.Listener{listener}, + OnAccept: func(conn *Conn) { + inboundConns <- conn + }, + Dial: pendingDialer, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Dial a manual connection and wait for it to become pending. + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("did not receive pending dial before timeout") + } + assertPendingAddr(t, cmgr, addr) + + // Duplicate connect to the pending address should be rejected. + if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyPending) { + t.Fatalf("did not reject duplicate pending connection, err: %v", err) + } + + // Inbound attempts from the pending outbound address should be rejected. + go listener.Connect(addr) + assertNoConnReceived(t, inboundConns) + + // Allow the pending connection to complete. + close(pending) + conn := assertConnReceived(t, connected, 0, ConnTypeManual) + + // Duplicate connect to the established address should be rejected. + if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { + t.Fatalf("did not reject duplicate active connection, err: %v", err) + } + + // Inbound attempts from the established outbound address should be + // rejected. + go listener.Connect(addr) + assertNoConnReceived(t, inboundConns) + + // Close the connection and wait for the disconnect. + conn.Close() + assertConnReceived(t, disconnected, conn.ID(), ConnTypeManual) + + // Add a persistent connection back to the same address and wait for it to + // connect since there are no longer any connections to the address. + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + assertConnReceived(t, connected, connID, ConnTypeManual) + + // Duplicate persistent connection attempts should be rejected. + _, err = cmgr.AddPersistent(addr) + if !errors.Is(err, ErrDuplicatePersistent) { + t.Fatalf("did not reject duplicate persistent connection, err: %v", err) + } + + // Manual connection attempts to persistent connection should be rejected. + _, err = cmgr.Connect(ctx, addr) + if !errors.Is(err, ErrDuplicatePersistent) { + t.Fatalf("did not reject manual connection to persistent, err: %v", err) + } + + // Inbound atempts from the persistent address should be rejected. + go listener.Connect(addr) + assertNoConnReceived(t, inboundConns) + + // Remove the persistent connection, wait for it to disconnect, and ensure + // it is actually removed. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("failed to remove persistent connection: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertRemovedPersistent(t, cmgr, addr) + + // Inbound connections from the same address should now succeed. + go listener.Connect(addr) + assertConnReceived(t, inboundConns, 0, ConnTypeInbound) + + // Manual connection attempts to the inbound address should be rejected. + if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { + t.Fatalf("did not reject outbound for existing inbound conn, err: %v", + err) + } + + // Attempts to add a persistent connection to an existing inbound should be + // rejected. + _, err = cmgr.AddPersistent(addr) + if !errors.Is(err, ErrAlreadyConnected) { + t.Fatalf("did not reject persistent conn for existing inbound conn: %v", + err) + } + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} From 9e6fe3762d27283861c6fc49634764e7ef13b6ae Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:11 -0500 Subject: [PATCH 08/51] connmgr: Add max persistent conns test. This adds tests to ensure attempts to add more than the maximum allowed number of persistent are rejected. --- internal/connmgr/connmanager_test.go | 80 ++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index aaa59b273..a4dce9dd9 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -396,6 +396,86 @@ func TestRetryPersistent(t *testing.T) { wg.Wait() } +// TestMaxPersistent ensures [ConnManager.AddPersistent] limits the maximum +// number of persistent connections including a removal and addition of a new +// one after achieving the max. +func TestMaxPersistent(t *testing.T) { + t.Parallel() + + connected := make(chan *Conn) + disconnected := make(chan *Conn) + cmgr := newTestConnManager(t, &Config{ + Dial: mockDialer, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + var numAddrs uint32 + nextAddr := func() net.Addr { + numAddrs++ + addrStr := fmt.Sprintf("127.0.0.%d:18555", numAddrs) + return mustParseAddrPort(addrStr) + } + + // Add the maximum allowed number of persistent conns. + connIDs := make([]uint64, 0, MaxPersistent) + addrs := make([]net.Addr, 0, MaxPersistent) + for range MaxPersistent { + addr := nextAddr() + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection %v: %v", addr, err) + } + connIDs = append(connIDs, connID) + addrs = append(addrs, addr) + + // Wait for the connection. + assertConnReceived(t, connected, connID, ConnTypeManual) + } + + // Attempting to add more than the max allowed number of persistent conns + // should be rejected. + _, err := cmgr.AddPersistent(nextAddr()) + if !errors.Is(err, ErrMaxPersistent) { + t.Fatalf("did not reject > max persistent, err: %v", err) + } + + // Ensure disconnecting the persistent conn does not incorrectly decrement + // the count. + connID, addr := connIDs[0], addrs[0] + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("failed to disconnect persistent conn %v: %v", addr, err) + } + _, err = cmgr.AddPersistent(nextAddr()) + if !errors.Is(err, ErrMaxPersistent) { + t.Fatalf("did not reject max persistent after dc, err: %v", err) + } + + // Remove the first persistent connection, wait for it to disconnect, and + // ensure it is actually removed. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("failed to remove persistent conn %v: %v", addr, err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertRemovedPersistent(t, cmgr, addr) + + // A new persistent conn should now be allowed. + addr = nextAddr() + _, err = cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection %v: %v", addr, err) + } + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + // TestMaxRetryDuration tests the maximum retry duration. // // We have a timed dialer which initially returns err but after RetryDuration From 354adde8c96923390e8eaa49634d2feebfebc752 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:12 -0500 Subject: [PATCH 09/51] connmgr: Add disconnect by id tests. This adds tests to ensure the Disconnect method properly disconnects pending and established connections for both non-persistent and persistent connections. --- internal/connmgr/connmanager_test.go | 151 +++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index a4dce9dd9..70389177c 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -272,6 +272,157 @@ func TestConnectMode(t *testing.T) { wg.Wait() } +// TestDisconnect ensures that [ConnManager.Disconnect] properly disconnects +// pending and established connections for both non-persistent and persistent +// connections. +func TestDisconnect(t *testing.T) { + t.Parallel() + + // Create a connection manager instance with a dialer that has a few + // synchronization channels to notify when a dial attempt is made, to keep + // connection attempts in a pending state, and to notify when the context + // for the attempt is canceled. Whether or not to wait/send the signals are + // controlled by the associated atomic flags. + connected := make(chan *Conn) + disconnected := make(chan *Conn) + dialed := make(chan struct{}) + pending := make(chan struct{}) + canceled := make(chan struct{}) + var notifyDialed, waitForPending, notifyCanceled atomic.Bool + pendingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + if notifyDialed.Load() { + dialed <- struct{}{} + } + if waitForPending.Load() { + <-pending + } + conn, err := mockDialer(ctx, network, addr) + if errors.Is(err, context.Canceled) && notifyCanceled.Load() { + canceled <- struct{}{} + } + return conn, err + } + cmgr := newTestConnManager(t, &Config{ + Dial: pendingDialer, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Attempt a connection to a localhost IP. + notifyDialed.Store(true) + waitForPending.Store(true) + notifyCanceled.Store(true) + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) + + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for dial") + } + assertPendingAddr(t, cmgr, addr) + + // Disconnect the connection attempt while it's still pending. + connID, _ := pendingAddrConnID(cmgr, addr) + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + + // Allow the dialer to proceed with the disconnected connection attempt and + // then wait for the dialer to signal the context associated with the dial + // was canceled. Finally, ensure the internal pending state is removed. + select { + case pending <- struct{}{}: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting to signal pending") + } + select { + case <-canceled: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for cancel") + } + if _, ok := pendingAddrConnID(cmgr, addr); ok { + t.Fatalf("connection %s is still pending", addr) + } + + // Start a connection attempt and wait for it to be established. + notifyDialed.Store(false) + waitForPending.Store(false) + notifyCanceled.Store(false) + go cmgr.Connect(ctx, addr) + conn := assertConnReceived(t, connected, 0, ConnTypeManual) + + // Disconnect the established connection and wait for the disconnect + // notification to ensure it is disconnected as intended. + connID = conn.ID() + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + + // Add a persistent connection back to the same address. + notifyDialed.Store(true) + waitForPending.Store(true) + notifyCanceled.Store(true) + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for dial") + } + assertPendingAddr(t, cmgr, addr) + + // Disconnect the persistent connection attempt while it's still pending. + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + + // Allow the dialer to proceed with the disconnected persistent connection + // attempt and then wait for the dialer to signal the context associated + // with the dial was canceled. + select { + case pending <- struct{}{}: + // Ensure the reconnect attempt doesn't notify the dialed chan or + // wait for the pending chan. + notifyDialed.Store(false) + waitForPending.Store(false) + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting to signal pending") + } + select { + case <-canceled: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for cancel") + } + + // Wait for the retry to be established. + assertConnReceived(t, connected, connID, ConnTypeManual) + + // Disconnect the established persistent connection and wait for the + // disconnect notification to ensure it is disconnected as intended. + if err := cmgr.Disconnect(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + // TestTargetOutbound tests the target number of outbound connections // configuration option by waiting until all connections are established and // ensuring they are the only connections made. From bd5aaf2d92b949ea208c7da087ff13553336d1bf Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:12 -0500 Subject: [PATCH 10/51] connmgr: Add remove by id tests. This adds tests to ensure the Remove method properly disconnects and removes pending and established connections for both non-persistent and persistent connections. --- internal/connmgr/connmanager_test.go | 167 +++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 70389177c..0e22b60c8 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -423,6 +423,173 @@ func TestDisconnect(t *testing.T) { wg.Wait() } +// TestRemove ensures that [ConnManager.Remove] properly removes pending and +// established connections for both non-persistent and persistent connections. +// +// It also ensures removal of an invalid ID returns the expected error. +func TestRemove(t *testing.T) { + t.Parallel() + + // Create a connection manager instance with a dialer that has a few + // synchronization channels to notify when a dial attempt is made, to keep + // connection attempts in a pending state, and to notify when the context + // for the attempt is canceled. Whether or not to wait/send the signals are + // controlled by the associated atomic flags. + connected := make(chan *Conn) + disconnected := make(chan *Conn) + dialed := make(chan struct{}) + pending := make(chan struct{}) + canceled := make(chan struct{}) + var notifyDialed, waitForPending, notifyCanceled atomic.Bool + pendingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + if notifyDialed.Load() { + dialed <- struct{}{} + } + if waitForPending.Load() { + <-pending + } + conn, err := mockDialer(ctx, network, addr) + if errors.Is(err, context.Canceled) && notifyCanceled.Load() { + canceled <- struct{}{} + } + return conn, err + } + cmgr := newTestConnManager(t, &Config{ + Dial: pendingDialer, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Ensure removing an ID that doesn't exist returns the expected error. + if err := cmgr.Remove(^uint64(0)); !errors.Is(err, ErrNotFound) { + t.Fatalf("mismatched remove error: got %v, want %v", err, ErrNotFound) + } + + // Attempt a connection to a localhost IP. + notifyDialed.Store(true) + waitForPending.Store(true) + notifyCanceled.Store(true) + addr := mustParseAddrPort("127.0.0.1:18555") + go cmgr.Connect(ctx, addr) + + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for dial") + } + assertPendingAddr(t, cmgr, addr) + + // Remove the connection attempt while it's still pending. + connID, _ := pendingAddrConnID(cmgr, addr) + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected remove err: %v", err) + } + + // Allow the dialer to proceed with the removed connection attempt and then + // wait for the dialer to signal the context associated with the dial was + // canceled. Finally, ensure the internal pending state is removed. + select { + case pending <- struct{}{}: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting to signal pending") + } + select { + case <-canceled: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for cancel") + } + if _, ok := pendingAddrConnID(cmgr, addr); ok { + t.Fatalf("connection %s is still pending", addr) + } + + // Start a connection attempt and wait for it to be established. + notifyDialed.Store(false) + waitForPending.Store(false) + notifyCanceled.Store(false) + go cmgr.Connect(ctx, addr) + conn := assertConnReceived(t, connected, 0, ConnTypeManual) + + // Remove the established connection and wait for the disconnect + // notification to ensure it is disconnected as intended. + connID = conn.ID() + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + + // Add a persistent connection back to the same address. + notifyDialed.Store(true) + waitForPending.Store(true) + notifyCanceled.Store(true) + connID, err := cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + + // Wait for the connection manager to attempt to dial and ensure the + // connection is marked as pending while the dialer is blocked. + select { + case <-dialed: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for dial") + } + assertPendingAddr(t, cmgr, addr) + + // Remove the persistent connection attempt while it's still pending. + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + + // Allow the dialer to proceed with the removed persistent connection + // attempt and then wait for the dialer to signal the context associated + // with the dial was canceled. + select { + case pending <- struct{}{}: + // Ensure the reconnect attempt doesn't notify the dialed chan or + // wait for the pending chan. + notifyDialed.Store(false) + waitForPending.Store(false) + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting to signal pending") + } + select { + case <-canceled: + case <-time.After(time.Millisecond * 5): + t.Fatal("timeout waiting for cancel") + } + + // Add a persistent connection back to the same address and wait for it to + // be established. + notifyDialed.Store(false) + waitForPending.Store(false) + notifyCanceled.Store(false) + connID, err = cmgr.AddPersistent(addr) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + conn2 := assertConnReceived(t, connected, connID, ConnTypeManual) + + // Remove the established persistent connection and wait for the disconnect + // notification to ensure it is disconnected as intended. Also, ensure the + // persistent connection entry is removed. + connID = conn2.ID() + if err := cmgr.Remove(connID); err != nil { + t.Fatalf("unexpected disconnect err: %v", err) + } + assertConnReceived(t, disconnected, connID, ConnTypeManual) + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() +} + // TestTargetOutbound tests the target number of outbound connections // configuration option by waiting until all connections are established and // ensuring they are the only connections made. From 3b78fbb51e4449068af9fc27f52d70ca459049a8 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 17 May 2026 07:28:13 -0500 Subject: [PATCH 11/51] connmgr: Update README.md. This updates the connmgr package README.md to match the new design and capabilities. --- internal/connmgr/README.md | 58 +++++++++++++++++++++++---------- internal/connmgr/connmanager.go | 2 ++ internal/connmgr/doc.go | 21 ------------ 3 files changed, 42 insertions(+), 39 deletions(-) delete mode 100644 internal/connmgr/doc.go diff --git a/internal/connmgr/README.md b/internal/connmgr/README.md index 46114366a..b9299d864 100644 --- a/internal/connmgr/README.md +++ b/internal/connmgr/README.md @@ -5,26 +5,48 @@ connmgr [![ISC License](https://img.shields.io/badge/license-ISC-blue.svg)](http://copyfree.org) [![Doc](https://img.shields.io/badge/doc-reference-blue.svg)](https://pkg.go.dev/github.com/decred/dcrd/internal/connmgr) -Package connmgr implements a generic Decred network connection manager. - ## Overview -This package handles all the general connection concerns such as maintaining a -set number of outbound connections, sourcing peers, banning, limiting max -connections, tor lookup, etc. - -The package provides a generic connection manager which is able to accept -connection requests from a source or a set of given addresses, dial them and -notify the caller on connections. The main intended use is to initialize a pool -of active connections and maintain them to remain connected to the P2P network. - -In addition the connection manager provides the following utilities: - -- Notifications on connections or disconnections -- Handle failures and retry new addresses from the source -- Connect only to specified addresses -- Permanent connections with increasing backoff retry timers -- Disconnect or Remove an established connection +Package `connmgr` provides a flexible and robust context-aware connection +manager for inbound, outbound, and persistent network connections with retry +logic. + +It handles all general connection lifecycle concerns such as accepting inbound +connections, automatically maintaining a set number of outbound connections, +maintaining persistent connections, and limiting max connections. + +The design has a strong emphasis on reliability, readability, and efficiency under high connection load while also aiming to provide an ergonomic API. + +The following is a brief overview of the key features: + +- Full context support +- Inbound listening + - Accepts inbound connections on provided `Listeners` + - Uses connection shedding for rejected inbound connections +- Automatic outbound maintenance + - Maintains up to `TargetOutbound` normal outbound connections via a provided + address source (`GetNewAddress`) +- Persistent connections + - Maintains up to `MaxPersistent` addresses that are automatically retried + with exponential backoff on disconnect +- Manual connections + - Supports manual connection establishment via `Connect` +- Duplicate address prevention + - Rejects duplicate connections to and from the same address (host:port) +- Rich managed connections via `Conn` + - Connection types for differentiated handling + - Automatic cleanup on connection close + - Concrete parsed address access +- Manual disconnection and removal + - Ability to disconnect / remove established, pending, and persistent + connections via `Disconnect` and `Remove` +- Notification callbacks + - Provides callbacks for connection establishment and disconnects +- Graceful network outage handling + - Automatic connection attempts are throttled during network outages +- Clear and actionable programatically-detectable errors + +A full suite of tests is provided to help ensure proper functionality. ## License diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index e4e75b06d..453e4a33d 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -3,6 +3,8 @@ // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. +// Package connmgr provides a robust connection manager for inbound, outbound, +// and persistent network connections with retry logic. package connmgr import ( diff --git a/internal/connmgr/doc.go b/internal/connmgr/doc.go deleted file mode 100644 index 3eb872c3c..000000000 --- a/internal/connmgr/doc.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2016 The btcsuite developers -// Copyright (c) 2017-2022 The Decred developers -// Use of this source code is governed by an ISC -// license that can be found in the LICENSE file. - -/* -Package connmgr implements a generic Decred network connection manager. - -# Deprecated - -This module is deprecated and is no longer maintained. Callers are encouraged -to use github.com/decred/dcrd/addrmgr/vX for methods that were moved to it -instead. - -# Connection Manager Overview - -Connection manager handles all the general connection concerns such as -maintaining a set number of outbound connections, sourcing peers, banning, -limiting max connections, tor lookup, etc. -*/ -package connmgr From 1cc2fa725c10039873fe7d0344f5fd5460275d37 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Fri, 22 May 2026 20:58:21 -0500 Subject: [PATCH 12/51] connmgr: Add internal state test assertions. This adds a couple of test helpers for asserting the internal state of the connection manager updates all tests to call the new helpers throughout. The first one asserts the internal maps are all coherent and do not violate any preconditions. The second one asserts clean shutdown. --- internal/connmgr/connmanager_test.go | 162 +++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 0e22b60c8..ecbb0ccf0 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -12,6 +12,7 @@ import ( "io" "net" "net/netip" + "reflect" "sync" "sync/atomic" "testing" @@ -122,6 +123,88 @@ func newTestConnManager(t *testing.T, cfg *Config) *ConnManager { return cmgr } +// assertConnManagerInternalState ensures the internal state of the passed +// connection manager instance is coherent. +func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { + t.Helper() + + cm.connMtx.Lock() + defer cm.connMtx.Unlock() + + // Assert established persistent conns have the correct connection type. + for id, conn := range cm.active { + if _, ok := cm.persistent[id]; ok { + want := ConnTypeManual + if got := conn.Type(); got != want { + t.Fatalf("bad conn type in active map: %v != %v", got, want) + } + } + } + + // Assert the pending and active maps are mutually exclusive for both conn + // IDs and addrs. + // + // Also build a map of addrs to conn IDs in the pending, active, and + // persistent maps for the checks below. + connIDByAddr := make(map[string]uint64) + for id, info := range cm.pending { + if _, ok := cm.active[id]; ok { + t.Fatalf("conn ID %d is both pending and active", id) + } + connIDByAddr[info.addr.String()] = id + } + for id, conn := range cm.active { + if _, ok := cm.pending[id]; ok { + t.Fatalf("conn ID %d is both pending and active", id) + } + addrStr := conn.remoteAddr.String() + if _, ok := connIDByAddr[addrStr]; ok { + t.Fatalf("addr %s is both pending and active", addrStr) + } + connIDByAddr[addrStr] = id + } + for id, entry := range cm.persistent { + // Assert the conn ID of established/pending persistent conns matches. + addrStr := entry.addr.String() + if existingID, ok := connIDByAddr[addrStr]; ok && existingID != id { + t.Fatalf("conn ID for addr %s mismatch: %d != %d", addrStr, + existingID, id) + } + connIDByAddr[addrStr] = id + } + + // Assert the addr to conn ID mappings match the values obtained from + // manually constructing them. + if !reflect.DeepEqual(cm.connIDByAddr, connIDByAddr) { + t.Fatalf("mismatched conn ID by addr maps\ngot: %v\nwant %v", + cm.connIDByAddr, connIDByAddr) + } +} + +// assertConnManagerCleanShutdown ensures the internal state of the passed +// connection manager is fully cleaned up as expected. It must only be called +// after [ConnManager.Run] returns. +func assertConnManagerCleanShutdown(t *testing.T, cm *ConnManager) { + t.Helper() + + cm.connMtx.Lock() + defer cm.connMtx.Unlock() + + if len(cm.active) != 0 { + t.Fatalf("active map is not empty: %d entries", len(cm.active)) + } + if len(cm.pending) != 0 { + t.Fatalf("pending map is not empty: %d entries", len(cm.pending)) + } + if len(cm.persistent) != 0 { + t.Fatalf("persistent map is not empty: %d entries", len(cm.persistent)) + } + if len(cm.connIDByAddr) != 0 { + t.Fatalf("conn ID by addr map not empty: %d entries", + len(cm.connIDByAddr)) + } +} + // TestNewConfig tests that new ConnManager config is validated as expected. func TestNewConfig(t *testing.T) { t.Parallel() @@ -266,10 +349,12 @@ func TestConnectMode(t *testing.T) { // Ensure that only a single connection is received. assertConnReceived(t, connected, 0, ConnTypeManual) assertNoConnReceived(t, connected) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestDisconnect ensures that [ConnManager.Disconnect] properly disconnects @@ -328,12 +413,14 @@ func TestDisconnect(t *testing.T) { t.Fatal("timeout waiting for dial") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Disconnect the connection attempt while it's still pending. connID, _ := pendingAddrConnID(cmgr, addr) if err := cmgr.Disconnect(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Allow the dialer to proceed with the disconnected connection attempt and // then wait for the dialer to signal the context associated with the dial @@ -351,6 +438,7 @@ func TestDisconnect(t *testing.T) { if _, ok := pendingAddrConnID(cmgr, addr); ok { t.Fatalf("connection %s is still pending", addr) } + assertConnManagerInternalState(t, cmgr) // Start a connection attempt and wait for it to be established. notifyDialed.Store(false) @@ -358,6 +446,7 @@ func TestDisconnect(t *testing.T) { notifyCanceled.Store(false) go cmgr.Connect(ctx, addr) conn := assertConnReceived(t, connected, 0, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Disconnect the established connection and wait for the disconnect // notification to ensure it is disconnected as intended. @@ -366,6 +455,7 @@ func TestDisconnect(t *testing.T) { t.Fatalf("unexpected disconnect err: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Add a persistent connection back to the same address. notifyDialed.Store(true) @@ -375,6 +465,7 @@ func TestDisconnect(t *testing.T) { if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } + assertConnManagerInternalState(t, cmgr) // Wait for the connection manager to attempt to dial and ensure the // connection is marked as pending while the dialer is blocked. @@ -384,11 +475,13 @@ func TestDisconnect(t *testing.T) { t.Fatal("timeout waiting for dial") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Disconnect the persistent connection attempt while it's still pending. if err := cmgr.Disconnect(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Allow the dialer to proceed with the disconnected persistent connection // attempt and then wait for the dialer to signal the context associated @@ -410,6 +503,7 @@ func TestDisconnect(t *testing.T) { // Wait for the retry to be established. assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Disconnect the established persistent connection and wait for the // disconnect notification to ensure it is disconnected as intended. @@ -417,10 +511,12 @@ func TestDisconnect(t *testing.T) { t.Fatalf("unexpected disconnect err: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestRemove ensures that [ConnManager.Remove] properly removes pending and @@ -485,6 +581,7 @@ func TestRemove(t *testing.T) { t.Fatal("timeout waiting for dial") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Remove the connection attempt while it's still pending. connID, _ := pendingAddrConnID(cmgr, addr) @@ -508,6 +605,7 @@ func TestRemove(t *testing.T) { if _, ok := pendingAddrConnID(cmgr, addr); ok { t.Fatalf("connection %s is still pending", addr) } + assertConnManagerInternalState(t, cmgr) // Start a connection attempt and wait for it to be established. notifyDialed.Store(false) @@ -515,6 +613,7 @@ func TestRemove(t *testing.T) { notifyCanceled.Store(false) go cmgr.Connect(ctx, addr) conn := assertConnReceived(t, connected, 0, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Remove the established connection and wait for the disconnect // notification to ensure it is disconnected as intended. @@ -523,6 +622,7 @@ func TestRemove(t *testing.T) { t.Fatalf("unexpected disconnect err: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Add a persistent connection back to the same address. notifyDialed.Store(true) @@ -532,6 +632,7 @@ func TestRemove(t *testing.T) { if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } + assertConnManagerInternalState(t, cmgr) // Wait for the connection manager to attempt to dial and ensure the // connection is marked as pending while the dialer is blocked. @@ -546,6 +647,7 @@ func TestRemove(t *testing.T) { if err := cmgr.Remove(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Allow the dialer to proceed with the removed persistent connection // attempt and then wait for the dialer to signal the context associated @@ -564,6 +666,7 @@ func TestRemove(t *testing.T) { case <-time.After(time.Millisecond * 5): t.Fatal("timeout waiting for cancel") } + assertConnManagerInternalState(t, cmgr) // Add a persistent connection back to the same address and wait for it to // be established. @@ -575,6 +678,7 @@ func TestRemove(t *testing.T) { t.Fatalf("failed to add persistent connection: %v", err) } conn2 := assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Remove the established persistent connection and wait for the disconnect // notification to ensure it is disconnected as intended. Also, ensure the @@ -584,10 +688,12 @@ func TestRemove(t *testing.T) { t.Fatalf("unexpected disconnect err: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestTargetOutbound tests the target number of outbound connections @@ -618,10 +724,12 @@ func TestTargetOutbound(t *testing.T) { assertConnReceived(t, connected, 0, ConnTypeOutbound) } assertNoConnReceived(t, connected) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestDoubleClose ensures closing a connection multiple times is a noop after @@ -644,6 +752,7 @@ func TestDoubleClose(t *testing.T) { // Wait for the connection to be established. conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) + assertConnManagerInternalState(t, cmgr) // Override the close func to cleanly detect closes. var numClosed uint32 @@ -660,10 +769,12 @@ func TestDoubleClose(t *testing.T) { if numClosed != 1 { t.Fatal("connection closed more than once") } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestRetryPersistent tests that persistent connections are retried. @@ -700,6 +811,7 @@ func TestRetryPersistent(t *testing.T) { conn.Close() assertConnReceived(t, disconnected, connID, ConnTypeManual) assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Remove the persistent connection, wait for it to disconnect, and ensure // it is actually removed. @@ -708,10 +820,12 @@ func TestRetryPersistent(t *testing.T) { } assertConnReceived(t, disconnected, connID, ConnTypeManual) assertRemovedPersistent(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestMaxPersistent ensures [ConnManager.AddPersistent] limits the maximum @@ -754,6 +868,7 @@ func TestMaxPersistent(t *testing.T) { // Wait for the connection. assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) } // Attempting to add more than the max allowed number of persistent conns @@ -762,6 +877,7 @@ func TestMaxPersistent(t *testing.T) { if !errors.Is(err, ErrMaxPersistent) { t.Fatalf("did not reject > max persistent, err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Ensure disconnecting the persistent conn does not incorrectly decrement // the count. @@ -773,6 +889,7 @@ func TestMaxPersistent(t *testing.T) { if !errors.Is(err, ErrMaxPersistent) { t.Fatalf("did not reject max persistent after dc, err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Remove the first persistent connection, wait for it to disconnect, and // ensure it is actually removed. @@ -781,6 +898,7 @@ func TestMaxPersistent(t *testing.T) { } assertConnReceived(t, disconnected, connID, ConnTypeManual) assertRemovedPersistent(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // A new persistent conn should now be allowed. addr = nextAddr() @@ -788,10 +906,12 @@ func TestMaxPersistent(t *testing.T) { if err != nil { t.Fatalf("failed to add persistent connection %v: %v", addr, err) } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestMaxRetryDuration tests the maximum retry duration. @@ -843,10 +963,12 @@ func TestMaxRetryDuration(t *testing.T) { }) const timeout = connTestReceiveTimeout + networkUpTimeout assertConnReceivedTimeout(t, connected, timeout, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestNetworkFailure tests that the connection manager handles a network @@ -906,6 +1028,8 @@ func TestNetworkFailure(t *testing.T) { t.Fatalf("unexpected number of dials - got %v, want <= %v", gotDials, wantMaxDials) } + + assertConnManagerCleanShutdown(t, cmgr) } // TestMultipleFailedConns ensures that the connection manager remains @@ -944,6 +1068,7 @@ func TestMultipleFailedConns(t *testing.T) { t.Fatalf("unexpected add err: %v", err) } } + assertConnManagerInternalState(t, cmgr) // Wait for the target number of dials and ensure they happen simultaneously // by checking it happens before the retry timeout. @@ -952,6 +1077,7 @@ func TestMultipleFailedConns(t *testing.T) { case <-time.After(20 * time.Millisecond): t.Fatal("did not reach target number of dials before timeout") } + assertConnManagerInternalState(t, cmgr) // Ensure that the connection manager still responds to requests while the // failed connections are still retrying. @@ -966,10 +1092,12 @@ func TestMultipleFailedConns(t *testing.T) { case <-time.After(20 * time.Millisecond): t.Fatal("timeout servicing connmgr requests") } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestShutdownFailedConns tests that failed connections are ignored after @@ -997,6 +1125,7 @@ func TestShutdownFailedConns(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + assertConnManagerInternalState(t, cmgr) // Shutdown the connection manager during the retry timeout after a failed // dial attempt. @@ -1010,6 +1139,7 @@ func TestShutdownFailedConns(t *testing.T) { // Ensure clean shutdown of connection manager. wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestRemovePendingConnection ensures that removing a pending outbound @@ -1035,6 +1165,7 @@ func TestRemovePendingConnection(t *testing.T) { // Establish a connection request to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") go cmgr.Connect(ctx, addr) + assertConnManagerInternalState(t, cmgr) // Wait for the connection manager to attempt to dial and ensure the // connection is marked as pending while the dialer is blocked. @@ -1044,12 +1175,14 @@ func TestRemovePendingConnection(t *testing.T) { t.Fatal("timeout waiting for dial") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Cancel the connection attempt while it's still pending. connID, _ := pendingAddrConnID(cmgr, addr) if err := cmgr.Remove(connID); err != nil { t.Fatalf("unexpected remove err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Wait for the dialer to signal the context associated with the dial was // canceled and ensure the internal pending state is removed. @@ -1061,10 +1194,12 @@ func TestRemovePendingConnection(t *testing.T) { if _, ok := pendingAddrConnID(cmgr, addr); ok { t.Fatalf("connection %s is still pending", addr) } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestCancelIgnoreDelayedConnection tests that a canceled pending persistent @@ -1111,6 +1246,7 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + assertConnManagerInternalState(t, cmgr) // Wait for the retry and ensure the connection is pending. select { @@ -1119,6 +1255,7 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { t.Fatalf("did not get retry before timeout") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Remove the connection and then immediately allow the next connection to // succeed. @@ -1132,10 +1269,12 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { // timeout window to ensure the connection manager's backoff is allowed to // properly elapse. assertNoConnReceivedTimeout(t, connected, 5*retryTimeout) + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestDialTimeout ensure [Config.Timeout] works as intended by creating a @@ -1167,6 +1306,7 @@ func TestDialTimeout(t *testing.T) { // Establish a connection to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") go cmgr.Connect(ctx, addr) + assertConnManagerInternalState(t, cmgr) // Wait to receive the signal that the dialer context was cancelled, which // means the dial timeout was hit. @@ -1175,10 +1315,12 @@ func TestDialTimeout(t *testing.T) { case <-time.After(dialTimeout * 10): t.Fatal("timeout waiting for dial cancellation") } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestConnectContext ensures the [ConnManager.Connect] method works as intended @@ -1218,6 +1360,7 @@ func TestConnectContext(t *testing.T) { t.Fatal("timeout waiting for dial") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Cancel the connection context, wait for the error from connect, and // ensure it is the expected error. @@ -1231,10 +1374,12 @@ func TestConnectContext(t *testing.T) { case <-time.After(10 * time.Millisecond): t.Fatal("timeout waiting for dial cancellation") } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // mockListener implements the net.Listener interface and is used to test @@ -1327,10 +1472,12 @@ func TestListeners(t *testing.T) { for range expectedNumConns { assertConnReceived(t, receivedConns, 0, ConnTypeInbound) } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } // TestRejectDuplicateConns ensures duplicate addresses are rejected. This @@ -1379,6 +1526,7 @@ func TestRejectDuplicateConns(t *testing.T) { t.Fatal("did not receive pending dial before timeout") } assertPendingAddr(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Duplicate connect to the pending address should be rejected. if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyPending) { @@ -1388,24 +1536,29 @@ func TestRejectDuplicateConns(t *testing.T) { // Inbound attempts from the pending outbound address should be rejected. go listener.Connect(addr) assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cmgr) // Allow the pending connection to complete. close(pending) conn := assertConnReceived(t, connected, 0, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Duplicate connect to the established address should be rejected. if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { t.Fatalf("did not reject duplicate active connection, err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Inbound attempts from the established outbound address should be // rejected. go listener.Connect(addr) assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cmgr) // Close the connection and wait for the disconnect. conn.Close() assertConnReceived(t, disconnected, conn.ID(), ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Add a persistent connection back to the same address and wait for it to // connect since there are no longer any connections to the address. @@ -1414,22 +1567,26 @@ func TestRejectDuplicateConns(t *testing.T) { t.Fatalf("failed to add persistent connection: %v", err) } assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) // Duplicate persistent connection attempts should be rejected. _, err = cmgr.AddPersistent(addr) if !errors.Is(err, ErrDuplicatePersistent) { t.Fatalf("did not reject duplicate persistent connection, err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Manual connection attempts to persistent connection should be rejected. _, err = cmgr.Connect(ctx, addr) if !errors.Is(err, ErrDuplicatePersistent) { t.Fatalf("did not reject manual connection to persistent, err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Inbound atempts from the persistent address should be rejected. go listener.Connect(addr) assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cmgr) // Remove the persistent connection, wait for it to disconnect, and ensure // it is actually removed. @@ -1438,16 +1595,19 @@ func TestRejectDuplicateConns(t *testing.T) { } assertConnReceived(t, disconnected, connID, ConnTypeManual) assertRemovedPersistent(t, cmgr, addr) + assertConnManagerInternalState(t, cmgr) // Inbound connections from the same address should now succeed. go listener.Connect(addr) assertConnReceived(t, inboundConns, 0, ConnTypeInbound) + assertConnManagerInternalState(t, cmgr) // Manual connection attempts to the inbound address should be rejected. if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { t.Fatalf("did not reject outbound for existing inbound conn, err: %v", err) } + assertConnManagerInternalState(t, cmgr) // Attempts to add a persistent connection to an existing inbound should be // rejected. @@ -1456,8 +1616,10 @@ func TestRejectDuplicateConns(t *testing.T) { t.Fatalf("did not reject persistent conn for existing inbound conn: %v", err) } + assertConnManagerInternalState(t, cmgr) // Ensure clean shutdown of connection manager. shutdown() wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) } From 76bf20218ce048223f98cc83d80501d251e83b4d Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 18 May 2026 15:59:15 -0500 Subject: [PATCH 13/51] connmgr: Support whitelisting. Currently the whitelisting logic happens in the server which makes it inaccessible to the connection manager. In order to pave the way for supporting various connection-related logic that currently happens in the server, but ideally should be happening in the connection manager, this adds basic support for whitelisting CIDR prefixes to the connection manager. The connection manager config struct now accepts a slice of prefixes and a new method named IsWhitelisted is added. Note that this only adds support . It does not update anything to use the new functionality yet. --- internal/connmgr/connmanager.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 453e4a33d..a2df96398 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "net" + "net/netip" "strconv" "sync" "sync/atomic" @@ -260,6 +261,10 @@ type Config struct { // DialTimeout specifies the amount of time to wait for a connection to // complete before giving up. DialTimeout time.Duration + + // Whitelists specifies CIDR address prefixes to whitelist. Whitelisted + // addresses are exempt from banning and certain connection limits. + Whitelists []netip.Prefix } // ConnManager provides a manager to handle network connections. @@ -317,6 +322,22 @@ type ConnManager struct { connIDByAddr map[string]uint64 } +// IsWhitelisted returns whether the IP address is included in the whitelisted +// networks and IPs. +func (cm *ConnManager) IsWhitelisted(addr *addrmgr.NetAddress) bool { + if len(cm.cfg.Whitelists) == 0 { + return false + } + + ip, _ := netip.AddrFromSlice(addr.IP) + for _, prefix := range cm.cfg.Whitelists { + if prefix.Contains(ip) { + return true + } + } + return false +} + // checkShutdown returns [ErrShutdown] when the connection manager quit channel // has been closed. func (cm *ConnManager) checkShutdown() error { From 78e734ff2d1b825392e967419014a84c705c946f Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Wed, 20 May 2026 19:21:44 -0500 Subject: [PATCH 14/51] connmgr: Add whitelist detection tests. This adds tests to ensure the new whitelist detection method works as expected. --- internal/connmgr/connmanager_test.go | 87 ++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index ecbb0ccf0..b9e1f512a 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -219,6 +219,93 @@ func TestNewConfig(t *testing.T) { }) } +// TestIsWhitelisted ensures [ConnManager.IsWhitelisted] works as expected. +func TestIsWhitelisted(t *testing.T) { + type perManagerTest struct { + addr string // address to test against whitelist + whitelisted bool // expected whitelisted result + } + + tests := []struct { + name string // test description + prefixes []netip.Prefix // CIDR prefixes to whitelist + perManagerTests []perManagerTest // tests to run against the prefixes + }{{ + name: "no whitelisted entries", + prefixes: nil, + perManagerTests: []perManagerTest{ + {"1.2.3.4:18555", false}, + {"127.0.0.1:18555", false}, + }, + }, { + name: "single /32 IPv4 entry", + prefixes: []netip.Prefix{ + netip.MustParsePrefix("1.2.3.4/32"), + }, + perManagerTests: []perManagerTest{ + {"1.2.3.4:18555", true}, + {"1.2.3.4:9108", true}, + {"[::1.2.3.4]:18555", false}, // IPv4 in IPv6 + {"1.2.3.5:18555", false}, + }, + }, { + name: "single /128 IPv6 entry", + prefixes: []netip.Prefix{ + netip.MustParsePrefix("::1.2.3.4/128"), + }, + perManagerTests: []perManagerTest{ + {"[::1.2.3.4]:18555", true}, + {"[::1.2.3.4]:9108", true}, + {"1.2.3.4:18555", false}, // IPv4 doesn't match IPv4 in IPv6 + {"[::1.2.3.5]:9108", false}, + }, + }, { + name: "mixed IPv4 and IPv6 with different prefix lengths", + prefixes: []netip.Prefix{ + netip.MustParsePrefix("12.13.14.0/24"), + netip.MustParsePrefix("20.21.22.23/8"), + netip.MustParsePrefix("fe80::/64"), + }, + perManagerTests: []perManagerTest{ + {"12.13.14.1:18555", true}, + {"12.13.14.255:18555", true}, + {"12.13.15.0:18555", false}, + {"20.0.0.0:18555", true}, + {"20.0.0.0:9108", true}, + {"20.255.255.255:18555", true}, + {"20.255.255.255:9108", true}, + {"21.0.0.0:18555", false}, + {"[fe80::1]:18555", true}, + {"[fe80::1]:9108", true}, + {"[fe80::ffff:ffff:ffff:ffff]:18555", true}, + {"[fe80::ffff:ffff:ffff:ffff]:1234", true}, + {"[fe80::1:ffff:ffff:ffff:ffff]:18555", false}, + }, + }} + + for _, test := range tests { + // Parse the whitelist entries for the test. + cmgr := newTestConnManager(t, &Config{ + Dial: mockDialer, + Whitelists: test.prefixes, + }) + + for _, pmTest := range test.perManagerTests { + mAddr := mockAddr{"tcp", pmTest.addr} + addr, err := stdlibNetAddrToAddrMgrNetAddr(mAddr) + if err != nil { + t.Fatalf("%q-%q: failed to parse address: %v", test.name, + pmTest.addr, err) + } + if got := cmgr.IsWhitelisted(addr); got != pmTest.whitelisted { + t.Errorf("%q-%q: mismatched result -- got %v, want %v", + test.name, pmTest.addr, got, pmTest.whitelisted) + continue + } + } + } +} + // assertConnID ensures the provided connection has the given ID. func assertConnID(t *testing.T, conn *Conn, wantID uint64) { t.Helper() From b2f60079feec3f6115d6c72a7ad645e6119412ad Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 18 May 2026 15:59:20 -0500 Subject: [PATCH 15/51] server: Integrate connmgr whitelisting. This modifies the server to pass in the parsed whitelist entries to the connection manager config and the relevant code to make use of the new method it exposes. Finally, it removes the no longer used local isWhitelisted method. --- server.go | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/server.go b/server.go index f61afaf33..3a7b1b671 100644 --- a/server.go +++ b/server.go @@ -513,6 +513,7 @@ func newServerPeer(s *server, conn *connmgr.Conn, remoteAddr *addrmgr.NetAddress conn: conn, remoteAddr: remoteAddr, persistent: s.connManager.IsPersistent(conn.ID()), + isWhitelisted: s.connManager.IsWhitelisted(remoteAddr), knownAddresses: apbf.NewFilter(maxKnownAddrsPerPeer, knownAddrsFPRate), quit: make(chan struct{}), getDataQueue: make(chan []*wire.InvVect, maxConcurrentGetDataReqs), @@ -2289,7 +2290,6 @@ func (s *server) inboundPeerConnected(ctx context.Context, conn *connmgr.Conn) { sp := newServerPeer(s, conn, remoteNetAddr) sp.Peer = peer.NewInboundPeer(newPeerConfig(sp), conn) - sp.isWhitelisted = isWhitelisted(remoteNetAddr) if err := sp.Handshake(ctx, sp.OnVersion); err != nil { srvrLog.Debugf("Failed handshake for inbound peer %s: %v", remoteNetAddr, err) @@ -2323,7 +2323,6 @@ func (s *server) outboundPeerConnected(ctx context.Context, conn *connmgr.Conn) sp := newServerPeer(s, conn, remoteNetAddr) sp.Peer = peer.NewOutboundPeer(newPeerConfig(sp), conn.RemoteAddr(), conn) - sp.isWhitelisted = isWhitelisted(remoteNetAddr) if err := sp.Handshake(ctx, sp.OnVersion); err != nil { srvrLog.Debugf("Failed handshake for outbound peer %s: %v", conn.RemoteAddr(), err) @@ -4367,6 +4366,7 @@ func newServer(ctx context.Context, profiler *profileServer, s.outboundPeerConnected(ctx, conn) }, GetNewAddress: newAddressFunc, + Whitelists: cfg.whitelists, }) if err != nil { return nil, err @@ -4631,19 +4631,3 @@ func addLocalAddress(addrMgr *addrmgr.AddrManager, addr string, services wire.Se return nil } - -// isWhitelisted returns whether the IP address is included in the whitelisted -// networks and IPs. -func isWhitelisted(addr *addrmgr.NetAddress) bool { - if len(cfg.whitelists) == 0 { - return false - } - - ip, _ := netip.AddrFromSlice(addr.IP) - for _, prefix := range cfg.whitelists { - if prefix.Contains(ip) { - return true - } - } - return false -} From dc4cd8870d4baddb475931015fb783b1966c9b9d Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Thu, 21 May 2026 17:10:47 -0500 Subject: [PATCH 16/51] connmgr: Update README.md for whitelist support. --- internal/connmgr/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/connmgr/README.md b/internal/connmgr/README.md index b9299d864..cf43ac92a 100644 --- a/internal/connmgr/README.md +++ b/internal/connmgr/README.md @@ -33,6 +33,8 @@ The following is a brief overview of the key features: - Supports manual connection establishment via `Connect` - Duplicate address prevention - Rejects duplicate connections to and from the same address (host:port) +- Whitelist support + - CIDR-based whitelists that allow bypassing certain limits and restrictions - Rich managed connections via `Conn` - Connection types for differentiated handling - Automatic cleanup on connection close From 26ea64f9de59bc7b06f73a7b30c801551a2603a8 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 19 May 2026 14:11:47 -0500 Subject: [PATCH 17/51] connmgr: Add try acquire support to semaphore. This adds a new TryAcquire method to the context-aware semaphore. As the name implies, the method supports conditionally acquiring the semaphore only when resources are immediately available. In other words, it will not block when there are no resources immediately available. --- internal/connmgr/semaphore.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/internal/connmgr/semaphore.go b/internal/connmgr/semaphore.go index fb7d7eed4..5f831d9be 100644 --- a/internal/connmgr/semaphore.go +++ b/internal/connmgr/semaphore.go @@ -27,6 +27,27 @@ func (s semaphore) Acquire(ctx context.Context) bool { return true } +// TryAcquire attempts to acquire the semaphore without blocking when there are +// no resources immediately available. +// +// It returns true with a nil error on success. It returns false with a nil +// error when the semaphore is at capacity and no permit is available. +// +// Finally, it returns false with the error associated with the context +// immediately when the context is already canceled or timed out at the time of +// the call. It does not attempt to acquire the semaphore in that case. +func (s semaphore) TryAcquire(ctx context.Context) (bool, error) { + if ctx.Err() != nil { + return false, ctx.Err() + } + select { + case s <- struct{}{}: + return true, nil + default: + } + return false, nil +} + // Release release the semaphore. func (s semaphore) Release() { select { From 0b85a384e808e8196e883d4a5c7391728a6e6d78 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 19 May 2026 14:11:48 -0500 Subject: [PATCH 18/51] connmgr: Add semaphore try acquire tests. This adds tests for the new TryAcquire method on the context-aware semaphore to ensure the semantics work as expected. --- internal/connmgr/semaphore_test.go | 75 +++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 2 deletions(-) diff --git a/internal/connmgr/semaphore_test.go b/internal/connmgr/semaphore_test.go index 9542176df..1d7d5b750 100644 --- a/internal/connmgr/semaphore_test.go +++ b/internal/connmgr/semaphore_test.go @@ -6,6 +6,7 @@ package connmgr import ( "context" + "errors" "testing" "time" ) @@ -21,10 +22,24 @@ func TestSemaphore(t *testing.T) { return sem.Acquire(ctx) } + // Create a closure that tries to acquire a semaphore via the nonblocking + // method with a timeout. + timedTryAcquire := func(sem semaphore, timeout time.Duration) (bool, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + if timeout == 0 { + cancel() + } else { + defer cancel() + } + return sem.TryAcquire(ctx) + } + // perSemTest describes a test to run against the same semaphore. type perSemTest struct { name string // test description numAcquires uint32 // num to acquire + numTries uint32 // num to try to acquire via nonblocking method + cancelTry bool // whether or not to cancel nonblocking try numReleases uint32 // num to release } @@ -69,6 +84,45 @@ func TestSemaphore(t *testing.T) { numReleases: 5, }}, want: []bool{true, true, true, true, true, true, true, false}, + }, { + name: "nonblocking tryacquire and blocking acquire mixed", + cap: 3, + perSemTests: []perSemTest{{ + name: "cap 3 (0 acquired): try 1, release 2", + numTries: 1, + numReleases: 2, + }, { + name: "cap 3 (0 acquired): acquire 2, try 1, release 1", + numAcquires: 2, + numTries: 1, + numReleases: 1, + }, { + name: "cap 3 (2 acquired): acquire 1, try 2, release 3", + numAcquires: 1, + numTries: 2, + numReleases: 3, + }}, + want: []bool{true, true, true, true, true, false, false}, + }, { + name: "nonblocking tryacquire with canceled context", + cap: 1, + perSemTests: []perSemTest{{ + name: "cap 1 (0 acquired): try 1 (canceled), release 0", + numTries: 1, + cancelTry: true, + numReleases: 0, + }, { + name: "cap 1 (0 acquired): acquire 1, try 1, release 1", + numAcquires: 1, + numTries: 1, + numReleases: 1, + }, { + name: "cap 1 (0 acquired): try 2, release 1", + numAcquires: 0, + numTries: 2, + numReleases: 1, + }}, + want: []bool{false, true, false, true, false}, }} for _, test := range tests { @@ -77,13 +131,30 @@ func TestSemaphore(t *testing.T) { sem := makeSemaphore(test.cap) results := make([]bool, 0, len(test.want)) - // Perform each sequence of acquires and releases as specified by the - // per semaphore tests. + // Perform each sequence of acquires, try acquires, and releases as + // specified by the per semaphore tests. for _, psTest := range test.perSemTests { const timeout = 10 * time.Millisecond for range psTest.numAcquires { results = append(results, timedAcquire(sem, timeout)) } + for range psTest.numTries { + // Override timeout with a duration 0 and expected error when + // the flag to force the context for the try acquire to be + // canceled is specified. + var wantErr error + tryTimeout := timeout + if psTest.cancelTry { + tryTimeout = 0 + wantErr = context.DeadlineExceeded + } + acquired, err := timedTryAcquire(sem, tryTimeout) + if !errors.Is(err, wantErr) { + t.Fatalf("%q: unexpected try acquire error: got %v, want %v", + psTest.name, err, wantErr) + } + results = append(results, acquired) + } for range psTest.numReleases { sem.Release() } From 6f13d36568887d21e48ca52ccc9d64a0254e2b24 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 19 May 2026 14:11:49 -0500 Subject: [PATCH 19/51] connmgr: Limit total overall normal connections. The current overall total connection limits are enforced by the server rather than the connection manager. This is not ideal for many reasons, but one of the most important consequences is that it makes DoS attacks easier. Another example of some less than ideal behavior that it allows is that some rare combinations of events can lead to temporary extra connection churn. It is much more robust and natural to perform the limiting in the connection manager itself via semaphores. That approach not only significantly hardens the server against DoS attacks and solves various edge cases present in the current code, it also paves the way for even more advanced features such as traffic shaping in the future. To that end, this adds semaphore-based limiting for the total overall number of normal connections to the connection manager and removes the relevant current limiting for it from the server. Normal connections are the automatic outbound, manual outbound, and inbound connections. Persistent connections, on the other hand, are not subject to the limit since they have their own limiting. This is consistent with them not being subject to the automatic target outbound limit either. --- internal/connmgr/connmanager.go | 140 +++++++++++++++++++++++++++----- internal/connmgr/error.go | 4 + internal/connmgr/error_test.go | 1 + internal/connmgr/log.go | 9 ++ rpcadaptors.go | 8 -- server.go | 12 +-- 6 files changed, 135 insertions(+), 39 deletions(-) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index a2df96398..78135ef1e 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -44,6 +44,10 @@ const ( // base times the number of retries that have been done. defaultMaxRetryDuration = time.Minute * 5 + // defaultMaxNormalConns is the default maximum number of normal inbound, + // outbound, and pending connections to permit. + defaultMaxNormalConns = 125 + // defaultTargetOutbound is the default number of outbound connections to // maintain. defaultTargetOutbound = 8 @@ -233,11 +237,25 @@ type Config struct { // connections in that case. OnAccept func(*Conn) + // MaxNormalConns is the maximum number of normal inbound, outbound, and + // pending connections to permit. Defaults to 125. + // + // Persistent connections do not count against this limit. They have their + // own maximum defined by [MaxPersistent]. + // + // Whitelisted connections and some connections with special permissions are + // also exempt. As a result, the total number of connections may exceed + // this value. + MaxNormalConns uint32 + // TargetOutbound is the number of outbound network connections to maintain // automatically. Defaults to 8. // // Persistent connections do not count against this value. They have their // own maximum limit defined by [MaxPersistent]. + // + // This will be forced to the smaller of the specified value (or its default + // value when unspecified) and [Config.MaxNormalConns]. TargetOutbound uint32 // RetryDuration is the duration to wait before retrying connection @@ -290,10 +308,16 @@ type ConnManager struct { // It is a buffered channel with size [MaxPersistent]. runPersistentChan chan *persistentEntry - // outboundSem limits the number of active outbound connections. It does - // not apply to persistent connections which are separately limited to - // [MaxPersistent]. - activeOutboundsSem semaphore + // These semaphores are used to enforce max limits on the number of + // connections of different kinds. They do not apply to persistent + // connections which are separately limited to [MaxPersistent]. + // + // totalNormalConnsSem limits the total overall number of normal inbound, + // outbound, and pending connections. + // + // outboundSem limits the number of active outbound connections. + totalNormalConnsSem semaphore + activeOutboundsSem semaphore // The fields below this point are all protected by the connection mutex. connMtx sync.Mutex @@ -545,6 +569,8 @@ func (cm *ConnManager) rejectDuplicateAddr(addr *addrmgr.NetAddress) error { // to the address // - [ErrAlreadyConnected] when there is already an established connection to // the address +// - [ErrMaxNormalConns] when there are already the maximum allowed number of +// normal connections (inbound, outbound, and pending) // - [ErrShutdown] when the connection manager is shutting down // - [context.Canceled] or [context.DeadlineExceeded] depending on the // provided context or when the dialer fails to establish a connection @@ -694,9 +720,13 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect // the connection manager. // // Attempts to dial addresses that already have an established, pending, or -// persistent connection will return an error as described below. +// persistent connection or would exceed max allowed limits will return an error +// as described below. +// +// The connection will have type [ConnTypeManual] and the following connection +// limits are enforced: // -// The connection will have type [ConnTypeManual]. +// - Total normal connections ([Config.MaxNormalConns]) // // Note that the context parameter to this function and the lifecycle context // may be independent. @@ -710,6 +740,8 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect // to the address // - [ErrAlreadyConnected] when there is already an established connection to // the address +// - [ErrMaxNormalConns] when there are already the maximum allowed number of +// normal connections (inbound, outbound, and pending) // - [ErrShutdown] when the connection manager is shutting down // - [context.Canceled] or [context.DeadlineExceeded] depending on the // provided context or when the dialer fails to establish a connection @@ -717,7 +749,21 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect // // This function is safe for concurrent access. func (cm *ConnManager) Connect(ctx context.Context, addr net.Addr) (*Conn, error) { - conn, err := cm.dial(ctx, addr, ConnTypeManual, nil, nil) + acquired, err := cm.totalNormalConnsSem.TryAcquire(ctx) + if err != nil { + if sErr := cm.checkShutdown(); sErr != nil { + return nil, sErr + } + return nil, err + } + if !acquired { + maxAllowed := cm.cfg.MaxNormalConns + str := fmt.Sprintf("a maximum of %d %s is allowed", maxAllowed, + pickNoun(maxAllowed, "connection", "connections")) + return nil, MakeError(ErrMaxNormalConns, str) + } + onClose := cm.totalNormalConnsSem.Release + conn, err := cm.dial(ctx, addr, ConnTypeManual, onClose, nil) if err != nil { return nil, err } @@ -847,6 +893,18 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) defer log.Tracef("Listener handler done for %s", listener.Addr()) for ctx.Err() == nil { + // The following is intentionally implementing active connection + // shedding by accepting connections and then immediately disconnecting + // them after the [net.Listener.Accept] call if any policies are + // violated. + // + // Reversing it and blocking until a permit is available and only then + // calling Accept would cause the connections to build up in the kernel. + // Then, since the kernel will still create the 3-way handshake, clients + // would connect and hang until their own timeouts are hit, and, + // eventually, the entire service could appear entirely down if the SYN + // queue were to fill. It also would not allow implementing better + // additional policies. netConn, err := listener.Accept() if err != nil { // Only log the error if not forcibly shutting down. @@ -881,7 +939,29 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) } cm.connMtx.Unlock() - go func(netConn net.Conn) { + // Require a permit to allow the inbound connection unless the address + // has special permissions (e.g. whitelisted). + // + // Attempt to acquire a permit via a non-blocking call and immediately + // disconnect if unsuccessful so that all blocking happens on + // [net.Listener.Accept] for the reasons described above. + requirePermit := !cm.IsWhitelisted(rAddr) + if requirePermit { + acquired, err := cm.totalNormalConnsSem.TryAcquire(ctx) + if err != nil { + netConn.Close() + continue + } + if !acquired { + maxAllowed := cm.cfg.MaxNormalConns + log.Debugf("Dropped connection from %v: a maximum of %d %s is "+ + "allowed", rAddr, maxAllowed, pickNoun(maxAllowed, + "connection", "connections")) + netConn.Close() + continue + } + } + go func(netConn net.Conn, requirePermit bool) { // Create a new connection instance with the next globally unique // connection ID, add an entry to the map that tracks all active // connections, and invoke the configured accept callback with it. @@ -897,6 +977,9 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) cm.connMtx.Unlock() log.Debugf("Disconnected from %v (id: %d, type: %v)", rAddr, id, connType) + if requirePermit { + cm.totalNormalConnsSem.Release() + } } conn = newConn(cm, netConn, id, connType, rAddr, onClose) cm.connMtx.Lock() @@ -905,7 +988,7 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) log.Debugf("Accepted connection from %v (id: %d, type: %v)", rAddr, id, connType) cm.cfg.OnAccept(conn) - }(netConn) + }(netConn, requirePermit) } } @@ -1147,16 +1230,28 @@ func (cm *ConnManager) targetOutboundHandler(ctx context.Context) { return } + // Wait for a permit to make another overall connection. This limits + // the total number of normal connections while the previous limits the + // total number of automatic outbound connections. + if !cm.totalNormalConnsSem.Acquire(ctx) { + cm.activeOutboundsSem.Release() + return + } + addr, err := cm.cfg.GetNewAddress() if err != nil { failedAttempts.Add(1) log.Debugf("Failed to get address for outbound connection: %v", err) + cm.totalNormalConnsSem.Release() cm.activeOutboundsSem.Release() continue } go func(addr net.Addr) { - onClose := cm.activeOutboundsSem.Release + onClose := func() { + cm.totalNormalConnsSem.Release() + cm.activeOutboundsSem.Release() + } conn, err := cm.dial(ctx, addr, ConnTypeOutbound, onClose, nil) if err != nil { failedAttempts.Add(1) @@ -1252,23 +1347,28 @@ func New(cfg *Config) (*ConnManager, error) { if cfg.Dial == nil { return nil, MakeError(ErrDialNil, "dial cannot be nil") } - // Default to sane values + // Default to sane values. if cfg.RetryDuration <= 0 { cfg.RetryDuration = defaultRetryDuration } + if cfg.MaxNormalConns == 0 { + cfg.MaxNormalConns = defaultMaxNormalConns + } if cfg.TargetOutbound == 0 { cfg.TargetOutbound = defaultTargetOutbound } + cfg.TargetOutbound = min(cfg.TargetOutbound, cfg.MaxNormalConns) cm := ConnManager{ - cfg: *cfg, // Copy so caller can't mutate - quit: make(chan struct{}), - maxRetryDuration: defaultMaxRetryDuration, - runPersistentChan: make(chan *persistentEntry, MaxPersistent), - activeOutboundsSem: makeSemaphore(cfg.TargetOutbound), - persistent: make(map[uint64]*persistentEntry, MaxPersistent), - pending: make(map[uint64]*pendingConnInfo), - active: make(map[uint64]*Conn, cfg.TargetOutbound), - connIDByAddr: make(map[string]uint64), + cfg: *cfg, // Copy so caller can't mutate + quit: make(chan struct{}), + maxRetryDuration: defaultMaxRetryDuration, + runPersistentChan: make(chan *persistentEntry, MaxPersistent), + totalNormalConnsSem: makeSemaphore(cfg.MaxNormalConns), + activeOutboundsSem: makeSemaphore(cfg.TargetOutbound), + persistent: make(map[uint64]*persistentEntry, MaxPersistent), + pending: make(map[uint64]*pendingConnInfo), + active: make(map[uint64]*Conn, cfg.TargetOutbound), + connIDByAddr: make(map[string]uint64), } return &cm, nil } diff --git a/internal/connmgr/error.go b/internal/connmgr/error.go index 9c87dad6e..6b313d89f 100644 --- a/internal/connmgr/error.go +++ b/internal/connmgr/error.go @@ -22,6 +22,10 @@ const ( // already has an established connection. ErrAlreadyConnected = ErrorKind("ErrAlreadyConnected") + // ErrMaxNormalConns indicates a connection attempt (inbound or outbound) + // would exceed the maximum allowed number of normal connections. + ErrMaxNormalConns = ErrorKind("ErrMaxNormalConns") + // ErrMaxPersistent indicates an attempt to add more than the maximum // allowed number of persistent connections. ErrMaxPersistent = ErrorKind("ErrMaxPersistent") diff --git a/internal/connmgr/error_test.go b/internal/connmgr/error_test.go index d4e5d2262..b3881bf7c 100644 --- a/internal/connmgr/error_test.go +++ b/internal/connmgr/error_test.go @@ -19,6 +19,7 @@ func TestErrorKindStringer(t *testing.T) { {ErrDialNil, "ErrDialNil"}, {ErrAlreadyPending, "ErrAlreadyPending"}, {ErrAlreadyConnected, "ErrAlreadyConnected"}, + {ErrMaxNormalConns, "ErrMaxNormalConns"}, {ErrMaxPersistent, "ErrMaxPersistent"}, {ErrDuplicatePersistent, "ErrDuplicatePersistent"}, {ErrNotFound, "ErrNotFound"}, diff --git a/internal/connmgr/log.go b/internal/connmgr/log.go index 4bf44f579..f6ba6f5f4 100644 --- a/internal/connmgr/log.go +++ b/internal/connmgr/log.go @@ -19,3 +19,12 @@ var log = slog.Disabled func UseLogger(logger slog.Logger) { log = logger } + +// pickNoun returns the singular or plural form of a noun depending on the count +// n. +func pickNoun[T ~uint32 | ~uint64](n T, singular, plural string) string { + if n == 1 { + return singular + } + return plural +} diff --git a/rpcadaptors.go b/rpcadaptors.go index a6826e047..b88fd5bee 100644 --- a/rpcadaptors.go +++ b/rpcadaptors.go @@ -130,14 +130,6 @@ func (cm *rpcConnManager) Connect(ctx context.Context, addr string, permanent bo return err } - // Limit max number of total peers. - cm.server.peerState.Lock() - count := cm.server.peerState.count() - cm.server.peerState.Unlock() - if count >= cfg.MaxPeers { - return errors.New("max peers reached") - } - // Attempt to add a persistent peer when requested. connManager := cm.server.connManager if permanent { diff --git a/server.go b/server.go index 3a7b1b671..5d019fdb4 100644 --- a/server.go +++ b/server.go @@ -2703,17 +2703,6 @@ func (s *server) handleAddPeer(sp *serverPeer) bool { return false } - // Limit max number of total peers. However, allow whitelisted inbound - // peers regardless. - if state.count()+1 > cfg.MaxPeers && !isInboundWhitelisted { - srvrLog.Infof("Max peers reached [%d] - disconnecting peer %s", - cfg.MaxPeers, sp) - sp.Disconnect() - // TODO: how to handle permanent peers here? - // they should be rescheduled. - return false - } - // Add the new peer. if sp.Inbound() { state.inboundPeers[sp.ID()] = sp @@ -4359,6 +4348,7 @@ func newServer(ctx context.Context, profiler *profileServer, s.inboundPeerConnected(ctx, conn) }, RetryDuration: connectionRetryInterval, + MaxNormalConns: uint32(cfg.MaxPeers), TargetOutbound: s.targetOutbound, Dial: s.attemptDcrdDial, DialTimeout: cfg.DialTimeout, From 9b30eee308297969338e786f3da277f94ff98b1b Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 19 May 2026 14:11:50 -0500 Subject: [PATCH 20/51] connmgr: Add total max normal conns tests. This adds tests to ensure that the new max normal connection limiting properly enforces the limit including automatic outbound, manual outbound, and inbound connections. It also ensures that it not applied to persistent connections. --- internal/connmgr/connmanager_test.go | 154 +++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index b9e1f512a..aa03622f7 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -1710,3 +1710,157 @@ func TestRejectDuplicateConns(t *testing.T) { wg.Wait() assertConnManagerCleanShutdown(t, cmgr) } + +// TestMaxNormalConns ensures the connection manager limits the total number of +// normal connections to [Config.MaxNormalConns] including automatic outbound, +// manual outbound, and inbound connections. It also ensures that it is not +// applied to persistent connections. +func TestMaxNormalConns(t *testing.T) { + t.Parallel() + + // nextAddr is a convenience func to return a new unique address with every + // invocation. + var numAddrs atomic.Uint32 + nextAddr := func() net.Addr { + addrStr := fmt.Sprintf("10.0.0.%d:18555", numAddrs.Add(1)) + return mustParseAddrPort(addrStr) + } + + // Constants for the number of various normal connection types to test + // overall max normal connection limits. + const ( + targetOutbound = 3 + targetManual = 4 + targetInbound = 5 + maxNormalConns = targetOutbound + targetManual + targetInbound + ) + connected := make(chan *Conn) + disconnected := make(chan *Conn) + inboundConns := make(chan *Conn) + listener := newMockListener("127.0.0.1:9108") + var pauseTargetOutbound atomic.Bool + var totalPausedAddrs atomic.Uint32 + hitMaxFailedAttempts := make(chan struct{}) + cmgr := newTestConnManager(t, &Config{ + Listeners: []net.Listener{listener}, + MaxNormalConns: maxNormalConns, + TargetOutbound: targetOutbound, + RetryDuration: 50 * time.Millisecond, + Dial: mockDialer, + OnAccept: func(conn *Conn) { + inboundConns <- conn + }, + GetNewAddress: func() (net.Addr, error) { + if pauseTargetOutbound.Load() { + total := totalPausedAddrs.Add(1) + if total == maxFailedAttempts { + hitMaxFailedAttempts <- struct{}{} + } + return nil, errors.New("network down") + } + return nextAddr(), nil + }, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + cmgr.maxRetryDuration = cmgr.cfg.RetryDuration + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Wait for the expected number of target outbound conns to be established. + outbounds := make([]*Conn, 0, targetOutbound) + for len(outbounds) < targetOutbound { + conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) + outbounds = append(outbounds, conn) + } + assertConnManagerInternalState(t, cmgr) + + // Establish target number of inbounds to the listener and wait for them to + // be established. + go func() { + for range targetInbound { + listener.Connect(nextAddr()) + } + }() + inbounds := make([]*Conn, 0, targetInbound) + for len(inbounds) < targetInbound { + conn := assertConnReceived(t, inboundConns, 0, ConnTypeInbound) + inbounds = append(inbounds, conn) + } + assertConnManagerInternalState(t, cmgr) + + // Establish target number of manual connections and wait for them to be + // established. + go func() { + for range targetManual { + go cmgr.Connect(ctx, nextAddr()) + } + }() + manualConns := make([]*Conn, 0, targetManual+1) + for len(manualConns) < targetManual { + conn := assertConnReceived(t, connected, 0, ConnTypeManual) + manualConns = append(manualConns, conn) + } + assertConnManagerInternalState(t, cmgr) + + // Ensure manual connections that would exceed the max allowed normal + // connections are rejected. + _, err := cmgr.Connect(ctx, nextAddr()) + if !errors.Is(err, ErrMaxNormalConns) { + t.Fatalf("did not reject manual connection at max allowed, err: %v", err) + } + assertConnManagerInternalState(t, cmgr) + + // Ensure inbound connections that would exceed the max allowed normal + // connections are rejected. + go listener.Connect(nextAddr()) + assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cmgr) + + // Pause the target outbound dials and remove one of the target outbound + // connections to make room for another manual connection. Then wait for + // the max failures to be hit so attempts are paused for a retry timeout. + pauseTargetOutbound.Store(true) + outboundConn := outbounds[0] + outboundConn.Close() + assertConnReceived(t, disconnected, outboundConn.ID(), ConnTypeOutbound) + select { + case <-hitMaxFailedAttempts: + time.Sleep(connTestReceiveTimeout) + case <-time.After(maxFailedAttempts * connTestReceiveTimeout): + t.Fatal("did not reach max failed attempts before timeout") + } + assertConnManagerInternalState(t, cmgr) + + // Establish another manual connection to take the place of the target + // outbound connection that was just closed and wait for it to be + // established. + go cmgr.Connect(ctx, nextAddr()) + assertConnReceived(t, connected, 0, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) + + // Unpause the target outbound dials and ensure no additional automatic + // outbound connections are made despite being under the target outbound due + // to max total conns. + pauseTargetOutbound.Store(false) + assertNoConnReceivedTimeout(t, connected, connTestNonReceiveTimeout+ + cmgr.cfg.RetryDuration) + assertConnManagerInternalState(t, cmgr) + + // Ensure persistent connections are not subject to the max total normal + // connections by adding one and waiting for it to be established. + connID, err := cmgr.AddPersistent(nextAddr()) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) +} From aedda39720c7cab55b95a4074f87f2fb890372ea Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Thu, 21 May 2026 17:14:30 -0500 Subject: [PATCH 21/51] connmgr: Update README.md for total conn limits. --- internal/connmgr/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/connmgr/README.md b/internal/connmgr/README.md index cf43ac92a..850f2620e 100644 --- a/internal/connmgr/README.md +++ b/internal/connmgr/README.md @@ -31,6 +31,8 @@ The following is a brief overview of the key features: with exponential backoff on disconnect - Manual connections - Supports manual connection establishment via `Connect` +- Connection limits + - Limits total normal (non-persistent) connections to `MaxNormalConns` - Duplicate address prevention - Rejects duplicate connections to and from the same address (host:port) - Whitelist support From 1a32f180c40cb957ff82d4c87dac42977dc38d23 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Thu, 21 May 2026 16:44:17 -0500 Subject: [PATCH 22/51] connmgr: Limit max connections per host. Similar to the recent total normal connection limiting, the current per-host connection limits are enforced by the server. For similar reasons, it is much more robust and natural to perform the limiting early in the connection manager. With that in mind, this implements the per-host connection limiting in the connection manager and removes the relevant current limiting for it from the server. The limiting is applied to inbound, outbound, and persistent connections. The new limiting is handled early in both the inbound and outbound paths now which allows it to take advantage of fast connection shedding for inbound connections and to preemptively prevent all outbound attempts that would exceed the limit regardless of source. It also provides the flexibility to apply independent special permissions in the future. This also slightly changes the semantics to exempt whitelisted addresses for both inbound and outbound connections as opposed to only inbound connections. --- internal/connmgr/connmanager.go | 126 ++++++++++++++++++++++++++++++-- internal/connmgr/error.go | 4 + server.go | 38 ++-------- 3 files changed, 130 insertions(+), 38 deletions(-) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 78135ef1e..f11d28953 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -48,6 +48,11 @@ const ( // outbound, and pending connections to permit. defaultMaxNormalConns = 125 + // defaultMaxConnsPerHost is the default maximum number of connections with + // the same host to permit. It does not apply to whitelisted or loopback + // addresses. + defaultMaxConnsPerHost = 5 + // defaultTargetOutbound is the default number of outbound connections to // maintain. defaultTargetOutbound = 8 @@ -193,9 +198,10 @@ func (c *Conn) Type() ConnectionType { // pendingConnInfo houses information about a pending connection attempt. type pendingConnInfo struct { - id uint64 - addr *addrmgr.NetAddress - cancel context.CancelFunc + id uint64 + addr *addrmgr.NetAddress + hostKey string + cancel context.CancelFunc } // persistentEntry houses information about a persistent connection that has @@ -248,6 +254,21 @@ type Config struct { // this value. MaxNormalConns uint32 + // MaxConnsPerHost is the maximum number of connections with the same host + // to permit. Defaults to 5. + // + // This applies to inbound, outbound, and persistent connections. However, + // in practice, it is highly unlikely that outbound connections will hit the + // default limit (unless intentionally connecting manually) because: + // + // - connections to the same host:port are rejected and it is extremely rare + // for the same host to serve multiple instances on different ports + // - all automatic outbound connections are heavily biased toward different + // network groups + // + // This limit is not applied to whitelisted or loopback connections. + MaxConnsPerHost uint32 + // TargetOutbound is the number of outbound network connections to maintain // automatically. Defaults to 8. // @@ -344,6 +365,11 @@ type ConnManager struct { // (host:port). It is kept in sync with the persistent, pending, and active // maps and is primarily used to efficiently reject duplicate connections. connIDByAddr map[string]uint64 + + // perHostCounts provides fast O(1) lookup of the number of entries per + // host. It is kept in sync with the persistent, pending, and active maps + // and is primarily used to efficiently enforce per-host connection limits. + perHostCounts map[string]uint32 } // IsWhitelisted returns whether the IP address is included in the whitelisted @@ -403,6 +429,32 @@ func stdlibNetAddrToAddrMgrNetAddr(addr net.Addr) (*addrmgr.NetAddress, error) { return netAddr, nil } +// addrHostKey returns the host portion of the passed address as a string +// suitable for use as a map key. +func addrHostKey(addr net.Addr) string { + if na, ok := addr.(*addrmgr.NetAddress); ok { + return net.IP(na.IP).String() + } + + addrStr := addr.String() + host, _, err := net.SplitHostPort(addrStr) + if err == nil { + return host + } + return addrStr +} + +// decrementPerHostCount decrements the reference count for the provided host +// and cleans up the associated entry when there are no more references. +// +// This function MUST be called with the connection mutex held (writes). +func (cm *ConnManager) decrementPerHostCount(hostKey string) { + cm.perHostCounts[hostKey]-- + if cm.perHostCounts[hostKey] == 0 { + delete(cm.perHostCounts, hostKey) + } +} + // addPendingInfo adds information about a pending connection attempt to the // local state. // @@ -411,6 +463,7 @@ func (cm *ConnManager) addPendingInfo(info *pendingConnInfo) { cm.pending[info.id] = info if _, ok := cm.persistent[info.id]; !ok { cm.connIDByAddr[info.addr.String()] = info.id + cm.perHostCounts[info.hostKey]++ } } @@ -421,6 +474,7 @@ func (cm *ConnManager) removePendingInfo(info *pendingConnInfo) { delete(cm.pending, info.id) if _, ok := cm.persistent[info.id]; !ok { delete(cm.connIDByAddr, info.addr.String()) + cm.decrementPerHostCount(info.hostKey) } } @@ -431,6 +485,7 @@ func (cm *ConnManager) addActiveConn(conn *Conn) { cm.active[conn.id] = conn if _, ok := cm.persistent[conn.id]; !ok { cm.connIDByAddr[conn.remoteAddr.String()] = conn.id + cm.perHostCounts[addrHostKey(&conn.remoteAddr)]++ } } @@ -448,6 +503,7 @@ func (cm *ConnManager) removeActiveConn(conn *Conn) { delete(cm.active, conn.id) if _, ok := cm.persistent[conn.id]; !ok { delete(cm.connIDByAddr, conn.remoteAddr.String()) + cm.decrementPerHostCount(addrHostKey(&conn.remoteAddr)) } } @@ -457,6 +513,7 @@ func (cm *ConnManager) removeActiveConn(conn *Conn) { func (cm *ConnManager) addPersistentEntry(entry *persistentEntry) { cm.persistent[entry.id] = entry cm.connIDByAddr[entry.addr.String()] = entry.id + cm.perHostCounts[addrHostKey(entry.addr)]++ } // removePersistentEntry removes a persistent connection entry from the local @@ -469,6 +526,7 @@ func (cm *ConnManager) removePersistentEntry(entry *persistentEntry) { _, active := cm.active[entry.id] if !pending && !active { delete(cm.connIDByAddr, entry.addr.String()) + cm.decrementPerHostCount(addrHostKey(entry.addr)) } } @@ -541,6 +599,28 @@ func (cm *ConnManager) rejectDuplicateAddr(addr *addrmgr.NetAddress) error { return nil } +// rejectMaxConnsPerHost returns an error if adding an additional connection +// with the provided host address would exceed [Config.MaxConnsPerHost] and is +// not exempt. +// +// This function MUST be called with the connection mutex held (reads). +func (cm *ConnManager) rejectMaxConnsPerHost(addr *addrmgr.NetAddress, hostKey string, isWhitelisted bool) error { + // Whitelisted and loopback addresses are exempt. + isLoopback := net.IP(addr.IP).IsLoopback() + if isWhitelisted || isLoopback { + return nil + } + + maxAllowed := cm.cfg.MaxConnsPerHost + if numConns := cm.perHostCounts[hostKey]; numConns+1 > maxAllowed { + str := fmt.Sprintf("a maximum of %d %s per host is allowed", maxAllowed, + pickNoun(maxAllowed, "connection", "connections")) + return MakeError(ErrMaxConnsPerHost, str) + } + + return nil +} + // dial attempts to connect to the provided address and returns a connection // configured with the provided params on success. // @@ -553,6 +633,10 @@ func (cm *ConnManager) rejectDuplicateAddr(addr *addrmgr.NetAddress) error { // and pending connections are rejected when a non-nil persistent connection ID // is passed. // +// The following connection limits are enforced: +// +// - Total connections with the same host ([Config.MaxConnsPerHost]) +// // On success, the returned connection is configured to remove itself from the // set of all active connections and invoke the provided on close callback (if // set) when it is closed. @@ -571,6 +655,8 @@ func (cm *ConnManager) rejectDuplicateAddr(addr *addrmgr.NetAddress) error { // the address // - [ErrMaxNormalConns] when there are already the maximum allowed number of // normal connections (inbound, outbound, and pending) +// - [ErrMaxConnsPerHost] when there are already the maximum allowed number of +// connections (pending, active, and persistent) with the same host // - [ErrShutdown] when the connection manager is shutting down // - [context.Canceled] or [context.DeadlineExceeded] depending on the // provided context or when the dialer fails to establish a connection @@ -598,6 +684,8 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect if err != nil { return nil, err } + rAddrHostKey := addrHostKey(rAddr) + isWhitelisted := cm.IsWhitelisted(rAddr) // Reject attempts to dial addresses that are already connected (or in the // process of it). Additionally, reject attempts to dial existing @@ -617,6 +705,14 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect return nil, err } + // Limit the max number of connections per host. + err = cm.rejectMaxConnsPerHost(rAddr, rAddrHostKey, isWhitelisted) + if err != nil { + cm.connMtx.Unlock() + log.Debugf("Rejected connection to %v: %v", rAddr, err) + return nil, err + } + // Apply a dial timeout if requested. Otherwise, use a regular cancel // context to support canceling the pending connection later. var cancel context.CancelFunc @@ -635,7 +731,7 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect } else { connID = cm.nextConnID.Add(1) } - info := &pendingConnInfo{connID, rAddr, cancel} + info := &pendingConnInfo{connID, rAddr, rAddrHostKey, cancel} cm.addPendingInfo(info) cm.connMtx.Unlock() defer func() { @@ -727,6 +823,7 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect // limits are enforced: // // - Total normal connections ([Config.MaxNormalConns]) +// - Total connections with the same host ([Config.MaxConnsPerHost]) // // Note that the context parameter to this function and the lifecycle context // may be independent. @@ -742,6 +839,8 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect // the address // - [ErrMaxNormalConns] when there are already the maximum allowed number of // normal connections (inbound, outbound, and pending) +// - [ErrMaxConnsPerHost] when there are already the maximum allowed number of +// connections (pending, active, and persistent) with the same host // - [ErrShutdown] when the connection manager is shutting down // - [context.Canceled] or [context.DeadlineExceeded] depending on the // provided context or when the dialer fails to establish a connection @@ -921,6 +1020,8 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) netConn.Close() continue } + rAddrHostKey := addrHostKey(rAddr) + isWhitelisted := cm.IsWhitelisted(rAddr) // Reject connections with the same host:port as any existing pending, // established, or persistent connections. Note that this does NOT @@ -929,7 +1030,7 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) // // The aforementioned behavior is intentional as it allows connections // from the same host to be independently limited to more than one - // elsewhere. + // below. cm.connMtx.Lock() if err := cm.rejectDuplicateAddr(rAddr); err != nil { cm.connMtx.Unlock() @@ -937,6 +1038,15 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) netConn.Close() continue } + + // Limit the max number of connections per host. + err = cm.rejectMaxConnsPerHost(rAddr, rAddrHostKey, isWhitelisted) + if err != nil { + cm.connMtx.Unlock() + log.Debugf("Dropped connection from %v: %v", rAddr, err) + netConn.Close() + continue + } cm.connMtx.Unlock() // Require a permit to allow the inbound connection unless the address @@ -945,7 +1055,7 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) // Attempt to acquire a permit via a non-blocking call and immediately // disconnect if unsuccessful so that all blocking happens on // [net.Listener.Accept] for the reasons described above. - requirePermit := !cm.IsWhitelisted(rAddr) + requirePermit := !isWhitelisted if requirePermit { acquired, err := cm.totalNormalConnsSem.TryAcquire(ctx) if err != nil { @@ -1354,6 +1464,9 @@ func New(cfg *Config) (*ConnManager, error) { if cfg.MaxNormalConns == 0 { cfg.MaxNormalConns = defaultMaxNormalConns } + if cfg.MaxConnsPerHost == 0 { + cfg.MaxConnsPerHost = defaultMaxConnsPerHost + } if cfg.TargetOutbound == 0 { cfg.TargetOutbound = defaultTargetOutbound } @@ -1369,6 +1482,7 @@ func New(cfg *Config) (*ConnManager, error) { pending: make(map[uint64]*pendingConnInfo), active: make(map[uint64]*Conn, cfg.TargetOutbound), connIDByAddr: make(map[string]uint64), + perHostCounts: make(map[string]uint32), } return &cm, nil } diff --git a/internal/connmgr/error.go b/internal/connmgr/error.go index 6b313d89f..966475894 100644 --- a/internal/connmgr/error.go +++ b/internal/connmgr/error.go @@ -26,6 +26,10 @@ const ( // would exceed the maximum allowed number of normal connections. ErrMaxNormalConns = ErrorKind("ErrMaxNormalConns") + // ErrMaxConnsPerHost indicates a connection attempt (inbound or outbound) + // would exceed the maximum allowed number of connections per host. + ErrMaxConnsPerHost = ErrorKind("ErrMaxConnsPerHost") + // ErrMaxPersistent indicates an attempt to add more than the maximum // allowed number of persistent connections. ErrMaxPersistent = ErrorKind("ErrMaxPersistent") diff --git a/server.go b/server.go index 5d019fdb4..327ed87ba 100644 --- a/server.go +++ b/server.go @@ -2618,20 +2618,6 @@ func (s *server) considerReportedAddr(from *serverPeer, addr *wire.NetAddress) { s.considerReportedAddrOutbound(from, addr) } -// connectionsWithIP returns the number of connections with the given IP. -// -// This function MUST be called with the embedded mutex locked (for reads). -func (ps *peerState) connectionsWithIP(ip net.IP) int { - var total int - ps.forAllPeers(func(sp *serverPeer) { - if ip.Equal(sp.remoteAddr.IP) { - total++ - } - - }) - return total -} - // handleAddPeer deals with adding new peers and includes logic such as // categorizing the type of peer, limiting the maximum allowed number of peers, // and local external address resolution. @@ -2690,19 +2676,6 @@ func (s *server) handleAddPeer(sp *serverPeer) bool { defer state.Unlock() state.Lock() - // Limit max number of connections from a single IP. However, allow - // whitelisted inbound peers and localhost connections regardless. - isInboundWhitelisted := sp.isWhitelisted && sp.Inbound() - peerIP := net.IP(sp.remoteAddr.IP) - if cfg.MaxSameIP > 0 && !isInboundWhitelisted && !peerIP.IsLoopback() && - state.connectionsWithIP(peerIP)+1 > cfg.MaxSameIP { - - srvrLog.Infof("Max connections with %s reached [%d] - disconnecting "+ - "peer", sp, cfg.MaxSameIP) - sp.Disconnect() - return false - } - // Add the new peer. if sp.Inbound() { state.inboundPeers[sp.ID()] = sp @@ -4347,11 +4320,12 @@ func newServer(ctx context.Context, profiler *profileServer, OnAccept: func(conn *connmgr.Conn) { s.inboundPeerConnected(ctx, conn) }, - RetryDuration: connectionRetryInterval, - MaxNormalConns: uint32(cfg.MaxPeers), - TargetOutbound: s.targetOutbound, - Dial: s.attemptDcrdDial, - DialTimeout: cfg.DialTimeout, + RetryDuration: connectionRetryInterval, + MaxNormalConns: uint32(cfg.MaxPeers), + MaxConnsPerHost: uint32(cfg.MaxSameIP), + TargetOutbound: s.targetOutbound, + Dial: s.attemptDcrdDial, + DialTimeout: cfg.DialTimeout, OnConnection: func(conn *connmgr.Conn) { s.outboundPeerConnected(ctx, conn) }, From 8f9ff5809eb599d5d439e7160ba3c1e60a7be6f1 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Thu, 21 May 2026 16:44:17 -0500 Subject: [PATCH 23/51] connmgr: Add max per-host conn tests. This adds tests to ensure that the new max connections per host limiting properly enforces the limit including automatic outbound, manual outbound, inbound, and persistent connections. It also tests whitelisted addresses are exempt. --- internal/connmgr/connmanager_test.go | 175 ++++++++++++++++++++++++++- 1 file changed, 173 insertions(+), 2 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index aa03622f7..c1b7965d1 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -144,14 +144,18 @@ func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { // Assert the pending and active maps are mutually exclusive for both conn // IDs and addrs. // - // Also build a map of addrs to conn IDs in the pending, active, and - // persistent maps for the checks below. + // Also build a map of addrs to conn IDs and tally the per host counts in + // the pending, active, and persistent maps for the checks below. connIDByAddr := make(map[string]uint64) + perHostCounts := make(map[string]uint32) for id, info := range cm.pending { if _, ok := cm.active[id]; ok { t.Fatalf("conn ID %d is both pending and active", id) } connIDByAddr[info.addr.String()] = id + if _, ok := cm.persistent[id]; !ok { + perHostCounts[addrHostKey(info.addr)]++ + } } for id, conn := range cm.active { if _, ok := cm.pending[id]; ok { @@ -162,6 +166,9 @@ func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { t.Fatalf("addr %s is both pending and active", addrStr) } connIDByAddr[addrStr] = id + if _, ok := cm.persistent[id]; !ok { + perHostCounts[addrHostKey(&conn.remoteAddr)]++ + } } for id, entry := range cm.persistent { // Assert the conn ID of established/pending persistent conns matches. @@ -170,6 +177,7 @@ func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { t.Fatalf("conn ID for addr %s mismatch: %d != %d", addrStr, existingID, id) } + perHostCounts[addrHostKey(entry.addr)]++ connIDByAddr[addrStr] = id } @@ -179,6 +187,13 @@ func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { t.Fatalf("mismatched conn ID by addr maps\ngot: %v\nwant %v", cm.connIDByAddr, connIDByAddr) } + + // Assert the per host counts match the values obtained from manually + // tallying them. + if !reflect.DeepEqual(cm.perHostCounts, perHostCounts) { + t.Fatalf("mismatched per host count maps\ngot: %v\nwant %v", + cm.perHostCounts, perHostCounts) + } } // assertConnManagerCleanShutdown ensures the internal state of the passed @@ -203,6 +218,10 @@ func assertConnManagerCleanShutdown(t *testing.T, cm *ConnManager) { t.Fatalf("conn ID by addr map not empty: %d entries", len(cm.connIDByAddr)) } + if len(cm.perHostCounts) != 0 { + t.Fatalf("per host counts map not empty: %d entries", + len(cm.perHostCounts)) + } } // TestNewConfig tests that new ConnManager config is validated as expected. @@ -1864,3 +1883,155 @@ func TestMaxNormalConns(t *testing.T) { wg.Wait() assertConnManagerCleanShutdown(t, cmgr) } + +// TestMaxConnsPerHost ensures the connection manager limits the total number of +// connections with the same host to [Config.MaxConnsPerHost] including +// automatic outbound, manual outbound, inbound, and persistent connections. It +// also tests whitelisted addresses are exempt. +func TestMaxConnsPerHost(t *testing.T) { + t.Parallel() + + // nextSameHost is a convenience func to return a new address to the same IP + // with a different port on every invocation. + var nextPort atomic.Uint32 + nextSameHost := func() net.Addr { + addrStr := fmt.Sprintf("10.10.0.1:%d", nextPort.Add(1)+1024) + return mustParseAddrPort(addrStr) + } + + // nextSameHostWhitelisted is a convenience func to return a new address to + // the same whitelisted IP with a different port on every invocation. + allowedIP := netip.MustParseAddr("10.20.0.1") + nextSameWhitelistedHost := func() net.Addr { + addrStr := fmt.Sprintf("%s:%d", allowedIP, nextPort.Add(1)+1024) + return mustParseAddrPort(addrStr) + } + + const maxConnsPerHost = 3 + connected := make(chan *Conn, 1) + disconnected := make(chan *Conn, 1) + inboundConns := make(chan *Conn) + listener := newMockListener("127.0.0.1:9108") + var pauseTargetOutbound atomic.Bool + var totalPausedAddrs atomic.Uint32 + hitMaxFailedAttempts := make(chan struct{}) + cmgr := newTestConnManager(t, &Config{ + Listeners: []net.Listener{listener}, + MaxNormalConns: 30, // High enough to not interfere with per-host tests. + MaxConnsPerHost: maxConnsPerHost, + TargetOutbound: maxConnsPerHost, + RetryDuration: 50 * time.Millisecond, + Dial: mockDialer, + Whitelists: []netip.Prefix{netip.PrefixFrom(allowedIP, 32)}, + OnAccept: func(conn *Conn) { + inboundConns <- conn + }, + GetNewAddress: func() (net.Addr, error) { + if pauseTargetOutbound.Load() { + total := totalPausedAddrs.Add(1) + if total == maxFailedAttempts { + close(hitMaxFailedAttempts) + } + return nil, errors.New("network down") + } + return nextSameHost(), nil + }, + OnConnection: func(conn *Conn) { + connected <- conn + }, + OnDisconnection: func(conn *Conn) { + disconnected <- conn + }, + }) + cmgr.maxRetryDuration = cmgr.cfg.RetryDuration + ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + + // Wait for the maximum allowed non-whitelisted per-host automatic outbound + // conns. + outboundConns := make([]*Conn, 0, maxConnsPerHost) + for len(outboundConns) < maxConnsPerHost { + conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) + outboundConns = append(outboundConns, conn) + } + assertConnManagerInternalState(t, cmgr) + + // Ensure non-whitelisted manual connections that would exceed the max + // allowed per-host connections are rejected. + _, err := cmgr.Connect(ctx, nextSameHost()) + if !errors.Is(err, ErrMaxConnsPerHost) { + t.Fatalf("did not reject manual connection at per-host limit, err: %v", + err) + } + assertConnManagerInternalState(t, cmgr) + + // Ensure non-whitelisted inbound connections that would exceed the max + // allowed per-host connections are rejected. + go listener.Connect(nextSameHost()) + assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cmgr) + + // Ensure whitelisted manual connections are allowed to exceed the per-host + // limit. + for range maxConnsPerHost + 1 { + go cmgr.Connect(ctx, nextSameWhitelistedHost()) + assertConnReceived(t, connected, 0, ConnTypeManual) + } + + // Ensure whitelisted inbound connections are allowed to exceed the per-host + // limit. + go listener.Connect(nextSameWhitelistedHost()) + assertConnReceived(t, inboundConns, 0, ConnTypeInbound) + assertConnManagerInternalState(t, cmgr) + + // Ensure whitelisted persistent connections are allowed to exceed the + // per-host limit. + connID, err := cmgr.AddPersistent(nextSameWhitelistedHost()) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + assertConnReceived(t, connected, connID, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) + + // Pause the target outbound dials and remove one of the target outbound + // connections to make room for another manual connection with the same + // host. Then wait for the max failures to be hit so attempts are paused + // for a retry timeout. + pauseTargetOutbound.Store(true) + outboundConn := outboundConns[0] + outboundConn.Close() + assertConnReceived(t, disconnected, outboundConn.ID(), ConnTypeOutbound) + select { + case <-hitMaxFailedAttempts: + time.Sleep(connTestReceiveTimeout) + case <-time.After(maxFailedAttempts * connTestReceiveTimeout): + t.Fatal("did not reach max failed attempts before timeout") + } + + // Ensure a new non-whitelisted manual connection to the same host now + // succeeds. + go cmgr.Connect(ctx, nextSameHost()) + assertConnReceived(t, connected, 0, ConnTypeManual) + assertConnManagerInternalState(t, cmgr) + + // Unpause the target outbound dials and ensure no additional automatic + // outbound connections to the same host are made despite being under the + // target outbound. + noConnWaitTimeout := connTestReceiveTimeout + cmgr.cfg.RetryDuration + pauseTargetOutbound.Store(false) + assertNoConnReceivedTimeout(t, connected, noConnWaitTimeout) + assertConnManagerInternalState(t, cmgr) + + // Ensure persistent connections are also subject to the max per-host + // connections by adding one and confirming it is NOT established. + _, err = cmgr.AddPersistent(nextSameHost()) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } + assertNoConnReceivedTimeout(t, connected, noConnWaitTimeout) + assertConnManagerInternalState(t, cmgr) + + // Ensure clean shutdown of connection manager. + shutdown() + wg.Wait() + assertConnManagerCleanShutdown(t, cmgr) +} From a09f8f57c618985c175e50d8544e9edfb8ae4a3b Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Thu, 21 May 2026 17:20:32 -0500 Subject: [PATCH 24/51] connmgr: Update README.md for per-host conn limits. --- internal/connmgr/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/connmgr/README.md b/internal/connmgr/README.md index 850f2620e..30ebaef01 100644 --- a/internal/connmgr/README.md +++ b/internal/connmgr/README.md @@ -33,6 +33,8 @@ The following is a brief overview of the key features: - Supports manual connection establishment via `Connect` - Connection limits - Limits total normal (non-persistent) connections to `MaxNormalConns` + - Limits per-host connections to `MaxConnsPerHost` with exemptions for + whitelisted and loopback addresses - Duplicate address prevention - Rejects duplicate connections to and from the same address (host:port) - Whitelist support From 4d627a48907b84e572f37d973c90e88ee11f2d05 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sat, 23 May 2026 14:43:42 -0500 Subject: [PATCH 25/51] connmgr: Separate mock conn and addr test code. This moves the code related to mock addresses, connections, and listeners to a separate file. This helps keep it from cluttering up the main test code and also makes it easier to reuse in other packages. --- internal/connmgr/connmanager_test.go | 107 ------------------------ internal/connmgr/mockconn_test.go | 120 +++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 107 deletions(-) create mode 100644 internal/connmgr/mockconn_test.go diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index c1b7965d1..33ef22004 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -9,7 +9,6 @@ import ( "context" "errors" "fmt" - "io" "net" "net/netip" "reflect" @@ -59,56 +58,6 @@ func runConnMgrAsync(ctx context.Context, cmgr *ConnManager) (context.Context, c return ctx, cancel, &wg } -// mockAddr mocks a network address. -type mockAddr struct { - net, address string -} - -func (m mockAddr) Network() string { return m.net } -func (m mockAddr) String() string { return m.address } - -// mockConn mocks a network connection by implementing the net.Conn interface. -type mockConn struct { - io.Reader - io.Writer - io.Closer - - // local network, address for the connection. - lnet, laddr string - - // remote network, address for the connection. - rAddr net.Addr -} - -// LocalAddr returns the local address for the connection. -func (c mockConn) LocalAddr() net.Addr { - return &mockAddr{c.lnet, c.laddr} -} - -// RemoteAddr returns the remote address for the connection. -func (c mockConn) RemoteAddr() net.Addr { - return &mockAddr{c.rAddr.Network(), c.rAddr.String()} -} - -// Close handles closing the connection. -func (c mockConn) Close() error { - return nil -} - -func (c mockConn) SetDeadline(t time.Time) error { return nil } -func (c mockConn) SetReadDeadline(t time.Time) error { return nil } -func (c mockConn) SetWriteDeadline(t time.Time) error { return nil } - -// mockDialer mocks the net.Dial interface by returning a mock connection to -// the given address. -func mockDialer(ctx context.Context, network, addr string) (net.Conn, error) { - r, w := io.Pipe() - c := &mockConn{rAddr: &mockAddr{network, addr}} - c.Reader = r - c.Writer = w - return c, ctx.Err() -} - // newTestConnManager returns a new connection manager with the provided // configuration and some timeout tweaks so that it is suitable for use in the // tests. @@ -1488,62 +1437,6 @@ func TestConnectContext(t *testing.T) { assertConnManagerCleanShutdown(t, cmgr) } -// mockListener implements the net.Listener interface and is used to test -// code that deals with net.Listeners without having to actually make any real -// connections. -type mockListener struct { - localAddr string - provideConn chan net.Conn -} - -// Accept returns a mock connection when it receives a signal via the Connect -// function. -// -// This is part of the net.Listener interface. -func (m *mockListener) Accept() (net.Conn, error) { - for conn := range m.provideConn { - return conn, nil - } - return nil, errors.New("network connection closed") -} - -// Close closes the mock listener which will cause any blocked Accept -// operations to be unblocked and return errors. -// -// This is part of the net.Listener interface. -func (m *mockListener) Close() error { - close(m.provideConn) - return nil -} - -// Addr returns the address the mock listener was configured with. -// -// This is part of the net.Listener interface. -func (m *mockListener) Addr() net.Addr { - return &mockAddr{"tcp", m.localAddr} -} - -// Connect fakes a connection to the mock listener from the provided remote -// address. It will cause the Accept function to return a mock connection -// configured with the provided remote address and the local address for the -// mock listener. -func (m *mockListener) Connect(addr net.Addr) { - m.provideConn <- &mockConn{ - laddr: m.localAddr, - lnet: "tcp", - rAddr: addr, - } -} - -// newMockListener returns a new mock listener for the provided local address -// and port. No ports are actually opened. -func newMockListener(localAddr string) *mockListener { - return &mockListener{ - localAddr: localAddr, - provideConn: make(chan net.Conn), - } -} - // TestListeners ensures providing listeners to the connection manager along // with an accept callback works properly. func TestListeners(t *testing.T) { diff --git a/internal/connmgr/mockconn_test.go b/internal/connmgr/mockconn_test.go new file mode 100644 index 000000000..36150e5b8 --- /dev/null +++ b/internal/connmgr/mockconn_test.go @@ -0,0 +1,120 @@ +// Copyright (c) 2016 The btcsuite developers +// Copyright (c) 2019-2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import ( + "context" + "errors" + "io" + "net" + "time" +) + +// mockAddr mocks a network address. +type mockAddr struct { + net, address string +} + +func (m mockAddr) Network() string { return m.net } +func (m mockAddr) String() string { return m.address } + +// mockConn mocks a network connection by implementing the net.Conn interface. +type mockConn struct { + io.Reader + io.Writer + io.Closer + + // local network, address for the connection. + lnet, laddr string + + // remote network, address for the connection. + rAddr net.Addr +} + +// LocalAddr returns the local address for the connection. +func (c mockConn) LocalAddr() net.Addr { + return &mockAddr{c.lnet, c.laddr} +} + +// RemoteAddr returns the remote address for the connection. +func (c mockConn) RemoteAddr() net.Addr { + return &mockAddr{c.rAddr.Network(), c.rAddr.String()} +} + +// Close handles closing the connection. +func (c mockConn) Close() error { + return nil +} + +func (c mockConn) SetDeadline(t time.Time) error { return nil } +func (c mockConn) SetReadDeadline(t time.Time) error { return nil } +func (c mockConn) SetWriteDeadline(t time.Time) error { return nil } + +// mockDialer mocks the net.Dial interface by returning a mock connection to +// the given address. +func mockDialer(ctx context.Context, network, addr string) (net.Conn, error) { + r, w := io.Pipe() + c := &mockConn{rAddr: &mockAddr{network, addr}} + c.Reader = r + c.Writer = w + return c, ctx.Err() +} + +// mockListener implements the net.Listener interface and is used to test +// code that deals with net.Listeners without having to actually make any real +// connections. +type mockListener struct { + localAddr string + provideConn chan net.Conn +} + +// Accept returns a mock connection when it receives a signal via the Connect +// function. +// +// This is part of the net.Listener interface. +func (m *mockListener) Accept() (net.Conn, error) { + for conn := range m.provideConn { + return conn, nil + } + return nil, errors.New("network connection closed") +} + +// Close closes the mock listener which will cause any blocked Accept +// operations to be unblocked and return errors. +// +// This is part of the net.Listener interface. +func (m *mockListener) Close() error { + close(m.provideConn) + return nil +} + +// Addr returns the address the mock listener was configured with. +// +// This is part of the net.Listener interface. +func (m *mockListener) Addr() net.Addr { + return &mockAddr{"tcp", m.localAddr} +} + +// Connect fakes a connection to the mock listener from the provided remote +// address. It will cause the Accept function to return a mock connection +// configured with the provided remote address and the local address for the +// mock listener. +func (m *mockListener) Connect(addr net.Addr) { + m.provideConn <- &mockConn{ + laddr: m.localAddr, + lnet: "tcp", + rAddr: addr, + } +} + +// newMockListener returns a new mock listener for the provided local address +// and port. No ports are actually opened. +func newMockListener(localAddr string) *mockListener { + return &mockListener{ + localAddr: localAddr, + provideConn: make(chan net.Conn), + } +} From b2b20c9149630a8cfe11a68e0820bbacf5314bec Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sat, 23 May 2026 14:55:36 -0500 Subject: [PATCH 26/51] connmgr: Use more modern t.Cleanup in tests. Most of the tests in this package were written well before some of the more modern test conveniences like t.Cleanup were added. As a result, almost all of the tests repeat the code related to waiting for and asserting clean shutdown. This updates the tests so that the main method used to run the connection manager in all of the tests now registers a cleanup func via t.Cleanup to wait for clean shutdown and assert the internal state is clean as expected. This approach is much less error prone and convenient. --- internal/connmgr/connmanager_test.go | 161 +++++++-------------------- 1 file changed, 43 insertions(+), 118 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 33ef22004..50d082674 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -44,17 +44,31 @@ func mustParseAddrPort(addr string) *net.TCPAddr { } } -// runConnMgrAsync invokes the Run method on the passed connection manager in a -// separate goroutine and returns a cancelable context and wait group the caller -// can use to shutdown the connection manager and wait for clean shutdown. -func runConnMgrAsync(ctx context.Context, cmgr *ConnManager) (context.Context, context.CancelFunc, *sync.WaitGroup) { +// runConnMgrAsync invokes [ConnManager.Run] on the passed connection manager in +// a separate goroutine and returns a cancelable context and wait group the +// caller can use to shutdown the connection manager and wait for clean +// shutdown. +// +// It also registers a test cleanup func that waits for shutdown and asserts the +// internal state of the connection manager is empty as expected. +func runConnMgrAsync(t *testing.T, ctx context.Context, cm *ConnManager) (context.Context, context.CancelFunc, *sync.WaitGroup) { + t.Helper() + ctx, cancel := context.WithCancel(ctx) var wg sync.WaitGroup wg.Add(1) go func() { - cmgr.Run(ctx) + cm.Run(ctx) wg.Done() }() + t.Cleanup(func() { + t.Helper() + + cancel() + wg.Wait() + assertConnManagerCleanShutdown(t, cm) + }) + return ctx, cancel, &wg } @@ -93,8 +107,11 @@ func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { // Assert the pending and active maps are mutually exclusive for both conn // IDs and addrs. // - // Also build a map of addrs to conn IDs and tally the per host counts in - // the pending, active, and persistent maps for the checks below. + // Also build maps of the following data in the pending, active, and + // persistent maps: + // + // - addrs to conn IDs + // - the per host counts connIDByAddr := make(map[string]uint64) perHostCounts := make(map[string]uint32) for id, info := range cm.pending { @@ -396,7 +413,7 @@ func TestConnectMode(t *testing.T) { connected <- conn }, }) - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) addr := mustParseAddrPort("127.0.0.1:18555") go cmgr.Connect(ctx, addr) @@ -405,11 +422,6 @@ func TestConnectMode(t *testing.T) { assertConnReceived(t, connected, 0, ConnTypeManual) assertNoConnReceived(t, connected) assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestDisconnect ensures that [ConnManager.Disconnect] properly disconnects @@ -451,7 +463,7 @@ func TestDisconnect(t *testing.T) { disconnected <- conn }, }) - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) // Attempt a connection to a localhost IP. notifyDialed.Store(true) @@ -567,11 +579,6 @@ func TestDisconnect(t *testing.T) { } assertConnReceived(t, disconnected, connID, ConnTypeManual) assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestRemove ensures that [ConnManager.Remove] properly removes pending and @@ -614,7 +621,7 @@ func TestRemove(t *testing.T) { disconnected <- conn }, }) - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) // Ensure removing an ID that doesn't exist returns the expected error. if err := cmgr.Remove(^uint64(0)); !errors.Is(err, ErrNotFound) { @@ -744,11 +751,6 @@ func TestRemove(t *testing.T) { } assertConnReceived(t, disconnected, connID, ConnTypeManual) assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestTargetOutbound tests the target number of outbound connections @@ -771,7 +773,7 @@ func TestTargetOutbound(t *testing.T) { connected <- conn }, }) - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cmgr) // Ensure only the expected number of target outbound conns are established // and no more. @@ -780,11 +782,6 @@ func TestTargetOutbound(t *testing.T) { } assertNoConnReceived(t, connected) assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestDoubleClose ensures closing a connection multiple times is a noop after @@ -803,7 +800,7 @@ func TestDoubleClose(t *testing.T) { connected <- conn }, }) - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cmgr) // Wait for the connection to be established. conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) @@ -825,11 +822,6 @@ func TestDoubleClose(t *testing.T) { t.Fatal("connection closed more than once") } assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestRetryPersistent tests that persistent connections are retried. @@ -849,7 +841,7 @@ func TestRetryPersistent(t *testing.T) { disconnected <- conn }, }) - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cmgr) addr := mustParseAddrPort("127.0.0.1:18555") connID, err := cmgr.AddPersistent(addr) @@ -876,11 +868,6 @@ func TestRetryPersistent(t *testing.T) { assertConnReceived(t, disconnected, connID, ConnTypeManual) assertRemovedPersistent(t, cmgr, addr) assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestMaxPersistent ensures [ConnManager.AddPersistent] limits the maximum @@ -900,7 +887,7 @@ func TestMaxPersistent(t *testing.T) { disconnected <- conn }, }) - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cmgr) var numAddrs uint32 nextAddr := func() net.Addr { @@ -962,11 +949,6 @@ func TestMaxPersistent(t *testing.T) { t.Fatalf("failed to add persistent connection %v: %v", addr, err) } assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestMaxRetryDuration tests the maximum retry duration. @@ -1002,7 +984,7 @@ func TestMaxRetryDuration(t *testing.T) { connected <- conn }, }) - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cmgr) connID, err := cmgr.AddPersistent(mustParseAddrPort("127.0.0.1:18555")) if err != nil { @@ -1019,11 +1001,6 @@ func TestMaxRetryDuration(t *testing.T) { const timeout = connTestReceiveTimeout + networkUpTimeout assertConnReceivedTimeout(t, connected, timeout, connID, ConnTypeManual) assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestNetworkFailure tests that the connection manager handles a network @@ -1059,7 +1036,7 @@ func TestNetworkFailure(t *testing.T) { conn.RemoteAddr()) }, }) - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + _, shutdown, wg := runConnMgrAsync(t, context.Background(), cmgr) // Shutdown the connection manager after the max failed attempts is reached // and an additional retry duration has passed and then wait for the @@ -1083,8 +1060,6 @@ func TestNetworkFailure(t *testing.T) { t.Fatalf("unexpected number of dials - got %v, want <= %v", gotDials, wantMaxDials) } - - assertConnManagerCleanShutdown(t, cmgr) } // TestMultipleFailedConns ensures that the connection manager remains @@ -1113,7 +1088,7 @@ func TestMultipleFailedConns(t *testing.T) { Dial: errDialer, }) cmgr.maxRetryDuration = maxRetryDuration - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cmgr) // Establish several connection requests to localhost IPs. for i := range targetFailed { @@ -1148,11 +1123,6 @@ func TestMultipleFailedConns(t *testing.T) { t.Fatal("timeout servicing connmgr requests") } assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestShutdownFailedConns tests that failed connections are ignored after @@ -1172,7 +1142,7 @@ func TestShutdownFailedConns(t *testing.T) { Dial: waitDialer, }) cmgr.maxRetryDuration = retryTimeout - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cmgr) // Add a persistent connection. addr := mustParseAddrPort("127.0.0.1:18555") @@ -1190,11 +1160,6 @@ func TestShutdownFailedConns(t *testing.T) { t.Fatal("timeout waiting for dial") } time.Sleep(connTestNonReceiveTimeout) - shutdown() - - // Ensure clean shutdown of connection manager. - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestRemovePendingConnection ensures that removing a pending outbound @@ -1215,7 +1180,7 @@ func TestRemovePendingConnection(t *testing.T) { cmgr := newTestConnManager(t, &Config{ Dial: indefiniteDialer, }) - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) // Establish a connection request to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") @@ -1250,11 +1215,6 @@ func TestRemovePendingConnection(t *testing.T) { t.Fatalf("connection %s is still pending", addr) } assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestCancelIgnoreDelayedConnection tests that a canceled pending persistent @@ -1293,7 +1253,7 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { connected <- conn }, }) - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cmgr) // Establish a persistent connection to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") @@ -1325,11 +1285,6 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { // properly elapse. assertNoConnReceivedTimeout(t, connected, 5*retryTimeout) assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestDialTimeout ensure [Config.Timeout] works as intended by creating a @@ -1356,7 +1311,7 @@ func TestDialTimeout(t *testing.T) { Dial: timeoutDialer, DialTimeout: dialTimeout, }) - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) // Establish a connection to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") @@ -1371,11 +1326,6 @@ func TestDialTimeout(t *testing.T) { t.Fatal("timeout waiting for dial cancellation") } assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestConnectContext ensures the [ConnManager.Connect] method works as intended @@ -1394,7 +1344,7 @@ func TestConnectContext(t *testing.T) { cmgr := newTestConnManager(t, &Config{ Dial: indefiniteDialer, }) - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) // Establish a connection request to a localhost IP with a separate context // that can be canceled. @@ -1430,11 +1380,6 @@ func TestConnectContext(t *testing.T) { t.Fatal("timeout waiting for dial cancellation") } assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestListeners ensures providing listeners to the connection manager along @@ -1455,7 +1400,7 @@ func TestListeners(t *testing.T) { }, Dial: mockDialer, }) - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cmgr) // Fake a couple of mock connections to each of the listeners. go func() { @@ -1472,11 +1417,6 @@ func TestListeners(t *testing.T) { assertConnReceived(t, receivedConns, 0, ConnTypeInbound) } assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestRejectDuplicateConns ensures duplicate addresses are rejected. This @@ -1514,7 +1454,7 @@ func TestRejectDuplicateConns(t *testing.T) { disconnected <- conn }, }) - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) // Dial a manual connection and wait for it to become pending. addr := mustParseAddrPort("127.0.0.1:18555") @@ -1616,11 +1556,6 @@ func TestRejectDuplicateConns(t *testing.T) { err) } assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestMaxNormalConns ensures the connection manager limits the total number of @@ -1680,7 +1615,7 @@ func TestMaxNormalConns(t *testing.T) { }, }) cmgr.maxRetryDuration = cmgr.cfg.RetryDuration - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) // Wait for the expected number of target outbound conns to be established. outbounds := make([]*Conn, 0, targetOutbound) @@ -1770,11 +1705,6 @@ func TestMaxNormalConns(t *testing.T) { } assertConnReceived(t, connected, connID, ConnTypeManual) assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } // TestMaxConnsPerHost ensures the connection manager limits the total number of @@ -1837,7 +1767,7 @@ func TestMaxConnsPerHost(t *testing.T) { }, }) cmgr.maxRetryDuration = cmgr.cfg.RetryDuration - ctx, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) // Wait for the maximum allowed non-whitelisted per-host automatic outbound // conns. @@ -1922,9 +1852,4 @@ func TestMaxConnsPerHost(t *testing.T) { } assertNoConnReceivedTimeout(t, connected, noConnWaitTimeout) assertConnManagerInternalState(t, cmgr) - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() - assertConnManagerCleanShutdown(t, cmgr) } From f7765383886097bdc226c70d1e291d97c59a47ef Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sat, 23 May 2026 15:02:59 -0500 Subject: [PATCH 27/51] connmgr: Consistent naming in tests. This makes the name of the connection manager instances in the test code match the naming used in the implmentation code. --- internal/connmgr/connmanager_test.go | 396 +++++++++++++-------------- 1 file changed, 198 insertions(+), 198 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 50d082674..0f46af1ec 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -78,12 +78,12 @@ func runConnMgrAsync(t *testing.T, ctx context.Context, cm *ConnManager) (contex func newTestConnManager(t *testing.T, cfg *Config) *ConnManager { t.Helper() - cmgr, err := New(cfg) + cm, err := New(cfg) if err != nil { t.Fatalf("New: unexpected error: %v", err) } - cmgr.maxRetryDuration = defaultTestMaxRetryDuration - return cmgr + cm.maxRetryDuration = defaultTestMaxRetryDuration + return cm } // assertConnManagerInternalState ensures the internal state of the passed @@ -270,7 +270,7 @@ func TestIsWhitelisted(t *testing.T) { for _, test := range tests { // Parse the whitelist entries for the test. - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ Dial: mockDialer, Whitelists: test.prefixes, }) @@ -282,7 +282,7 @@ func TestIsWhitelisted(t *testing.T) { t.Fatalf("%q-%q: failed to parse address: %v", test.name, pmTest.addr, err) } - if got := cmgr.IsWhitelisted(addr); got != pmTest.whitelisted { + if got := cm.IsWhitelisted(addr); got != pmTest.whitelisted { t.Errorf("%q-%q: mismatched result -- got %v, want %v", test.name, pmTest.addr, got, pmTest.whitelisted) continue @@ -406,22 +406,22 @@ func TestConnectMode(t *testing.T) { t.Parallel() connected := make(chan *Conn) - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ TargetOutbound: 2, Dial: mockDialer, OnConnection: func(conn *Conn) { connected <- conn }, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) addr := mustParseAddrPort("127.0.0.1:18555") - go cmgr.Connect(ctx, addr) + go cm.Connect(ctx, addr) // Ensure that only a single connection is received. assertConnReceived(t, connected, 0, ConnTypeManual) assertNoConnReceived(t, connected) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestDisconnect ensures that [ConnManager.Disconnect] properly disconnects @@ -454,7 +454,7 @@ func TestDisconnect(t *testing.T) { } return conn, err } - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ Dial: pendingDialer, OnConnection: func(conn *Conn) { connected <- conn @@ -463,14 +463,14 @@ func TestDisconnect(t *testing.T) { disconnected <- conn }, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) // Attempt a connection to a localhost IP. notifyDialed.Store(true) waitForPending.Store(true) notifyCanceled.Store(true) addr := mustParseAddrPort("127.0.0.1:18555") - go cmgr.Connect(ctx, addr) + go cm.Connect(ctx, addr) // Wait for the connection manager to attempt to dial and ensure the // connection is marked as pending while the dialer is blocked. @@ -479,15 +479,15 @@ func TestDisconnect(t *testing.T) { case <-time.After(time.Millisecond * 5): t.Fatal("timeout waiting for dial") } - assertPendingAddr(t, cmgr, addr) - assertConnManagerInternalState(t, cmgr) + assertPendingAddr(t, cm, addr) + assertConnManagerInternalState(t, cm) // Disconnect the connection attempt while it's still pending. - connID, _ := pendingAddrConnID(cmgr, addr) - if err := cmgr.Disconnect(connID); err != nil { + connID, _ := pendingAddrConnID(cm, addr) + if err := cm.Disconnect(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Allow the dialer to proceed with the disconnected connection attempt and // then wait for the dialer to signal the context associated with the dial @@ -502,37 +502,37 @@ func TestDisconnect(t *testing.T) { case <-time.After(time.Millisecond * 5): t.Fatal("timeout waiting for cancel") } - if _, ok := pendingAddrConnID(cmgr, addr); ok { + if _, ok := pendingAddrConnID(cm, addr); ok { t.Fatalf("connection %s is still pending", addr) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Start a connection attempt and wait for it to be established. notifyDialed.Store(false) waitForPending.Store(false) notifyCanceled.Store(false) - go cmgr.Connect(ctx, addr) + go cm.Connect(ctx, addr) conn := assertConnReceived(t, connected, 0, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Disconnect the established connection and wait for the disconnect // notification to ensure it is disconnected as intended. connID = conn.ID() - if err := cmgr.Disconnect(connID); err != nil { + if err := cm.Disconnect(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Add a persistent connection back to the same address. notifyDialed.Store(true) waitForPending.Store(true) notifyCanceled.Store(true) - connID, err := cmgr.AddPersistent(addr) + connID, err := cm.AddPersistent(addr) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Wait for the connection manager to attempt to dial and ensure the // connection is marked as pending while the dialer is blocked. @@ -541,14 +541,14 @@ func TestDisconnect(t *testing.T) { case <-time.After(time.Millisecond * 5): t.Fatal("timeout waiting for dial") } - assertPendingAddr(t, cmgr, addr) - assertConnManagerInternalState(t, cmgr) + assertPendingAddr(t, cm, addr) + assertConnManagerInternalState(t, cm) // Disconnect the persistent connection attempt while it's still pending. - if err := cmgr.Disconnect(connID); err != nil { + if err := cm.Disconnect(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Allow the dialer to proceed with the disconnected persistent connection // attempt and then wait for the dialer to signal the context associated @@ -570,15 +570,15 @@ func TestDisconnect(t *testing.T) { // Wait for the retry to be established. assertConnReceived(t, connected, connID, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Disconnect the established persistent connection and wait for the // disconnect notification to ensure it is disconnected as intended. - if err := cmgr.Disconnect(connID); err != nil { + if err := cm.Disconnect(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestRemove ensures that [ConnManager.Remove] properly removes pending and @@ -612,7 +612,7 @@ func TestRemove(t *testing.T) { } return conn, err } - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ Dial: pendingDialer, OnConnection: func(conn *Conn) { connected <- conn @@ -621,10 +621,10 @@ func TestRemove(t *testing.T) { disconnected <- conn }, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) // Ensure removing an ID that doesn't exist returns the expected error. - if err := cmgr.Remove(^uint64(0)); !errors.Is(err, ErrNotFound) { + if err := cm.Remove(^uint64(0)); !errors.Is(err, ErrNotFound) { t.Fatalf("mismatched remove error: got %v, want %v", err, ErrNotFound) } @@ -633,7 +633,7 @@ func TestRemove(t *testing.T) { waitForPending.Store(true) notifyCanceled.Store(true) addr := mustParseAddrPort("127.0.0.1:18555") - go cmgr.Connect(ctx, addr) + go cm.Connect(ctx, addr) // Wait for the connection manager to attempt to dial and ensure the // connection is marked as pending while the dialer is blocked. @@ -642,12 +642,12 @@ func TestRemove(t *testing.T) { case <-time.After(time.Millisecond * 5): t.Fatal("timeout waiting for dial") } - assertPendingAddr(t, cmgr, addr) - assertConnManagerInternalState(t, cmgr) + assertPendingAddr(t, cm, addr) + assertConnManagerInternalState(t, cm) // Remove the connection attempt while it's still pending. - connID, _ := pendingAddrConnID(cmgr, addr) - if err := cmgr.Remove(connID); err != nil { + connID, _ := pendingAddrConnID(cm, addr) + if err := cm.Remove(connID); err != nil { t.Fatalf("unexpected remove err: %v", err) } @@ -664,37 +664,37 @@ func TestRemove(t *testing.T) { case <-time.After(time.Millisecond * 5): t.Fatal("timeout waiting for cancel") } - if _, ok := pendingAddrConnID(cmgr, addr); ok { + if _, ok := pendingAddrConnID(cm, addr); ok { t.Fatalf("connection %s is still pending", addr) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Start a connection attempt and wait for it to be established. notifyDialed.Store(false) waitForPending.Store(false) notifyCanceled.Store(false) - go cmgr.Connect(ctx, addr) + go cm.Connect(ctx, addr) conn := assertConnReceived(t, connected, 0, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Remove the established connection and wait for the disconnect // notification to ensure it is disconnected as intended. connID = conn.ID() - if err := cmgr.Remove(connID); err != nil { + if err := cm.Remove(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Add a persistent connection back to the same address. notifyDialed.Store(true) waitForPending.Store(true) notifyCanceled.Store(true) - connID, err := cmgr.AddPersistent(addr) + connID, err := cm.AddPersistent(addr) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Wait for the connection manager to attempt to dial and ensure the // connection is marked as pending while the dialer is blocked. @@ -703,13 +703,13 @@ func TestRemove(t *testing.T) { case <-time.After(time.Millisecond * 5): t.Fatal("timeout waiting for dial") } - assertPendingAddr(t, cmgr, addr) + assertPendingAddr(t, cm, addr) // Remove the persistent connection attempt while it's still pending. - if err := cmgr.Remove(connID); err != nil { + if err := cm.Remove(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Allow the dialer to proceed with the removed persistent connection // attempt and then wait for the dialer to signal the context associated @@ -728,29 +728,29 @@ func TestRemove(t *testing.T) { case <-time.After(time.Millisecond * 5): t.Fatal("timeout waiting for cancel") } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Add a persistent connection back to the same address and wait for it to // be established. notifyDialed.Store(false) waitForPending.Store(false) notifyCanceled.Store(false) - connID, err = cmgr.AddPersistent(addr) + connID, err = cm.AddPersistent(addr) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } conn2 := assertConnReceived(t, connected, connID, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Remove the established persistent connection and wait for the disconnect // notification to ensure it is disconnected as intended. Also, ensure the // persistent connection entry is removed. connID = conn2.ID() - if err := cmgr.Remove(connID); err != nil { + if err := cm.Remove(connID); err != nil { t.Fatalf("unexpected disconnect err: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestTargetOutbound tests the target number of outbound connections @@ -762,7 +762,7 @@ func TestTargetOutbound(t *testing.T) { const targetOutbound = 10 var nextAddr atomic.Uint32 connected := make(chan *Conn) - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ TargetOutbound: targetOutbound, Dial: mockDialer, GetNewAddress: func() (net.Addr, error) { @@ -773,7 +773,7 @@ func TestTargetOutbound(t *testing.T) { connected <- conn }, }) - runConnMgrAsync(t, context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cm) // Ensure only the expected number of target outbound conns are established // and no more. @@ -781,7 +781,7 @@ func TestTargetOutbound(t *testing.T) { assertConnReceived(t, connected, 0, ConnTypeOutbound) } assertNoConnReceived(t, connected) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestDoubleClose ensures closing a connection multiple times is a noop after @@ -790,7 +790,7 @@ func TestDoubleClose(t *testing.T) { t.Parallel() connected := make(chan *Conn) - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ TargetOutbound: 1, Dial: mockDialer, GetNewAddress: func() (net.Addr, error) { @@ -800,11 +800,11 @@ func TestDoubleClose(t *testing.T) { connected <- conn }, }) - runConnMgrAsync(t, context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cm) // Wait for the connection to be established. conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Override the close func to cleanly detect closes. var numClosed uint32 @@ -821,7 +821,7 @@ func TestDoubleClose(t *testing.T) { if numClosed != 1 { t.Fatal("connection closed more than once") } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestRetryPersistent tests that persistent connections are retried. @@ -830,7 +830,7 @@ func TestRetryPersistent(t *testing.T) { connected := make(chan *Conn) disconnected := make(chan *Conn) - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ RetryDuration: time.Millisecond, TargetOutbound: 1, Dial: mockDialer, @@ -841,14 +841,14 @@ func TestRetryPersistent(t *testing.T) { disconnected <- conn }, }) - runConnMgrAsync(t, context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cm) addr := mustParseAddrPort("127.0.0.1:18555") - connID, err := cmgr.AddPersistent(addr) + connID, err := cm.AddPersistent(addr) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } - if !cmgr.IsPersistent(connID) { + if !cm.IsPersistent(connID) { t.Fatal("IsPersistent did not reported true for persistent conn") } @@ -858,16 +858,16 @@ func TestRetryPersistent(t *testing.T) { conn.Close() assertConnReceived(t, disconnected, connID, ConnTypeManual) assertConnReceived(t, connected, connID, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Remove the persistent connection, wait for it to disconnect, and ensure // it is actually removed. - if err := cmgr.Remove(connID); err != nil { + if err := cm.Remove(connID); err != nil { t.Fatalf("failed to remove persistent connection: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) - assertRemovedPersistent(t, cmgr, addr) - assertConnManagerInternalState(t, cmgr) + assertRemovedPersistent(t, cm, addr) + assertConnManagerInternalState(t, cm) } // TestMaxPersistent ensures [ConnManager.AddPersistent] limits the maximum @@ -878,7 +878,7 @@ func TestMaxPersistent(t *testing.T) { connected := make(chan *Conn) disconnected := make(chan *Conn) - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ Dial: mockDialer, OnConnection: func(conn *Conn) { connected <- conn @@ -887,7 +887,7 @@ func TestMaxPersistent(t *testing.T) { disconnected <- conn }, }) - runConnMgrAsync(t, context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cm) var numAddrs uint32 nextAddr := func() net.Addr { @@ -901,7 +901,7 @@ func TestMaxPersistent(t *testing.T) { addrs := make([]net.Addr, 0, MaxPersistent) for range MaxPersistent { addr := nextAddr() - connID, err := cmgr.AddPersistent(addr) + connID, err := cm.AddPersistent(addr) if err != nil { t.Fatalf("failed to add persistent connection %v: %v", addr, err) } @@ -910,45 +910,45 @@ func TestMaxPersistent(t *testing.T) { // Wait for the connection. assertConnReceived(t, connected, connID, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // Attempting to add more than the max allowed number of persistent conns // should be rejected. - _, err := cmgr.AddPersistent(nextAddr()) + _, err := cm.AddPersistent(nextAddr()) if !errors.Is(err, ErrMaxPersistent) { t.Fatalf("did not reject > max persistent, err: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Ensure disconnecting the persistent conn does not incorrectly decrement // the count. connID, addr := connIDs[0], addrs[0] - if err := cmgr.Disconnect(connID); err != nil { + if err := cm.Disconnect(connID); err != nil { t.Fatalf("failed to disconnect persistent conn %v: %v", addr, err) } - _, err = cmgr.AddPersistent(nextAddr()) + _, err = cm.AddPersistent(nextAddr()) if !errors.Is(err, ErrMaxPersistent) { t.Fatalf("did not reject max persistent after dc, err: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Remove the first persistent connection, wait for it to disconnect, and // ensure it is actually removed. - if err := cmgr.Remove(connID); err != nil { + if err := cm.Remove(connID); err != nil { t.Fatalf("failed to remove persistent conn %v: %v", addr, err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) - assertRemovedPersistent(t, cmgr, addr) - assertConnManagerInternalState(t, cmgr) + assertRemovedPersistent(t, cm, addr) + assertConnManagerInternalState(t, cm) // A new persistent conn should now be allowed. addr = nextAddr() - _, err = cmgr.AddPersistent(addr) + _, err = cm.AddPersistent(addr) if err != nil { t.Fatalf("failed to add persistent connection %v: %v", addr, err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestMaxRetryDuration tests the maximum retry duration. @@ -976,7 +976,7 @@ func TestMaxRetryDuration(t *testing.T) { } connected := make(chan *Conn) - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ RetryDuration: time.Millisecond, TargetOutbound: 1, Dial: timedDialer, @@ -984,9 +984,9 @@ func TestMaxRetryDuration(t *testing.T) { connected <- conn }, }) - runConnMgrAsync(t, context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cm) - connID, err := cmgr.AddPersistent(mustParseAddrPort("127.0.0.1:18555")) + connID, err := cm.AddPersistent(mustParseAddrPort("127.0.0.1:18555")) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } @@ -1000,7 +1000,7 @@ func TestMaxRetryDuration(t *testing.T) { }) const timeout = connTestReceiveTimeout + networkUpTimeout assertConnReceivedTimeout(t, connected, timeout, connID, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestNetworkFailure tests that the connection manager handles a network @@ -1023,7 +1023,7 @@ func TestNetworkFailure(t *testing.T) { return nil, errors.New("network down") } var nextAddr atomic.Uint32 - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ TargetOutbound: targetOutbound, RetryDuration: retryTimeout, Dial: errDialer, @@ -1036,7 +1036,7 @@ func TestNetworkFailure(t *testing.T) { conn.RemoteAddr()) }, }) - _, shutdown, wg := runConnMgrAsync(t, context.Background(), cmgr) + _, shutdown, wg := runConnMgrAsync(t, context.Background(), cm) // Shutdown the connection manager after the max failed attempts is reached // and an additional retry duration has passed and then wait for the @@ -1083,22 +1083,22 @@ func TestMultipleFailedConns(t *testing.T) { } return nil, errors.New("network down") } - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ RetryDuration: maxRetryDuration, Dial: errDialer, }) - cmgr.maxRetryDuration = maxRetryDuration - runConnMgrAsync(t, context.Background(), cmgr) + cm.maxRetryDuration = maxRetryDuration + runConnMgrAsync(t, context.Background(), cm) // Establish several connection requests to localhost IPs. for i := range targetFailed { addr := mustParseAddrPort(fmt.Sprintf("127.0.0.%d:18555", i+1)) - _, err := cmgr.AddPersistent(addr) + _, err := cm.AddPersistent(addr) if err != nil { t.Fatalf("unexpected add err: %v", err) } } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Wait for the target number of dials and ensure they happen simultaneously // by checking it happens before the retry timeout. @@ -1107,14 +1107,14 @@ func TestMultipleFailedConns(t *testing.T) { case <-time.After(20 * time.Millisecond): t.Fatal("did not reach target number of dials before timeout") } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Ensure that the connection manager still responds to requests while the // failed connections are still retrying. disconnected := make(chan struct{}) go func() { const badID = ^uint64(0) - cmgr.Disconnect(badID) + cm.Disconnect(badID) close(disconnected) }() select { @@ -1122,7 +1122,7 @@ func TestMultipleFailedConns(t *testing.T) { case <-time.After(20 * time.Millisecond): t.Fatal("timeout servicing connmgr requests") } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestShutdownFailedConns tests that failed connections are ignored after @@ -1137,20 +1137,20 @@ func TestShutdownFailedConns(t *testing.T) { closeOnce.Do(func() { close(dialed) }) return nil, errors.New("network down") } - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ RetryDuration: retryTimeout, Dial: waitDialer, }) - cmgr.maxRetryDuration = retryTimeout - runConnMgrAsync(t, context.Background(), cmgr) + cm.maxRetryDuration = retryTimeout + runConnMgrAsync(t, context.Background(), cm) // Add a persistent connection. addr := mustParseAddrPort("127.0.0.1:18555") - _, err := cmgr.AddPersistent(addr) + _, err := cm.AddPersistent(addr) if err != nil { t.Fatalf("unexpected error: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Shutdown the connection manager during the retry timeout after a failed // dial attempt. @@ -1177,15 +1177,15 @@ func TestRemovePendingConnection(t *testing.T) { close(canceled) return nil, errors.New("error") } - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ Dial: indefiniteDialer, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) // Establish a connection request to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") - go cmgr.Connect(ctx, addr) - assertConnManagerInternalState(t, cmgr) + go cm.Connect(ctx, addr) + assertConnManagerInternalState(t, cm) // Wait for the connection manager to attempt to dial and ensure the // connection is marked as pending while the dialer is blocked. @@ -1194,15 +1194,15 @@ func TestRemovePendingConnection(t *testing.T) { case <-time.After(time.Millisecond * 20): t.Fatal("timeout waiting for dial") } - assertPendingAddr(t, cmgr, addr) - assertConnManagerInternalState(t, cmgr) + assertPendingAddr(t, cm, addr) + assertConnManagerInternalState(t, cm) // Cancel the connection attempt while it's still pending. - connID, _ := pendingAddrConnID(cmgr, addr) - if err := cmgr.Remove(connID); err != nil { + connID, _ := pendingAddrConnID(cm, addr) + if err := cm.Remove(connID); err != nil { t.Fatalf("unexpected remove err: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Wait for the dialer to signal the context associated with the dial was // canceled and ensure the internal pending state is removed. @@ -1211,10 +1211,10 @@ func TestRemovePendingConnection(t *testing.T) { case <-time.After(time.Millisecond * 20): t.Fatal("timeout waiting for cancel") } - if _, ok := pendingAddrConnID(cmgr, addr); ok { + if _, ok := pendingAddrConnID(cm, addr); ok { t.Fatalf("connection %s is still pending", addr) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestCancelIgnoreDelayedConnection tests that a canceled pending persistent @@ -1246,22 +1246,22 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { } connected := make(chan *Conn) - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ Dial: failingDialer, RetryDuration: retryTimeout, OnConnection: func(conn *Conn) { connected <- conn }, }) - runConnMgrAsync(t, context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cm) // Establish a persistent connection to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") - connID, err := cmgr.AddPersistent(addr) + connID, err := cm.AddPersistent(addr) if err != nil { t.Fatalf("unexpected error: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Wait for the retry and ensure the connection is pending. select { @@ -1269,12 +1269,12 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { case <-time.After(20 * time.Millisecond): t.Fatalf("did not get retry before timeout") } - assertPendingAddr(t, cmgr, addr) - assertConnManagerInternalState(t, cmgr) + assertPendingAddr(t, cm, addr) + assertConnManagerInternalState(t, cm) // Remove the connection and then immediately allow the next connection to // succeed. - if err := cmgr.Remove(connID); err != nil { + if err := cm.Remove(connID); err != nil { t.Fatalf("unexpected remove err: %v", err) } close(connect) @@ -1284,7 +1284,7 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { // timeout window to ensure the connection manager's backoff is allowed to // properly elapse. assertNoConnReceivedTimeout(t, connected, 5*retryTimeout) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestDialTimeout ensure [Config.Timeout] works as intended by creating a @@ -1307,16 +1307,16 @@ func TestDialTimeout(t *testing.T) { return mockDialer(ctx, network, addr) } - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ Dial: timeoutDialer, DialTimeout: dialTimeout, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) // Establish a connection to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") - go cmgr.Connect(ctx, addr) - assertConnManagerInternalState(t, cmgr) + go cm.Connect(ctx, addr) + assertConnManagerInternalState(t, cm) // Wait to receive the signal that the dialer context was cancelled, which // means the dial timeout was hit. @@ -1325,7 +1325,7 @@ func TestDialTimeout(t *testing.T) { case <-time.After(dialTimeout * 10): t.Fatal("timeout waiting for dial cancellation") } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestConnectContext ensures the [ConnManager.Connect] method works as intended @@ -1341,10 +1341,10 @@ func TestConnectContext(t *testing.T) { <-ctx.Done() return nil, ctx.Err() } - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ Dial: indefiniteDialer, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) // Establish a connection request to a localhost IP with a separate context // that can be canceled. @@ -1352,7 +1352,7 @@ func TestConnectContext(t *testing.T) { connectCtx, cancelConnect := context.WithCancel(ctx) connectErr := make(chan error, 1) go func() { - _, err := cmgr.Connect(connectCtx, addr) + _, err := cm.Connect(connectCtx, addr) connectErr <- err }() @@ -1364,8 +1364,8 @@ func TestConnectContext(t *testing.T) { case <-time.After(time.Millisecond * 20): t.Fatal("timeout waiting for dial") } - assertPendingAddr(t, cmgr, addr) - assertConnManagerInternalState(t, cmgr) + assertPendingAddr(t, cm, addr) + assertConnManagerInternalState(t, cm) // Cancel the connection context, wait for the error from connect, and // ensure it is the expected error. @@ -1379,7 +1379,7 @@ func TestConnectContext(t *testing.T) { case <-time.After(10 * time.Millisecond): t.Fatal("timeout waiting for dial cancellation") } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestListeners ensures providing listeners to the connection manager along @@ -1393,14 +1393,14 @@ func TestListeners(t *testing.T) { listener1 := newMockListener("127.0.0.1:9108") listener2 := newMockListener("127.0.0.1:9208") listeners := []net.Listener{listener1, listener2} - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ Listeners: listeners, OnAccept: func(conn *Conn) { receivedConns <- conn }, Dial: mockDialer, }) - runConnMgrAsync(t, context.Background(), cmgr) + runConnMgrAsync(t, context.Background(), cm) // Fake a couple of mock connections to each of the listeners. go func() { @@ -1416,7 +1416,7 @@ func TestListeners(t *testing.T) { for range expectedNumConns { assertConnReceived(t, receivedConns, 0, ConnTypeInbound) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestRejectDuplicateConns ensures duplicate addresses are rejected. This @@ -1441,7 +1441,7 @@ func TestRejectDuplicateConns(t *testing.T) { <-pending return mockDialer(ctx, network, addr) } - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ Listeners: []net.Listener{listener}, OnAccept: func(conn *Conn) { inboundConns <- conn @@ -1454,108 +1454,108 @@ func TestRejectDuplicateConns(t *testing.T) { disconnected <- conn }, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) + ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) // Dial a manual connection and wait for it to become pending. addr := mustParseAddrPort("127.0.0.1:18555") - go cmgr.Connect(ctx, addr) + go cm.Connect(ctx, addr) select { case <-dialed: case <-time.After(time.Millisecond * 5): t.Fatal("did not receive pending dial before timeout") } - assertPendingAddr(t, cmgr, addr) - assertConnManagerInternalState(t, cmgr) + assertPendingAddr(t, cm, addr) + assertConnManagerInternalState(t, cm) // Duplicate connect to the pending address should be rejected. - if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyPending) { + if _, err := cm.Connect(ctx, addr); !errors.Is(err, ErrAlreadyPending) { t.Fatalf("did not reject duplicate pending connection, err: %v", err) } // Inbound attempts from the pending outbound address should be rejected. go listener.Connect(addr) assertNoConnReceived(t, inboundConns) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Allow the pending connection to complete. close(pending) conn := assertConnReceived(t, connected, 0, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Duplicate connect to the established address should be rejected. - if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { + if _, err := cm.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { t.Fatalf("did not reject duplicate active connection, err: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Inbound attempts from the established outbound address should be // rejected. go listener.Connect(addr) assertNoConnReceived(t, inboundConns) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Close the connection and wait for the disconnect. conn.Close() assertConnReceived(t, disconnected, conn.ID(), ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Add a persistent connection back to the same address and wait for it to // connect since there are no longer any connections to the address. - connID, err := cmgr.AddPersistent(addr) + connID, err := cm.AddPersistent(addr) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } assertConnReceived(t, connected, connID, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Duplicate persistent connection attempts should be rejected. - _, err = cmgr.AddPersistent(addr) + _, err = cm.AddPersistent(addr) if !errors.Is(err, ErrDuplicatePersistent) { t.Fatalf("did not reject duplicate persistent connection, err: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Manual connection attempts to persistent connection should be rejected. - _, err = cmgr.Connect(ctx, addr) + _, err = cm.Connect(ctx, addr) if !errors.Is(err, ErrDuplicatePersistent) { t.Fatalf("did not reject manual connection to persistent, err: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Inbound atempts from the persistent address should be rejected. go listener.Connect(addr) assertNoConnReceived(t, inboundConns) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Remove the persistent connection, wait for it to disconnect, and ensure // it is actually removed. - if err := cmgr.Remove(connID); err != nil { + if err := cm.Remove(connID); err != nil { t.Fatalf("failed to remove persistent connection: %v", err) } assertConnReceived(t, disconnected, connID, ConnTypeManual) - assertRemovedPersistent(t, cmgr, addr) - assertConnManagerInternalState(t, cmgr) + assertRemovedPersistent(t, cm, addr) + assertConnManagerInternalState(t, cm) // Inbound connections from the same address should now succeed. go listener.Connect(addr) assertConnReceived(t, inboundConns, 0, ConnTypeInbound) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Manual connection attempts to the inbound address should be rejected. - if _, err := cmgr.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { + if _, err := cm.Connect(ctx, addr); !errors.Is(err, ErrAlreadyConnected) { t.Fatalf("did not reject outbound for existing inbound conn, err: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Attempts to add a persistent connection to an existing inbound should be // rejected. - _, err = cmgr.AddPersistent(addr) + _, err = cm.AddPersistent(addr) if !errors.Is(err, ErrAlreadyConnected) { t.Fatalf("did not reject persistent conn for existing inbound conn: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestMaxNormalConns ensures the connection manager limits the total number of @@ -1588,7 +1588,7 @@ func TestMaxNormalConns(t *testing.T) { var pauseTargetOutbound atomic.Bool var totalPausedAddrs atomic.Uint32 hitMaxFailedAttempts := make(chan struct{}) - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ Listeners: []net.Listener{listener}, MaxNormalConns: maxNormalConns, TargetOutbound: targetOutbound, @@ -1614,8 +1614,8 @@ func TestMaxNormalConns(t *testing.T) { disconnected <- conn }, }) - cmgr.maxRetryDuration = cmgr.cfg.RetryDuration - ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) + cm.maxRetryDuration = cm.cfg.RetryDuration + ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) // Wait for the expected number of target outbound conns to be established. outbounds := make([]*Conn, 0, targetOutbound) @@ -1623,7 +1623,7 @@ func TestMaxNormalConns(t *testing.T) { conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) outbounds = append(outbounds, conn) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Establish target number of inbounds to the listener and wait for them to // be established. @@ -1637,13 +1637,13 @@ func TestMaxNormalConns(t *testing.T) { conn := assertConnReceived(t, inboundConns, 0, ConnTypeInbound) inbounds = append(inbounds, conn) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Establish target number of manual connections and wait for them to be // established. go func() { for range targetManual { - go cmgr.Connect(ctx, nextAddr()) + go cm.Connect(ctx, nextAddr()) } }() manualConns := make([]*Conn, 0, targetManual+1) @@ -1651,21 +1651,21 @@ func TestMaxNormalConns(t *testing.T) { conn := assertConnReceived(t, connected, 0, ConnTypeManual) manualConns = append(manualConns, conn) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Ensure manual connections that would exceed the max allowed normal // connections are rejected. - _, err := cmgr.Connect(ctx, nextAddr()) + _, err := cm.Connect(ctx, nextAddr()) if !errors.Is(err, ErrMaxNormalConns) { t.Fatalf("did not reject manual connection at max allowed, err: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Ensure inbound connections that would exceed the max allowed normal // connections are rejected. go listener.Connect(nextAddr()) assertNoConnReceived(t, inboundConns) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Pause the target outbound dials and remove one of the target outbound // connections to make room for another manual connection. Then wait for @@ -1680,31 +1680,31 @@ func TestMaxNormalConns(t *testing.T) { case <-time.After(maxFailedAttempts * connTestReceiveTimeout): t.Fatal("did not reach max failed attempts before timeout") } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Establish another manual connection to take the place of the target // outbound connection that was just closed and wait for it to be // established. - go cmgr.Connect(ctx, nextAddr()) + go cm.Connect(ctx, nextAddr()) assertConnReceived(t, connected, 0, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Unpause the target outbound dials and ensure no additional automatic // outbound connections are made despite being under the target outbound due // to max total conns. pauseTargetOutbound.Store(false) assertNoConnReceivedTimeout(t, connected, connTestNonReceiveTimeout+ - cmgr.cfg.RetryDuration) - assertConnManagerInternalState(t, cmgr) + cm.cfg.RetryDuration) + assertConnManagerInternalState(t, cm) // Ensure persistent connections are not subject to the max total normal // connections by adding one and waiting for it to be established. - connID, err := cmgr.AddPersistent(nextAddr()) + connID, err := cm.AddPersistent(nextAddr()) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } assertConnReceived(t, connected, connID, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } // TestMaxConnsPerHost ensures the connection manager limits the total number of @@ -1738,7 +1738,7 @@ func TestMaxConnsPerHost(t *testing.T) { var pauseTargetOutbound atomic.Bool var totalPausedAddrs atomic.Uint32 hitMaxFailedAttempts := make(chan struct{}) - cmgr := newTestConnManager(t, &Config{ + cm := newTestConnManager(t, &Config{ Listeners: []net.Listener{listener}, MaxNormalConns: 30, // High enough to not interfere with per-host tests. MaxConnsPerHost: maxConnsPerHost, @@ -1766,8 +1766,8 @@ func TestMaxConnsPerHost(t *testing.T) { disconnected <- conn }, }) - cmgr.maxRetryDuration = cmgr.cfg.RetryDuration - ctx, _, _ := runConnMgrAsync(t, context.Background(), cmgr) + cm.maxRetryDuration = cm.cfg.RetryDuration + ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) // Wait for the maximum allowed non-whitelisted per-host automatic outbound // conns. @@ -1776,27 +1776,27 @@ func TestMaxConnsPerHost(t *testing.T) { conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) outboundConns = append(outboundConns, conn) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Ensure non-whitelisted manual connections that would exceed the max // allowed per-host connections are rejected. - _, err := cmgr.Connect(ctx, nextSameHost()) + _, err := cm.Connect(ctx, nextSameHost()) if !errors.Is(err, ErrMaxConnsPerHost) { t.Fatalf("did not reject manual connection at per-host limit, err: %v", err) } - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Ensure non-whitelisted inbound connections that would exceed the max // allowed per-host connections are rejected. go listener.Connect(nextSameHost()) assertNoConnReceived(t, inboundConns) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Ensure whitelisted manual connections are allowed to exceed the per-host // limit. for range maxConnsPerHost + 1 { - go cmgr.Connect(ctx, nextSameWhitelistedHost()) + go cm.Connect(ctx, nextSameWhitelistedHost()) assertConnReceived(t, connected, 0, ConnTypeManual) } @@ -1804,16 +1804,16 @@ func TestMaxConnsPerHost(t *testing.T) { // limit. go listener.Connect(nextSameWhitelistedHost()) assertConnReceived(t, inboundConns, 0, ConnTypeInbound) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Ensure whitelisted persistent connections are allowed to exceed the // per-host limit. - connID, err := cmgr.AddPersistent(nextSameWhitelistedHost()) + connID, err := cm.AddPersistent(nextSameWhitelistedHost()) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } assertConnReceived(t, connected, connID, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Pause the target outbound dials and remove one of the target outbound // connections to make room for another manual connection with the same @@ -1832,24 +1832,24 @@ func TestMaxConnsPerHost(t *testing.T) { // Ensure a new non-whitelisted manual connection to the same host now // succeeds. - go cmgr.Connect(ctx, nextSameHost()) + go cm.Connect(ctx, nextSameHost()) assertConnReceived(t, connected, 0, ConnTypeManual) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Unpause the target outbound dials and ensure no additional automatic // outbound connections to the same host are made despite being under the // target outbound. - noConnWaitTimeout := connTestReceiveTimeout + cmgr.cfg.RetryDuration + noConnWaitTimeout := connTestReceiveTimeout + cm.cfg.RetryDuration pauseTargetOutbound.Store(false) assertNoConnReceivedTimeout(t, connected, noConnWaitTimeout) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) // Ensure persistent connections are also subject to the max per-host // connections by adding one and confirming it is NOT established. - _, err = cmgr.AddPersistent(nextSameHost()) + _, err = cm.AddPersistent(nextSameHost()) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } assertNoConnReceivedTimeout(t, connected, noConnWaitTimeout) - assertConnManagerInternalState(t, cmgr) + assertConnManagerInternalState(t, cm) } From 817edd6eb9f77239ad3edb4d1b6bd56ad6f7820c Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 7 Jun 2026 22:50:10 -0500 Subject: [PATCH 28/51] connmgr: Use more modern t.Context in tests. Most of the tests in this package were written when the connection manager was based the older async model and before t.Context was available. As a result, the main method used to run the connection manager in all of the tests takes a context that is always set to a new background context. This updates the method to remove the parameter and instead use the test context via t.Context. --- internal/connmgr/connmanager_test.go | 42 ++++++++++++++-------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 0f46af1ec..98b2e9a59 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -51,10 +51,10 @@ func mustParseAddrPort(addr string) *net.TCPAddr { // // It also registers a test cleanup func that waits for shutdown and asserts the // internal state of the connection manager is empty as expected. -func runConnMgrAsync(t *testing.T, ctx context.Context, cm *ConnManager) (context.Context, context.CancelFunc, *sync.WaitGroup) { +func runConnMgrAsync(t *testing.T, cm *ConnManager) (context.Context, context.CancelFunc, *sync.WaitGroup) { t.Helper() - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(t.Context()) var wg sync.WaitGroup wg.Add(1) go func() { @@ -413,7 +413,7 @@ func TestConnectMode(t *testing.T) { connected <- conn }, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) + ctx, _, _ := runConnMgrAsync(t, cm) addr := mustParseAddrPort("127.0.0.1:18555") go cm.Connect(ctx, addr) @@ -463,7 +463,7 @@ func TestDisconnect(t *testing.T) { disconnected <- conn }, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) + ctx, _, _ := runConnMgrAsync(t, cm) // Attempt a connection to a localhost IP. notifyDialed.Store(true) @@ -621,7 +621,7 @@ func TestRemove(t *testing.T) { disconnected <- conn }, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) + ctx, _, _ := runConnMgrAsync(t, cm) // Ensure removing an ID that doesn't exist returns the expected error. if err := cm.Remove(^uint64(0)); !errors.Is(err, ErrNotFound) { @@ -773,7 +773,7 @@ func TestTargetOutbound(t *testing.T) { connected <- conn }, }) - runConnMgrAsync(t, context.Background(), cm) + runConnMgrAsync(t, cm) // Ensure only the expected number of target outbound conns are established // and no more. @@ -800,7 +800,7 @@ func TestDoubleClose(t *testing.T) { connected <- conn }, }) - runConnMgrAsync(t, context.Background(), cm) + runConnMgrAsync(t, cm) // Wait for the connection to be established. conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) @@ -841,7 +841,7 @@ func TestRetryPersistent(t *testing.T) { disconnected <- conn }, }) - runConnMgrAsync(t, context.Background(), cm) + runConnMgrAsync(t, cm) addr := mustParseAddrPort("127.0.0.1:18555") connID, err := cm.AddPersistent(addr) @@ -887,7 +887,7 @@ func TestMaxPersistent(t *testing.T) { disconnected <- conn }, }) - runConnMgrAsync(t, context.Background(), cm) + runConnMgrAsync(t, cm) var numAddrs uint32 nextAddr := func() net.Addr { @@ -984,7 +984,7 @@ func TestMaxRetryDuration(t *testing.T) { connected <- conn }, }) - runConnMgrAsync(t, context.Background(), cm) + runConnMgrAsync(t, cm) connID, err := cm.AddPersistent(mustParseAddrPort("127.0.0.1:18555")) if err != nil { @@ -1036,7 +1036,7 @@ func TestNetworkFailure(t *testing.T) { conn.RemoteAddr()) }, }) - _, shutdown, wg := runConnMgrAsync(t, context.Background(), cm) + _, shutdown, wg := runConnMgrAsync(t, cm) // Shutdown the connection manager after the max failed attempts is reached // and an additional retry duration has passed and then wait for the @@ -1088,7 +1088,7 @@ func TestMultipleFailedConns(t *testing.T) { Dial: errDialer, }) cm.maxRetryDuration = maxRetryDuration - runConnMgrAsync(t, context.Background(), cm) + runConnMgrAsync(t, cm) // Establish several connection requests to localhost IPs. for i := range targetFailed { @@ -1142,7 +1142,7 @@ func TestShutdownFailedConns(t *testing.T) { Dial: waitDialer, }) cm.maxRetryDuration = retryTimeout - runConnMgrAsync(t, context.Background(), cm) + runConnMgrAsync(t, cm) // Add a persistent connection. addr := mustParseAddrPort("127.0.0.1:18555") @@ -1180,7 +1180,7 @@ func TestRemovePendingConnection(t *testing.T) { cm := newTestConnManager(t, &Config{ Dial: indefiniteDialer, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) + ctx, _, _ := runConnMgrAsync(t, cm) // Establish a connection request to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") @@ -1253,7 +1253,7 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { connected <- conn }, }) - runConnMgrAsync(t, context.Background(), cm) + runConnMgrAsync(t, cm) // Establish a persistent connection to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") @@ -1311,7 +1311,7 @@ func TestDialTimeout(t *testing.T) { Dial: timeoutDialer, DialTimeout: dialTimeout, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) + ctx, _, _ := runConnMgrAsync(t, cm) // Establish a connection to a localhost IP. addr := mustParseAddrPort("127.0.0.1:18555") @@ -1344,7 +1344,7 @@ func TestConnectContext(t *testing.T) { cm := newTestConnManager(t, &Config{ Dial: indefiniteDialer, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) + ctx, _, _ := runConnMgrAsync(t, cm) // Establish a connection request to a localhost IP with a separate context // that can be canceled. @@ -1400,7 +1400,7 @@ func TestListeners(t *testing.T) { }, Dial: mockDialer, }) - runConnMgrAsync(t, context.Background(), cm) + runConnMgrAsync(t, cm) // Fake a couple of mock connections to each of the listeners. go func() { @@ -1454,7 +1454,7 @@ func TestRejectDuplicateConns(t *testing.T) { disconnected <- conn }, }) - ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) + ctx, _, _ := runConnMgrAsync(t, cm) // Dial a manual connection and wait for it to become pending. addr := mustParseAddrPort("127.0.0.1:18555") @@ -1615,7 +1615,7 @@ func TestMaxNormalConns(t *testing.T) { }, }) cm.maxRetryDuration = cm.cfg.RetryDuration - ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) + ctx, _, _ := runConnMgrAsync(t, cm) // Wait for the expected number of target outbound conns to be established. outbounds := make([]*Conn, 0, targetOutbound) @@ -1767,7 +1767,7 @@ func TestMaxConnsPerHost(t *testing.T) { }, }) cm.maxRetryDuration = cm.cfg.RetryDuration - ctx, _, _ := runConnMgrAsync(t, context.Background(), cm) + ctx, _, _ := runConnMgrAsync(t, cm) // Wait for the maximum allowed non-whitelisted per-host automatic outbound // conns. From bd5bf252238eae5e8f60b1ab91c6bcb988b5f37e Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 25 May 2026 19:24:40 -0500 Subject: [PATCH 29/51] connmgr: Only close once in double close tests. --- internal/connmgr/connmanager_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 98b2e9a59..4eaa74d7a 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -811,7 +811,9 @@ func TestDoubleClose(t *testing.T) { origOnClose := conn.onClose conn.onClose = func() { numClosed++ - origOnClose() + if numClosed == 1 { + origOnClose() + } } // Close the connection multiple times and make sure it only happens once. From 1ce361d8a38a2455f16a64b4fd7077d77ce72d8e Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 25 May 2026 19:27:30 -0500 Subject: [PATCH 30/51] connmgr: Cleaner dial timeout detection test. Now that the Connect method is synchronous and returns an error, this modifies the test for detecting dial timeouts to use that error for more more accurate detection that the failure is actually the result of dial timeout as expected. --- internal/connmgr/connmanager_test.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 4eaa74d7a..c7954e25d 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -1298,12 +1298,10 @@ func TestDialTimeout(t *testing.T) { // Create a connection manager instance with a dialer that blocks for three // times the configured dial timeout before connecting. const dialTimeout = time.Millisecond * 20 - cancelled := make(chan struct{}) timeoutDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { select { case <-time.After(dialTimeout * 3): case <-ctx.Done(): - close(cancelled) return nil, ctx.Err() } @@ -1315,15 +1313,22 @@ func TestDialTimeout(t *testing.T) { }) ctx, _, _ := runConnMgrAsync(t, cm) - // Establish a connection to a localhost IP. - addr := mustParseAddrPort("127.0.0.1:18555") - go cm.Connect(ctx, addr) + connectErr := make(chan error, 1) + go func() { + addr := mustParseAddrPort("127.0.0.1:18555") + _, err := cm.Connect(ctx, addr) + connectErr <- err + }() assertConnManagerInternalState(t, cm) - // Wait to receive the signal that the dialer context was cancelled, which - // means the dial timeout was hit. + // Wait for the error from connect and ensure it is the expected deadline + // exceeded (aka dial timeout) error. select { - case <-cancelled: + case err := <-connectErr: + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("unexpected connect err: got %v, want %v", err, + context.Canceled) + } case <-time.After(dialTimeout * 10): t.Fatal("timeout waiting for dial cancellation") } From 0669e96fedfef69abf8cf259131efb5af46fd4b3 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 24 May 2026 17:53:48 -0500 Subject: [PATCH 31/51] connmgr: Use synctest for max retry duration test. This converts the max retry duration test over to use the synctest which makes it more robust and no longer reliant on real time. --- internal/connmgr/connmanager_test.go | 92 +++++++++++++++------------- 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index c7954e25d..9b8a58d24 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -15,6 +15,7 @@ import ( "sync" "sync/atomic" "testing" + "testing/synctest" "time" ) @@ -953,56 +954,61 @@ func TestMaxPersistent(t *testing.T) { assertConnManagerInternalState(t, cm) } -// TestMaxRetryDuration tests the maximum retry duration. -// -// We have a timed dialer which initially returns err but after RetryDuration -// hits maxRetryDuration returns a mock conn. +// TestMaxRetryDuration ensures the maximum retry duration is respected. func TestMaxRetryDuration(t *testing.T) { t.Parallel() - - // This test relies on the current value of the max retry duration defined - // in the tests, so assert it. - if defaultTestMaxRetryDuration != 2*time.Millisecond { - t.Fatalf("max retry duration of %v is not the required value for test", - defaultTestMaxRetryDuration) - } - - networkUp := make(chan struct{}) - timedDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { - select { - case <-networkUp: - return mockDialer(ctx, network, addr) - default: - return nil, errors.New("network down") + synctest.Test(t, func(t *testing.T) { + networkUp := make(chan struct{}) + timedDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + select { + case <-networkUp: + return mockDialer(ctx, network, addr) + default: + return nil, errors.New("network down") + } } - } - connected := make(chan *Conn) - cm := newTestConnManager(t, &Config{ - RetryDuration: time.Millisecond, - TargetOutbound: 1, - Dial: timedDialer, - OnConnection: func(conn *Conn) { - connected <- conn - }, - }) - runConnMgrAsync(t, cm) + connected := make(chan *Conn) + cm := newTestConnManager(t, &Config{ + RetryDuration: time.Second, + Dial: timedDialer, + OnConnection: func(conn *Conn) { + connected <- conn + }, + }) + cm.maxRetryDuration = 2 * time.Second + runConnMgrAsync(t, cm) - connID, err := cm.AddPersistent(mustParseAddrPort("127.0.0.1:18555")) - if err != nil { - t.Fatalf("failed to add persistent connection: %v", err) - } + connID, err := cm.AddPersistent(mustParseAddrPort("127.0.0.1:18555")) + if err != nil { + t.Fatalf("failed to add persistent connection: %v", err) + } - // retry in 1ms - // retry in 2ms - max retry duration reached - // retry in 2ms - timedDialer returns [mockDialer] - const networkUpTimeout = 5 * time.Millisecond - time.AfterFunc(networkUpTimeout, func() { - close(networkUp) + // Approximate sequence of events. The exact number of retries will + // vary due to jitter. + // + // The test is stable regardless since it expects a connection within + // one max retry duration of the network being brought up and, as shown + // below, the retry duration without the max imposed would be far + // greater and not arrive in time. + // + // 0s: initial attempt (retry in ~1s) + // ~1s: retry 1 (retry in ~2s) - max retry duration reached + // ~3s: retry 2 (retry in ~2s, w/o max would be in ~3s => next at ~6s) + // ~5s: retry 3 (retry in ~2s, w/o max would be in ~4s => next at ~10s) + // ~7s: retry 4 (retry in ~2s, w/o max would be in ~5s => next at ~15s) + // ~9s: retry 5 (retry in ~2s, w/o max would be in ~6s => next at ~21s) + // ~11s: retry 6 (retry in ~2s, w/o max would be in ~7s => next at ~28s) + // ~12s: timedDialer returns [mockDialer] + // ~13s: retry 7 succeeds + networkUpTimeout := 6 * cm.maxRetryDuration + time.AfterFunc(networkUpTimeout, func() { + close(networkUp) + }) + timeout := networkUpTimeout + cm.maxRetryDuration + assertConnReceivedTimeout(t, connected, timeout, connID, ConnTypeManual) + assertConnManagerInternalState(t, cm) }) - const timeout = connTestReceiveTimeout + networkUpTimeout - assertConnReceivedTimeout(t, connected, timeout, connID, ConnTypeManual) - assertConnManagerInternalState(t, cm) } // TestNetworkFailure tests that the connection manager handles a network From 5a8df14747110aa69c0445bb492bc235dab3c79c Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 24 May 2026 17:54:04 -0500 Subject: [PATCH 32/51] connmgr: Add deterministic CSPRNG for tests. There are various aspects of the connection manager that would benefit from making use of randomness. For example, adding random jitter to connection backoffs. However, by default, randomness often makes reproducing test failures extremely difficult since the next run will not not necessarily be following the same code branches due to different random values. In order to provide deterministic reproducibility, this adds a csprng interface that makes use of the dcrd crypto/rand module by default. All code that sources randomness moving forward is expected to use the interface. It also adds test infrastructure to automatically generate a new seed on each test iteration that is logged in the event of failure along with the ability to specify the seed via the -seed parameter. When running tests, the seed is then used to create a deterministic math/v2/chacha8 instance to implement the csprng interface. --- internal/connmgr/connmanager.go | 14 ++++- internal/connmgr/connmanager_test.go | 86 ++++++++++++++++++++++++++++ internal/connmgr/csprng.go | 62 ++++++++++++++++++++ 3 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 internal/connmgr/csprng.go diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index f11d28953..da6d0867a 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -318,6 +318,13 @@ type ConnManager struct { // creating time and treated as immutable after that. cfg Config + // csprng provides a cryptographically secure pseudorandom number generator. + // + // All code in the connection manager that relies on random values is + // expected to make use of this so that tests can replace the real + // implementation with a deterministic PRNG for reproducibility. + csprng csprng + // maxRetryDuration is the maximum duration a persistent connection retry // backoff is allowed to grow to. maxRetryDuration time.Duration @@ -340,7 +347,10 @@ type ConnManager struct { totalNormalConnsSem semaphore activeOutboundsSem semaphore - // The fields below this point are all protected by the connection mutex. + // ****************************************************************** + // The fields below this point are protected by the connection mutex. + // ****************************************************************** + connMtx sync.Mutex // persistent tracks all registered persistent connection entries. @@ -1457,6 +1467,7 @@ func New(cfg *Config) (*ConnManager, error) { if cfg.Dial == nil { return nil, MakeError(ErrDialNil, "dial cannot be nil") } + // Default to sane values. if cfg.RetryDuration <= 0 { cfg.RetryDuration = defaultRetryDuration @@ -1474,6 +1485,7 @@ func New(cfg *Config) (*ConnManager, error) { cm := ConnManager{ cfg: *cfg, // Copy so caller can't mutate quit: make(chan struct{}), + csprng: globalRand, maxRetryDuration: defaultMaxRetryDuration, runPersistentChan: make(chan *persistentEntry, MaxPersistent), totalNormalConnsSem: makeSemaphore(cfg.MaxNormalConns), diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 9b8a58d24..940cd9966 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -7,11 +7,18 @@ package connmgr import ( "context" + "encoding/binary" + "encoding/hex" "errors" + "flag" "fmt" + mrand "math/rand/v2" "net" "net/netip" + "os" "reflect" + "strconv" + "strings" "sync" "sync/atomic" "testing" @@ -19,6 +26,82 @@ import ( "time" ) +// prngSeed is populated when the tests are initialized either by the -seed +// parameter if specified or a source of cryptographic randomness otherwise. +var prngSeed [32]byte + +// prngIteration is incremented each time a new test prng seed is requested so +// that each iteration when testing with -count > 1 gets a unique sequence of +// reproducible values. It can be overridden via the -seed parameter when the +// tests are initialized for easy reproducibility of test failures. +var prngIteration atomic.Uint32 + +func TestMain(m *testing.M) { + seedFlag := flag.String("seed", "", "use deterministic PRNG seed") + flag.Parse() + if *seedFlag != "" { + parts := strings.Split(*seedFlag, "/") + if len(parts) == 0 || len(parts) > 2 { + fmt.Fprintln(os.Stderr, "invalid -seed: format must be "+ + "<32 byte hex seed> or <32 byte hex seed>/") + os.Exit(1) + } + b, err := hex.DecodeString(parts[0]) + if err != nil { + fmt.Fprintln(os.Stderr, "invalid -seed hex:", err) + os.Exit(1) + } + if len(b) != 32 { + fmt.Fprintln(os.Stderr, "invalid -seed: must be 32 bytes") + os.Exit(1) + } + copy(prngSeed[:], b) + if len(parts) > 1 { + iteration, err := strconv.ParseUint(parts[1], 10, 32) + if err != nil { + fmt.Fprintln(os.Stderr, "invalid -seed iteration:", err) + os.Exit(1) + } + prngIteration.Store(uint32(iteration)) + } + } else { + globalRand.Read(prngSeed[:]) + } + os.Exit(m.Run()) +} + +// newTestPRNGSeed returns a seed to use for the deterministic test prng for the +// given iteration based on the global [prngSeed] variable which is populated in +// [TestMain]. +func newTestPRNGSeed(t testing.TB) [32]byte { + t.Helper() + + // Generate a new determinstic seed based on the test iteration count and + // global [prngSeed] variable which can be set with flags in [TestMain]. + iteration := prngIteration.Add(1) - 1 + t.Cleanup(func() { + t.Helper() + if t.Failed() { + runFlags := fmt.Sprintf("-run=%s", t.Name()) + if _, ok := t.(*testing.B); ok { + runFlags = fmt.Sprintf("-run=^$ -bench=%s", t.Name()) + } + t.Logf("Reproduce with: go test %s -seed=%x/%d -count=1", runFlags, + prngSeed, iteration) + } + }) + + // Increment the test seed by the iteration count so each test iteration has + // a unique seed derived from the overall test seed. + // + // This is hacky and not cryptographically sound, but it's is only used for + // tests, so it doesn't need to be. + seed := prngSeed + be := binary.BigEndian + be.PutUint32(seed[28:], be.Uint32(seed[28:])+iteration) + return seed +} + const ( // defaultTestMaxRetryDuration is the default max duration a connection // retry backoff is allowed to grow to when running tests. @@ -84,6 +167,9 @@ func newTestConnManager(t *testing.T, cfg *Config) *ConnManager { t.Fatalf("New: unexpected error: %v", err) } cm.maxRetryDuration = defaultTestMaxRetryDuration + seed := newTestPRNGSeed(t) + src := mrand.NewChaCha8(seed) + cm.csprng = mrand.New(src) // nolint:gosec return cm } diff --git a/internal/connmgr/csprng.go b/internal/connmgr/csprng.go new file mode 100644 index 000000000..4f30b6d5b --- /dev/null +++ b/internal/connmgr/csprng.go @@ -0,0 +1,62 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package connmgr + +import ( + "sync" + + "github.com/decred/dcrd/crypto/rand" +) + +// csprng provides an interface for the CSPRNG methods the connection manager +// uses. This primarily exists so tests can replace the real implementation +// with a deterministic PRNG for reproducibility. +type csprng interface { + Uint64N(n uint64) uint64 +} + +// lockingPRNG wraps an instance of [rand.PRNG] with a mutex so it can be used +// concurrently. +type lockingPRNG struct { + prng *rand.PRNG + sync.Mutex +} + +// Uint64 returns a uniform random uint64. +func (p *lockingPRNG) Uint64() uint64 { + p.Lock() + defer p.Unlock() + + return p.prng.Uint64() +} + +// Uint64N returns a random uint64 in range [0,n) without modulo bias. +func (p *lockingPRNG) Uint64N(n uint64) uint64 { + p.Lock() + defer p.Unlock() + + return p.prng.Uint64N(n) +} + +// Read fills s with len(s) of cryptographically-secure random bytes. It never +// errors. +func (p *lockingPRNG) Read(s []byte) { + p.Lock() + defer p.Unlock() + + _, _ = p.prng.Read(s) +} + +// globalRand is set at init time so any failure to seed, which should never +// happen in practice, will cause a panic at startup instead of runtime. +var globalRand *lockingPRNG + +func init() { + p, err := rand.NewPRNG() + if err != nil { + panic(err) + } + globalRand = &lockingPRNG{prng: p} +} From 65532310e11d5b8d0c676a3d3d6b4330f94ab667 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 24 May 2026 17:54:08 -0500 Subject: [PATCH 33/51] connmgr: Implement exponential backoff with jitter. The current code uses a simple deterministic linear backoff with a maximum capacity. While it works well enough, it has a big downside in that all of the attempts will tend to coalesce over time. This is especially true during a network outage since all connections will drop at the same time. Also, exponential backoffs are preferred over linear to reduce the retry frequency more quickly and give the remote more time to come back online. This addresses both of those points by changing the linear backoff for persistent peers to use an exponential backoff with jitter instead. --- internal/connmgr/connmanager.go | 37 +++++++++++++++++++++++++--- internal/connmgr/connmanager_test.go | 10 ++++---- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index da6d0867a..065e89530 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -11,6 +11,7 @@ import ( "context" "errors" "fmt" + "math" "net" "net/netip" "strconv" @@ -329,6 +330,11 @@ type ConnManager struct { // backoff is allowed to grow to. maxRetryDuration time.Duration + // maxRetryScalingBits is the maximum number of bits the exponential backoff + // scaling factor can occupy such that multiplying by [Config.RetryDuration] + // is guaranteed not to overflow. + maxRetryScalingBits uint8 + // runPersistentChan is used to signal the persistent connections handler to // launch a goroutine that attempts to always maintain an established // connection with a given address. @@ -1198,6 +1204,30 @@ func (cm *ConnManager) FindPersistentAddrID(addr net.Addr) (uint64, bool) { return id, ok } +// backoffWithJitter returns an exponential backoff delay with additional jitter +// for the given number of retries. +func (cm *ConnManager) backoffWithJitter(retries uint32) time.Duration { + if retries == 0 { + return 0 + } + + // Calculate an expontential backoff capped to prevent overflow and clamped + // to the max retry duration. + shift := min(retries-1, uint32(cm.maxRetryScalingBits)) + factor := 1 << shift + + baseRetryDuration := cm.cfg.RetryDuration + backoff := min(baseRetryDuration*time.Duration(factor), cm.maxRetryDuration) + if backoff == 0 { + return 0 + } + + // Apply 50% jitter. + halfBackoff := backoff / 2 + jitter := time.Duration(cm.csprng.Uint64N(uint64(halfBackoff))) + return halfBackoff + jitter +} + // runPersistent attempts to maintain a persistent connection to the provided // address until the passed context is canceled. // @@ -1253,10 +1283,9 @@ func (cm *ConnManager) runPersistent(ctx context.Context, connID uint64, addr ne if retryCount < maxUint32 { retryCount++ } - retryWait := time.Duration(retryCount) * cm.cfg.RetryDuration - retryWait = min(retryWait, cm.maxRetryDuration) + retryWait := cm.backoffWithJitter(retryCount) log.Debugf("Retrying connection to %v in %v (retries %d)", addr, - retryWait, retryCount) + retryWait.Truncate(time.Microsecond), retryCount) retryAfter = time.After(retryWait) continue } @@ -1482,11 +1511,13 @@ func New(cfg *Config) (*ConnManager, error) { cfg.TargetOutbound = defaultTargetOutbound } cfg.TargetOutbound = min(cfg.TargetOutbound, cfg.MaxNormalConns) + retryDurationBits := uint8(math.Ceil(math.Log2(float64(cfg.RetryDuration)))) cm := ConnManager{ cfg: *cfg, // Copy so caller can't mutate quit: make(chan struct{}), csprng: globalRand, maxRetryDuration: defaultMaxRetryDuration, + maxRetryScalingBits: 63 - retryDurationBits, runPersistentChan: make(chan *persistentEntry, MaxPersistent), totalNormalConnsSem: makeSemaphore(cfg.MaxNormalConns), activeOutboundsSem: makeSemaphore(cfg.TargetOutbound), diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 940cd9966..caee107ee 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -1080,11 +1080,11 @@ func TestMaxRetryDuration(t *testing.T) { // // 0s: initial attempt (retry in ~1s) // ~1s: retry 1 (retry in ~2s) - max retry duration reached - // ~3s: retry 2 (retry in ~2s, w/o max would be in ~3s => next at ~6s) - // ~5s: retry 3 (retry in ~2s, w/o max would be in ~4s => next at ~10s) - // ~7s: retry 4 (retry in ~2s, w/o max would be in ~5s => next at ~15s) - // ~9s: retry 5 (retry in ~2s, w/o max would be in ~6s => next at ~21s) - // ~11s: retry 6 (retry in ~2s, w/o max would be in ~7s => next at ~28s) + // ~3s: retry 2 (retry in ~2s, w/o max would be in ~4s => next at ~7s) + // ~5s: retry 3 (retry in ~2s, w/o max would be in ~8s => next at ~15s) + // ~7s: retry 4 (retry in ~2s, w/o max would be in ~16s => next at ~33s) + // ~9s: retry 5 (retry in ~2s, w/o max would be in ~32s => next at ~65s) + // ~11s: retry 6 (retry in ~2s, w/o max would be in ~64s => next at ~129s) // ~12s: timedDialer returns [mockDialer] // ~13s: retry 7 succeeds networkUpTimeout := 6 * cm.maxRetryDuration From 6be73edaa13ab121614d4d4c83945f3fa0aece64 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 24 May 2026 21:11:19 -0500 Subject: [PATCH 34/51] connmgr: Update README.md for jitter. --- internal/connmgr/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/connmgr/README.md b/internal/connmgr/README.md index 30ebaef01..a7e171979 100644 --- a/internal/connmgr/README.md +++ b/internal/connmgr/README.md @@ -28,7 +28,7 @@ The following is a brief overview of the key features: address source (`GetNewAddress`) - Persistent connections - Maintains up to `MaxPersistent` addresses that are automatically retried - with exponential backoff on disconnect + with exponential backoff and jitter on disconnect - Manual connections - Supports manual connection establishment via `Connect` - Connection limits From ce4dd53882e834272090087682c2bb775dc5a3bb Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 26 May 2026 20:50:11 -0500 Subject: [PATCH 35/51] connmgr: Wait for all goroutines to finish. This modifies the persistent connection and target outbound handler shutdown logic to wait for any goroutines they have launched to finish before returning. This ensures there are no dangling goroutines from them once Run returns. --- internal/connmgr/connmanager.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 065e89530..3e20be429 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -1320,6 +1320,10 @@ func (cm *ConnManager) runPersistent(ctx context.Context, connID uint64, addr ne // persistentConnsHandler handles launching individual goroutines for persistent // connections. func (cm *ConnManager) persistentConnsHandler(ctx context.Context) { + // Ensure all persistent handlers are done before returning. + var wg sync.WaitGroup + defer wg.Wait() + for { select { case entry := <-cm.runPersistentChan: @@ -1327,7 +1331,11 @@ func (cm *ConnManager) persistentConnsHandler(ctx context.Context) { cm.connMtx.Lock() entry.cancel = cancel cm.connMtx.Unlock() - go cm.runPersistent(pCtx, entry.id, entry.addr) + wg.Add(1) + go func() { + cm.runPersistent(pCtx, entry.id, entry.addr) + wg.Done() + }() case <-ctx.Done(): return @@ -1344,6 +1352,10 @@ func (cm *ConnManager) targetOutboundHandler(ctx context.Context) { log.Trace("Starting target outbound handler") defer log.Trace("Target outbound handler done") + // Ensure potential pending dial cleanup is done before returning. + var wg sync.WaitGroup + defer wg.Wait() + // failedAttempts tracks the total number of failed outbound connection // attempts since the last successful connection. It is primarily used to // detect network outages in order to impose a retry timeout on achieving @@ -1396,7 +1408,9 @@ func (cm *ConnManager) targetOutboundHandler(ctx context.Context) { continue } + wg.Add(1) go func(addr net.Addr) { + defer wg.Done() onClose := func() { cm.totalNormalConnsSem.Release() cm.activeOutboundsSem.Release() From 4040caf16e0d2ff201da891324343daf12a7d33d Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 26 May 2026 20:50:13 -0500 Subject: [PATCH 36/51] connmgr: Use addr generator in tests. As the connection manager adds more functionality, the hard coded localhost addresses throughout the tests will likely lead to issues due to special handling for loopback and whitelisted addresses. Manually having to update them as things break would not be ideal, so this introduces a test address generator and updates all of the tests to make use of it. The generator uses normal, routable, IPv4 addresses and makes it easy to use new address ranges to deal with any changes that would otherwise break tests. --- internal/connmgr/connmanager_test.go | 199 +++++++++++++++++---------- 1 file changed, 129 insertions(+), 70 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index caee107ee..ee18bfea5 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -114,6 +114,9 @@ const ( // connTestNonReceiveTimeout is the default timeout used throughout the // tests when expecting that a connection will NOT be received. connTestNonReceiveTimeout = 20 * time.Millisecond + + // defaultTestP2PPort is the default p2p port to use throughout the test. + defaultTestP2PPort = 18555 ) // mustParseAddrPort parses the provided address into a [*net.TCPAddr] and will @@ -128,6 +131,82 @@ func mustParseAddrPort(addr string) *net.TCPAddr { } } +// addrGenerator houses state for an address generator used to simplify tests. +type addrGenerator struct { + mtx sync.Mutex + outboundGroupPrefixBits uint + addr netip.Addr + port uint16 +} + +// newAddrGenerator returns a new address generator configured to start at the +// given base ip:port. +func newAddrGenerator(baseAddrPort string) *addrGenerator { + addrPort := netip.MustParseAddrPort(baseAddrPort) + return &addrGenerator{ + outboundGroupPrefixBits: 16, + addr: addrPort.Addr(), + port: addrPort.Port(), + } +} + +// Next advances the generator to the next IP and returns the result. It skips +// all addresses of the form "x.x.x.0". +// +// This provides convenient access to a new unique address with every +// invocation. +func (g *addrGenerator) Next() net.Addr { + g.mtx.Lock() + defer g.mtx.Unlock() + + // Skip "x.x.x.0". + g.addr = g.addr.Next() + if g.addr.As4()[3] == 0 { + g.addr = g.addr.Next() + } + + port := strconv.Itoa(int(g.port)) + return mustParseAddrPort(net.JoinHostPort(g.addr.String(), port)) +} + +// NextPort advances the generator to the next port of the current IP and +// returns the result. It skips ports 0-1024. +// +// This provides convenient access to a new endpoint with the same host address +// with every invocation. +func (g *addrGenerator) NextPort() net.Addr { + g.mtx.Lock() + defer g.mtx.Unlock() + + g.port++ + if g.port < 1025 { + g.port = 1025 + } + + port := strconv.Itoa(int(g.port)) + return mustParseAddrPort(net.JoinHostPort(g.addr.String(), port)) +} + +// defaultAddrGenerator returns an address generator configured with a default +// starting base address and port useful throughout the tests. The base address +// is a normal routable IPv4 address. +func defaultAddrGenerator() *addrGenerator { + return newAddrGenerator(fmt.Sprintf("12.0.0.0:%d", defaultTestP2PPort)) +} + +// defaultTestAddr returns a default address to use throughout the tests. It is +// a convenient way to get the first address generated by the default address +// generator. +func defaultTestAddr() net.Addr { + return defaultAddrGenerator().Next() +} + +// defaultMockListener returns a default mock listener to use throughout the +// tests. +func defaultMockListener() *mockListener { + return newMockListener(fmt.Sprintf("127.0.0.1:%d", defaultTestP2PPort)) +} + // runConnMgrAsync invokes [ConnManager.Run] on the passed connection manager in // a separate goroutine and returns a cancelable context and wait group the // caller can use to shutdown the connection manager and wait for clean @@ -502,8 +581,7 @@ func TestConnectMode(t *testing.T) { }) ctx, _, _ := runConnMgrAsync(t, cm) - addr := mustParseAddrPort("127.0.0.1:18555") - go cm.Connect(ctx, addr) + go cm.Connect(ctx, defaultTestAddr()) // Ensure that only a single connection is received. assertConnReceived(t, connected, 0, ConnTypeManual) @@ -552,11 +630,11 @@ func TestDisconnect(t *testing.T) { }) ctx, _, _ := runConnMgrAsync(t, cm) - // Attempt a connection to a localhost IP. + // Attempt a connection. notifyDialed.Store(true) waitForPending.Store(true) notifyCanceled.Store(true) - addr := mustParseAddrPort("127.0.0.1:18555") + addr := defaultTestAddr() go cm.Connect(ctx, addr) // Wait for the connection manager to attempt to dial and ensure the @@ -715,11 +793,11 @@ func TestRemove(t *testing.T) { t.Fatalf("mismatched remove error: got %v, want %v", err, ErrNotFound) } - // Attempt a connection to a localhost IP. + // Attempt a connection. notifyDialed.Store(true) waitForPending.Store(true) notifyCanceled.Store(true) - addr := mustParseAddrPort("127.0.0.1:18555") + addr := defaultTestAddr() go cm.Connect(ctx, addr) // Wait for the connection manager to attempt to dial and ensure the @@ -846,15 +924,14 @@ func TestRemove(t *testing.T) { func TestTargetOutbound(t *testing.T) { t.Parallel() + addrGen := defaultAddrGenerator() const targetOutbound = 10 - var nextAddr atomic.Uint32 connected := make(chan *Conn) cm := newTestConnManager(t, &Config{ TargetOutbound: targetOutbound, Dial: mockDialer, GetNewAddress: func() (net.Addr, error) { - addrStr := fmt.Sprintf("127.0.0.%d:18555", nextAddr.Add(1)) - return mustParseAddrPort(addrStr), nil + return addrGen.Next(), nil }, OnConnection: func(conn *Conn) { connected <- conn @@ -876,12 +953,13 @@ func TestTargetOutbound(t *testing.T) { func TestDoubleClose(t *testing.T) { t.Parallel() + addrGen := defaultAddrGenerator() connected := make(chan *Conn) cm := newTestConnManager(t, &Config{ TargetOutbound: 1, Dial: mockDialer, GetNewAddress: func() (net.Addr, error) { - return mustParseAddrPort("127.0.0.1:18555"), nil + return addrGen.Next(), nil }, OnConnection: func(conn *Conn) { connected <- conn @@ -932,7 +1010,7 @@ func TestRetryPersistent(t *testing.T) { }) runConnMgrAsync(t, cm) - addr := mustParseAddrPort("127.0.0.1:18555") + addr := defaultTestAddr() connID, err := cm.AddPersistent(addr) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) @@ -978,18 +1056,12 @@ func TestMaxPersistent(t *testing.T) { }) runConnMgrAsync(t, cm) - var numAddrs uint32 - nextAddr := func() net.Addr { - numAddrs++ - addrStr := fmt.Sprintf("127.0.0.%d:18555", numAddrs) - return mustParseAddrPort(addrStr) - } - // Add the maximum allowed number of persistent conns. + addrGen := defaultAddrGenerator() connIDs := make([]uint64, 0, MaxPersistent) addrs := make([]net.Addr, 0, MaxPersistent) for range MaxPersistent { - addr := nextAddr() + addr := addrGen.Next() connID, err := cm.AddPersistent(addr) if err != nil { t.Fatalf("failed to add persistent connection %v: %v", addr, err) @@ -1004,7 +1076,7 @@ func TestMaxPersistent(t *testing.T) { // Attempting to add more than the max allowed number of persistent conns // should be rejected. - _, err := cm.AddPersistent(nextAddr()) + _, err := cm.AddPersistent(addrGen.Next()) if !errors.Is(err, ErrMaxPersistent) { t.Fatalf("did not reject > max persistent, err: %v", err) } @@ -1016,7 +1088,7 @@ func TestMaxPersistent(t *testing.T) { if err := cm.Disconnect(connID); err != nil { t.Fatalf("failed to disconnect persistent conn %v: %v", addr, err) } - _, err = cm.AddPersistent(nextAddr()) + _, err = cm.AddPersistent(addrGen.Next()) if !errors.Is(err, ErrMaxPersistent) { t.Fatalf("did not reject max persistent after dc, err: %v", err) } @@ -1032,7 +1104,7 @@ func TestMaxPersistent(t *testing.T) { assertConnManagerInternalState(t, cm) // A new persistent conn should now be allowed. - addr = nextAddr() + addr = addrGen.Next() _, err = cm.AddPersistent(addr) if err != nil { t.Fatalf("failed to add persistent connection %v: %v", addr, err) @@ -1065,7 +1137,7 @@ func TestMaxRetryDuration(t *testing.T) { cm.maxRetryDuration = 2 * time.Second runConnMgrAsync(t, cm) - connID, err := cm.AddPersistent(mustParseAddrPort("127.0.0.1:18555")) + connID, err := cm.AddPersistent(defaultTestAddr()) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } @@ -1116,14 +1188,13 @@ func TestNetworkFailure(t *testing.T) { } return nil, errors.New("network down") } - var nextAddr atomic.Uint32 + addrGen := defaultAddrGenerator() cm := newTestConnManager(t, &Config{ TargetOutbound: targetOutbound, RetryDuration: retryTimeout, Dial: errDialer, GetNewAddress: func() (net.Addr, error) { - addrStr := fmt.Sprintf("127.0.0.%d:18555", nextAddr.Add(1)) - return mustParseAddrPort(addrStr), nil + return addrGen.Next(), nil }, OnConnection: func(conn *Conn) { t.Fatalf("network failure: got unexpected connection - %v", @@ -1184,10 +1255,10 @@ func TestMultipleFailedConns(t *testing.T) { cm.maxRetryDuration = maxRetryDuration runConnMgrAsync(t, cm) - // Establish several connection requests to localhost IPs. - for i := range targetFailed { - addr := mustParseAddrPort(fmt.Sprintf("127.0.0.%d:18555", i+1)) - _, err := cm.AddPersistent(addr) + // Establish several persistent connections. + addrGen := defaultAddrGenerator() + for range targetFailed { + _, err := cm.AddPersistent(addrGen.Next()) if err != nil { t.Fatalf("unexpected add err: %v", err) } @@ -1239,8 +1310,7 @@ func TestShutdownFailedConns(t *testing.T) { runConnMgrAsync(t, cm) // Add a persistent connection. - addr := mustParseAddrPort("127.0.0.1:18555") - _, err := cm.AddPersistent(addr) + _, err := cm.AddPersistent(defaultTestAddr()) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -1276,8 +1346,8 @@ func TestRemovePendingConnection(t *testing.T) { }) ctx, _, _ := runConnMgrAsync(t, cm) - // Establish a connection request to a localhost IP. - addr := mustParseAddrPort("127.0.0.1:18555") + // Establish a connection request. + addr := defaultTestAddr() go cm.Connect(ctx, addr) assertConnManagerInternalState(t, cm) @@ -1349,8 +1419,8 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { }) runConnMgrAsync(t, cm) - // Establish a persistent connection to a localhost IP. - addr := mustParseAddrPort("127.0.0.1:18555") + // Establish a persistent connection. + addr := defaultTestAddr() connID, err := cm.AddPersistent(addr) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -1405,10 +1475,10 @@ func TestDialTimeout(t *testing.T) { }) ctx, _, _ := runConnMgrAsync(t, cm) + // Establish a connection. connectErr := make(chan error, 1) go func() { - addr := mustParseAddrPort("127.0.0.1:18555") - _, err := cm.Connect(ctx, addr) + _, err := cm.Connect(ctx, defaultTestAddr()) connectErr <- err }() assertConnManagerInternalState(t, cm) @@ -1445,9 +1515,9 @@ func TestConnectContext(t *testing.T) { }) ctx, _, _ := runConnMgrAsync(t, cm) - // Establish a connection request to a localhost IP with a separate context - // that can be canceled. - addr := mustParseAddrPort("127.0.0.1:18555") + // Establish a connection request with a separate context that can be + // canceled. + addr := defaultTestAddr() connectCtx, cancelConnect := context.WithCancel(ctx) connectErr := make(chan error, 1) go func() { @@ -1530,7 +1600,7 @@ func TestRejectDuplicateConns(t *testing.T) { var closeDialedOnce sync.Once inboundConns := make(chan *Conn) - listener := newMockListener("127.0.0.1:18109") + listener := defaultMockListener() connected := make(chan *Conn) disconnected := make(chan *Conn) dialed := make(chan struct{}) @@ -1556,7 +1626,7 @@ func TestRejectDuplicateConns(t *testing.T) { ctx, _, _ := runConnMgrAsync(t, cm) // Dial a manual connection and wait for it to become pending. - addr := mustParseAddrPort("127.0.0.1:18555") + addr := defaultTestAddr() go cm.Connect(ctx, addr) select { case <-dialed: @@ -1664,14 +1734,6 @@ func TestRejectDuplicateConns(t *testing.T) { func TestMaxNormalConns(t *testing.T) { t.Parallel() - // nextAddr is a convenience func to return a new unique address with every - // invocation. - var numAddrs atomic.Uint32 - nextAddr := func() net.Addr { - addrStr := fmt.Sprintf("10.0.0.%d:18555", numAddrs.Add(1)) - return mustParseAddrPort(addrStr) - } - // Constants for the number of various normal connection types to test // overall max normal connection limits. const ( @@ -1683,10 +1745,11 @@ func TestMaxNormalConns(t *testing.T) { connected := make(chan *Conn) disconnected := make(chan *Conn) inboundConns := make(chan *Conn) - listener := newMockListener("127.0.0.1:9108") + listener := defaultMockListener() var pauseTargetOutbound atomic.Bool var totalPausedAddrs atomic.Uint32 hitMaxFailedAttempts := make(chan struct{}) + addrGen := defaultAddrGenerator() cm := newTestConnManager(t, &Config{ Listeners: []net.Listener{listener}, MaxNormalConns: maxNormalConns, @@ -1704,7 +1767,7 @@ func TestMaxNormalConns(t *testing.T) { } return nil, errors.New("network down") } - return nextAddr(), nil + return addrGen.Next(), nil }, OnConnection: func(conn *Conn) { connected <- conn @@ -1728,7 +1791,7 @@ func TestMaxNormalConns(t *testing.T) { // be established. go func() { for range targetInbound { - listener.Connect(nextAddr()) + listener.Connect(addrGen.Next()) } }() inbounds := make([]*Conn, 0, targetInbound) @@ -1742,7 +1805,7 @@ func TestMaxNormalConns(t *testing.T) { // established. go func() { for range targetManual { - go cm.Connect(ctx, nextAddr()) + go cm.Connect(ctx, addrGen.Next()) } }() manualConns := make([]*Conn, 0, targetManual+1) @@ -1754,7 +1817,7 @@ func TestMaxNormalConns(t *testing.T) { // Ensure manual connections that would exceed the max allowed normal // connections are rejected. - _, err := cm.Connect(ctx, nextAddr()) + _, err := cm.Connect(ctx, addrGen.Next()) if !errors.Is(err, ErrMaxNormalConns) { t.Fatalf("did not reject manual connection at max allowed, err: %v", err) } @@ -1762,7 +1825,7 @@ func TestMaxNormalConns(t *testing.T) { // Ensure inbound connections that would exceed the max allowed normal // connections are rejected. - go listener.Connect(nextAddr()) + go listener.Connect(addrGen.Next()) assertNoConnReceived(t, inboundConns) assertConnManagerInternalState(t, cm) @@ -1784,7 +1847,7 @@ func TestMaxNormalConns(t *testing.T) { // Establish another manual connection to take the place of the target // outbound connection that was just closed and wait for it to be // established. - go cm.Connect(ctx, nextAddr()) + go cm.Connect(ctx, addrGen.Next()) assertConnReceived(t, connected, 0, ConnTypeManual) assertConnManagerInternalState(t, cm) @@ -1798,7 +1861,7 @@ func TestMaxNormalConns(t *testing.T) { // Ensure persistent connections are not subject to the max total normal // connections by adding one and waiting for it to be established. - connID, err := cm.AddPersistent(nextAddr()) + connID, err := cm.AddPersistent(addrGen.Next()) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } @@ -1815,25 +1878,21 @@ func TestMaxConnsPerHost(t *testing.T) { // nextSameHost is a convenience func to return a new address to the same IP // with a different port on every invocation. - var nextPort atomic.Uint32 - nextSameHost := func() net.Addr { - addrStr := fmt.Sprintf("10.10.0.1:%d", nextPort.Add(1)+1024) - return mustParseAddrPort(addrStr) - } + addrGen := defaultAddrGenerator() + addrGen.Next() + nextSameHost := addrGen.NextPort // nextSameHostWhitelisted is a convenience func to return a new address to // the same whitelisted IP with a different port on every invocation. - allowedIP := netip.MustParseAddr("10.20.0.1") - nextSameWhitelistedHost := func() net.Addr { - addrStr := fmt.Sprintf("%s:%d", allowedIP, nextPort.Add(1)+1024) - return mustParseAddrPort(addrStr) - } + allowedIP := netip.MustParseAddr("12.20.0.1") + addrGenWhitelisted := newAddrGenerator(allowedIP.String() + ":1025") + nextSameWhitelistedHost := addrGenWhitelisted.NextPort const maxConnsPerHost = 3 connected := make(chan *Conn, 1) disconnected := make(chan *Conn, 1) inboundConns := make(chan *Conn) - listener := newMockListener("127.0.0.1:9108") + listener := defaultMockListener() var pauseTargetOutbound atomic.Bool var totalPausedAddrs atomic.Uint32 hitMaxFailedAttempts := make(chan struct{}) From 7faf0ea8976e5d059864926310e0033a27892e26 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 26 May 2026 20:50:13 -0500 Subject: [PATCH 37/51] connmgr: Use concrete addrmgr types in test. This updates the primary parsing method in the connection manager tests to return a concrete addrmgr address instead of a stdlib net.TCPAddr. The goal is to eventually use the concrete address type almost everywhere to avoid a lot of the less than ideal address reparsing and repeated host/port splitting and joining. --- internal/connmgr/connmanager_test.go | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index ee18bfea5..ab909405c 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -24,6 +24,8 @@ import ( "testing" "testing/synctest" "time" + + "github.com/decred/dcrd/addrmgr/v4" ) // prngSeed is populated when the tests are initialized either by the -seed @@ -119,16 +121,13 @@ const ( defaultTestP2PPort = 18555 ) -// mustParseAddrPort parses the provided address into a [*net.TCPAddr] and will -// panic if there is an error. It will only (and must only) be called with -// hard-coded, and therefore known good, addresses. -func mustParseAddrPort(addr string) *net.TCPAddr { +// mustParseAddrPort parses the provided address into a [*addrmgr.NetAddress] +// and will panic if there is an error. It will only (and must only) be called +// with hard-coded, and therefore known good, addresses. +func mustParseAddrPort(addr string) *addrmgr.NetAddress { addrPort := netip.MustParseAddrPort(addr) - return &net.TCPAddr{ - IP: addrPort.Addr().AsSlice(), - Port: int(addrPort.Port()), - Zone: addrPort.Addr().Zone(), - } + return addrmgr.NewNetAddressFromIPPort(addrPort.Addr().AsSlice(), + addrPort.Port(), 0) } // addrGenerator houses state for an address generator used to simplify tests. @@ -155,7 +154,7 @@ func newAddrGenerator(baseAddrPort string) *addrGenerator { // // This provides convenient access to a new unique address with every // invocation. -func (g *addrGenerator) Next() net.Addr { +func (g *addrGenerator) Next() *addrmgr.NetAddress { g.mtx.Lock() defer g.mtx.Unlock() @@ -165,8 +164,7 @@ func (g *addrGenerator) Next() net.Addr { g.addr = g.addr.Next() } - port := strconv.Itoa(int(g.port)) - return mustParseAddrPort(net.JoinHostPort(g.addr.String(), port)) + return addrmgr.NewNetAddressFromIPPort(g.addr.AsSlice(), g.port, 0) } // NextPort advances the generator to the next port of the current IP and @@ -174,7 +172,7 @@ func (g *addrGenerator) Next() net.Addr { // // This provides convenient access to a new endpoint with the same host address // with every invocation. -func (g *addrGenerator) NextPort() net.Addr { +func (g *addrGenerator) NextPort() *addrmgr.NetAddress { g.mtx.Lock() defer g.mtx.Unlock() @@ -183,8 +181,7 @@ func (g *addrGenerator) NextPort() net.Addr { g.port = 1025 } - port := strconv.Itoa(int(g.port)) - return mustParseAddrPort(net.JoinHostPort(g.addr.String(), port)) + return addrmgr.NewNetAddressFromIPPort(g.addr.AsSlice(), g.port, 0) } // defaultAddrGenerator returns an address generator configured with a default From e8296746de491c4d59594ca86841cdad7dd37ec7 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 26 May 2026 20:50:14 -0500 Subject: [PATCH 38/51] connmgr: Use concrete addr type more often. This modifies the address handling to parse the stdlib net.Addr to a concrete addrmgr address earlier in the connect and persistent paths rather than doing it in the dial method and switch the callback to get new addresses to return the concrete type. The server is updated to simply the return the chosen address which no longer needs to be converted. Note that this is theoretically a semantics change because the code in server previously potentially resolved the address via DNS and that no longer is the case. In practice, nothing is really changing in terms of resolution though because the address manager only ever works with resolved addresses since it needs to gossip addresses via the wire protocol which only deals with resolved addresses. --- internal/connmgr/connmanager.go | 81 +++++++++++++++------------- internal/connmgr/connmanager_test.go | 15 +++--- server.go | 6 +-- 3 files changed, 53 insertions(+), 49 deletions(-) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 3e20be429..6c0a042a6 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -199,10 +199,9 @@ func (c *Conn) Type() ConnectionType { // pendingConnInfo houses information about a pending connection attempt. type pendingConnInfo struct { - id uint64 - addr *addrmgr.NetAddress - hostKey string - cancel context.CancelFunc + id uint64 + addr *addrmgr.NetAddress + cancel context.CancelFunc } // persistentEntry houses information about a persistent connection that has @@ -293,7 +292,7 @@ type Config struct { // GetNewAddress is a way to get an address to make a network connection // to. If nil, no new connections will be made automatically. - GetNewAddress func() (net.Addr, error) + GetNewAddress func() (*addrmgr.NetAddress, error) // Dial connects to the address on the named network. Dial func(ctx context.Context, network, addr string) (net.Conn, error) @@ -419,6 +418,12 @@ func (cm *ConnManager) checkShutdown() error { // stdlibNetAddrToAddrMgrNetAddr converts the provided standard lib [net.Addr] // to a concrete address manager address. func stdlibNetAddrToAddrMgrNetAddr(addr net.Addr) (*addrmgr.NetAddress, error) { + // Fast path for most addresses. + if na, ok := addr.(*addrmgr.NetAddress); ok { + return na, nil + } + + // Fall back to slower string parsing. host, portStr, err := net.SplitHostPort(addr.String()) if err != nil { str := fmt.Sprintf("unable to split address %q", addr) @@ -447,17 +452,8 @@ func stdlibNetAddrToAddrMgrNetAddr(addr net.Addr) (*addrmgr.NetAddress, error) { // addrHostKey returns the host portion of the passed address as a string // suitable for use as a map key. -func addrHostKey(addr net.Addr) string { - if na, ok := addr.(*addrmgr.NetAddress); ok { - return net.IP(na.IP).String() - } - - addrStr := addr.String() - host, _, err := net.SplitHostPort(addrStr) - if err == nil { - return host - } - return addrStr +func addrHostKey(addr *addrmgr.NetAddress) string { + return net.IP(addr.IP).String() } // decrementPerHostCount decrements the reference count for the provided host @@ -479,7 +475,7 @@ func (cm *ConnManager) addPendingInfo(info *pendingConnInfo) { cm.pending[info.id] = info if _, ok := cm.persistent[info.id]; !ok { cm.connIDByAddr[info.addr.String()] = info.id - cm.perHostCounts[info.hostKey]++ + cm.perHostCounts[addrHostKey(info.addr)]++ } } @@ -490,7 +486,7 @@ func (cm *ConnManager) removePendingInfo(info *pendingConnInfo) { delete(cm.pending, info.id) if _, ok := cm.persistent[info.id]; !ok { delete(cm.connIDByAddr, info.addr.String()) - cm.decrementPerHostCount(info.hostKey) + cm.decrementPerHostCount(addrHostKey(info.addr)) } } @@ -620,7 +616,7 @@ func (cm *ConnManager) rejectDuplicateAddr(addr *addrmgr.NetAddress) error { // not exempt. // // This function MUST be called with the connection mutex held (reads). -func (cm *ConnManager) rejectMaxConnsPerHost(addr *addrmgr.NetAddress, hostKey string, isWhitelisted bool) error { +func (cm *ConnManager) rejectMaxConnsPerHost(addr *addrmgr.NetAddress, isWhitelisted bool) error { // Whitelisted and loopback addresses are exempt. isLoopback := net.IP(addr.IP).IsLoopback() if isWhitelisted || isLoopback { @@ -628,7 +624,7 @@ func (cm *ConnManager) rejectMaxConnsPerHost(addr *addrmgr.NetAddress, hostKey s } maxAllowed := cm.cfg.MaxConnsPerHost - if numConns := cm.perHostCounts[hostKey]; numConns+1 > maxAllowed { + if numConns := cm.perHostCounts[addrHostKey(addr)]; numConns+1 > maxAllowed { str := fmt.Sprintf("a maximum of %d %s per host is allowed", maxAllowed, pickNoun(maxAllowed, "connection", "connections")) return MakeError(ErrMaxConnsPerHost, str) @@ -679,7 +675,7 @@ func (cm *ConnManager) rejectMaxConnsPerHost(addr *addrmgr.NetAddress, hostKey s // before the timeout configured for the connection manager // // This function is safe for concurrent access. -func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType ConnectionType, onClose func(), persistentConnID *uint64) (*Conn, error) { +func (cm *ConnManager) dial(ctx context.Context, addr *addrmgr.NetAddress, connType ConnectionType, onClose func(), persistentConnID *uint64) (*Conn, error) { var skipOnClose bool defer func() { if !skipOnClose && onClose != nil { @@ -696,12 +692,7 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect return nil, ctx.Err() } - rAddr, err := stdlibNetAddrToAddrMgrNetAddr(addr) - if err != nil { - return nil, err - } - rAddrHostKey := addrHostKey(rAddr) - isWhitelisted := cm.IsWhitelisted(rAddr) + isWhitelisted := cm.IsWhitelisted(addr) // Reject attempts to dial addresses that are already connected (or in the // process of it). Additionally, reject attempts to dial existing @@ -715,17 +706,17 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect rejectFn = cm.rejectConnectedAddr } cm.connMtx.Lock() - if err := rejectFn(rAddr); err != nil { + if err := rejectFn(addr); err != nil { cm.connMtx.Unlock() log.Debugf("Rejected connection: %v", err) return nil, err } // Limit the max number of connections per host. - err = cm.rejectMaxConnsPerHost(rAddr, rAddrHostKey, isWhitelisted) + err := cm.rejectMaxConnsPerHost(addr, isWhitelisted) if err != nil { cm.connMtx.Unlock() - log.Debugf("Rejected connection to %v: %v", rAddr, err) + log.Debugf("Rejected connection to %v: %v", addr, err) return nil, err } @@ -747,7 +738,7 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect } else { connID = cm.nextConnID.Add(1) } - info := &pendingConnInfo{connID, rAddr, rAddrHostKey, cancel} + info := &pendingConnInfo{connID, addr, cancel} cm.addPendingInfo(info) cm.connMtx.Unlock() defer func() { @@ -819,7 +810,7 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect // Create a new connection instance with the connection ID and type and add // an entry to the map that tracks all active connections. - conn = newConn(cm, netConn, connID, connType, rAddr, dialOnClose) + conn = newConn(cm, netConn, connID, connType, addr, dialOnClose) cm.addActiveConn(conn) cm.connMtx.Unlock() @@ -864,6 +855,11 @@ func (cm *ConnManager) dial(ctx context.Context, addr net.Addr, connType Connect // // This function is safe for concurrent access. func (cm *ConnManager) Connect(ctx context.Context, addr net.Addr) (*Conn, error) { + rAddr, err := stdlibNetAddrToAddrMgrNetAddr(addr) + if err != nil { + return nil, err + } + acquired, err := cm.totalNormalConnsSem.TryAcquire(ctx) if err != nil { if sErr := cm.checkShutdown(); sErr != nil { @@ -878,7 +874,7 @@ func (cm *ConnManager) Connect(ctx context.Context, addr net.Addr) (*Conn, error return nil, MakeError(ErrMaxNormalConns, str) } onClose := cm.totalNormalConnsSem.Release - conn, err := cm.dial(ctx, addr, ConnTypeManual, onClose, nil) + conn, err := cm.dial(ctx, rAddr, ConnTypeManual, onClose, nil) if err != nil { return nil, err } @@ -1036,7 +1032,6 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) netConn.Close() continue } - rAddrHostKey := addrHostKey(rAddr) isWhitelisted := cm.IsWhitelisted(rAddr) // Reject connections with the same host:port as any existing pending, @@ -1056,7 +1051,7 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) } // Limit the max number of connections per host. - err = cm.rejectMaxConnsPerHost(rAddr, rAddrHostKey, isWhitelisted) + err = cm.rejectMaxConnsPerHost(rAddr, isWhitelisted) if err != nil { cm.connMtx.Unlock() log.Debugf("Dropped connection from %v: %v", rAddr, err) @@ -1235,7 +1230,7 @@ func (cm *ConnManager) backoffWithJitter(retries uint32) time.Duration { // increasing backoff, up to a maximum for repeated failed attempts. // // This MUST be run as a goroutine. -func (cm *ConnManager) runPersistent(ctx context.Context, connID uint64, addr net.Addr) { +func (cm *ConnManager) runPersistent(ctx context.Context, connID uint64, addr *addrmgr.NetAddress) { // Ensure the connection is closed when the goroutine exits. var conn *Conn defer func() { @@ -1343,6 +1338,16 @@ func (cm *ConnManager) persistentConnsHandler(ctx context.Context) { } } +// pickOutboundAddr returns an address suitable for establishing a new outbound +// connection. +// +// It simply delegates to [Config.GetNewAddress] for now. +// +// This function is safe for concurrent access. +func (cm *ConnManager) pickOutboundAddr() (*addrmgr.NetAddress, error) { + return cm.cfg.GetNewAddress() +} + // targetOutboundHandler attempts to automatically maintain the target number of // outbound connections configured via [Config.TargetOutbound] when initially // creating the connection manager. @@ -1399,7 +1404,7 @@ func (cm *ConnManager) targetOutboundHandler(ctx context.Context) { return } - addr, err := cm.cfg.GetNewAddress() + addr, err := cm.pickOutboundAddr() if err != nil { failedAttempts.Add(1) log.Debugf("Failed to get address for outbound connection: %v", err) @@ -1409,7 +1414,7 @@ func (cm *ConnManager) targetOutboundHandler(ctx context.Context) { } wg.Add(1) - go func(addr net.Addr) { + go func(addr *addrmgr.NetAddress) { defer wg.Done() onClose := func() { cm.totalNormalConnsSem.Release() diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index ab909405c..0f5d88d55 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -927,7 +927,7 @@ func TestTargetOutbound(t *testing.T) { cm := newTestConnManager(t, &Config{ TargetOutbound: targetOutbound, Dial: mockDialer, - GetNewAddress: func() (net.Addr, error) { + GetNewAddress: func() (*addrmgr.NetAddress, error) { return addrGen.Next(), nil }, OnConnection: func(conn *Conn) { @@ -955,7 +955,7 @@ func TestDoubleClose(t *testing.T) { cm := newTestConnManager(t, &Config{ TargetOutbound: 1, Dial: mockDialer, - GetNewAddress: func() (net.Addr, error) { + GetNewAddress: func() (*addrmgr.NetAddress, error) { return addrGen.Next(), nil }, OnConnection: func(conn *Conn) { @@ -995,9 +995,8 @@ func TestRetryPersistent(t *testing.T) { connected := make(chan *Conn) disconnected := make(chan *Conn) cm := newTestConnManager(t, &Config{ - RetryDuration: time.Millisecond, - TargetOutbound: 1, - Dial: mockDialer, + RetryDuration: time.Millisecond, + Dial: mockDialer, OnConnection: func(conn *Conn) { connected <- conn }, @@ -1190,7 +1189,7 @@ func TestNetworkFailure(t *testing.T) { TargetOutbound: targetOutbound, RetryDuration: retryTimeout, Dial: errDialer, - GetNewAddress: func() (net.Addr, error) { + GetNewAddress: func() (*addrmgr.NetAddress, error) { return addrGen.Next(), nil }, OnConnection: func(conn *Conn) { @@ -1756,7 +1755,7 @@ func TestMaxNormalConns(t *testing.T) { OnAccept: func(conn *Conn) { inboundConns <- conn }, - GetNewAddress: func() (net.Addr, error) { + GetNewAddress: func() (*addrmgr.NetAddress, error) { if pauseTargetOutbound.Load() { total := totalPausedAddrs.Add(1) if total == maxFailedAttempts { @@ -1904,7 +1903,7 @@ func TestMaxConnsPerHost(t *testing.T) { OnAccept: func(conn *Conn) { inboundConns <- conn }, - GetNewAddress: func() (net.Addr, error) { + GetNewAddress: func() (*addrmgr.NetAddress, error) { if pauseTargetOutbound.Load() { total := totalPausedAddrs.Add(1) if total == maxFailedAttempts { diff --git a/server.go b/server.go index 327ed87ba..9db24745b 100644 --- a/server.go +++ b/server.go @@ -4258,7 +4258,7 @@ func newServer(ctx context.Context, profiler *profileServer, // to specified peers and actively avoid advertising and connecting to // discovered peers in order to prevent it from becoming a public test // network. - var newAddressFunc func() (net.Addr, error) + var newAddressFunc func() (*addrmgr.NetAddress, error) if !cfg.SimNet && !cfg.RegNet && len(cfg.ConnectPeers) == 0 { filter := func(addrType addrmgr.NetAddressType) bool { switch addrType { @@ -4270,7 +4270,7 @@ func newServer(ctx context.Context, profiler *profileServer, } return false } - newAddressFunc = func() (net.Addr, error) { + newAddressFunc = func() (*addrmgr.NetAddress, error) { for tries := 0; tries < 100; tries++ { addr := s.addrManager.GetAddress(filter) if addr == nil { @@ -4304,7 +4304,7 @@ func newServer(ctx context.Context, profiler *profileServer, continue } - return addrStringToNetAddr(netAddr.Key()) + return netAddr, nil } return nil, errors.New("no valid connect address") From 913f3d7409fb75e4185a960288c4f723191b1623 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 26 May 2026 20:50:15 -0500 Subject: [PATCH 39/51] connmgr: Add outbound group support to tests. Upcoming changes will be moving the restrictions on the number of automatic connections made to the same network group into the connection manager which will break several of the existing tests. With that in mind, this adds support to the test address generator for generating an address in the next outbound network group and updates the various tests that provide addresses for automatic outbounds to make use of it. An alternative would be to override the limits in every test once they're implemented, but it is generally a better idea to run tests with as few exceptions as possible to provide better coverage. --- internal/connmgr/connmanager_test.go | 79 ++++++++++++++++++++++------ 1 file changed, 63 insertions(+), 16 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 0f5d88d55..5786dfa21 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -149,6 +149,18 @@ func newAddrGenerator(baseAddrPort string) *addrGenerator { } } +// next advances the generator to the next IP and returns the result. It skips +// all addresses of the form "x.x.x.0". +func (g *addrGenerator) next() *addrmgr.NetAddress { + // Skip "x.x.x.0". + g.addr = g.addr.Next() + if g.addr.As4()[3] == 0 { + g.addr = g.addr.Next() + } + + return addrmgr.NewNetAddressFromIPPort(g.addr.AsSlice(), g.port, 0) +} + // Next advances the generator to the next IP and returns the result. It skips // all addresses of the form "x.x.x.0". // @@ -158,13 +170,7 @@ func (g *addrGenerator) Next() *addrmgr.NetAddress { g.mtx.Lock() defer g.mtx.Unlock() - // Skip "x.x.x.0". - g.addr = g.addr.Next() - if g.addr.As4()[3] == 0 { - g.addr = g.addr.Next() - } - - return addrmgr.NewNetAddressFromIPPort(g.addr.AsSlice(), g.port, 0) + return g.next() } // NextPort advances the generator to the next port of the current IP and @@ -184,6 +190,47 @@ func (g *addrGenerator) NextPort() *addrmgr.NetAddress { return addrmgr.NewNetAddressFromIPPort(g.addr.AsSlice(), g.port, 0) } +// nextPrefix advances the generator to the next IP for the given prefix bits +// and returns the result. It skips all addresses of the form "x.x.x.0". +func (g *addrGenerator) nextPrefix(prefixBits uint) *addrmgr.NetAddress { + if prefixBits == 32 { + return g.next() + } + + // Skip "x.x.x.0". + ip := g.addr.As4() + ip32 := binary.BigEndian.Uint32(ip[:]) + if ip32&0xff == 0 { + ip32++ + } + + // Split the IP into network and host bits based on the number of prefix + // bits. + networkMask := ^uint32(0) << (32 - prefixBits) + networkBits := (ip32 & networkMask) + hostBits := ip32 & ^networkMask + + // Calculate the next network. + nextNet := networkBits + (1 << (32 - prefixBits)) + + // Calculate and set the next address. + binary.BigEndian.PutUint32(ip[:], nextNet|hostBits) + g.addr = netip.AddrFrom4(ip) + + return addrmgr.NewNetAddressFromIPPort(g.addr.AsSlice(), g.port, 0) +} + +// NextOutboundGroup advances the generator to the next outbound group IP and +// returns the result. It skips all addresses of the form "x.x.x.0". +// +// An outbound group is determined by a certain number of prefix bits. +func (g *addrGenerator) NextOutboundGroup() *addrmgr.NetAddress { + g.mtx.Lock() + defer g.mtx.Unlock() + + return g.nextPrefix(g.outboundGroupPrefixBits) +} + // defaultAddrGenerator returns an address generator configured with a default // starting base address and port useful throughout the tests. The base address // is a normal routable IPv4 address. @@ -928,7 +975,7 @@ func TestTargetOutbound(t *testing.T) { TargetOutbound: targetOutbound, Dial: mockDialer, GetNewAddress: func() (*addrmgr.NetAddress, error) { - return addrGen.Next(), nil + return addrGen.NextOutboundGroup(), nil }, OnConnection: func(conn *Conn) { connected <- conn @@ -956,7 +1003,7 @@ func TestDoubleClose(t *testing.T) { TargetOutbound: 1, Dial: mockDialer, GetNewAddress: func() (*addrmgr.NetAddress, error) { - return addrGen.Next(), nil + return addrGen.NextOutboundGroup(), nil }, OnConnection: func(conn *Conn) { connected <- conn @@ -1190,7 +1237,7 @@ func TestNetworkFailure(t *testing.T) { RetryDuration: retryTimeout, Dial: errDialer, GetNewAddress: func() (*addrmgr.NetAddress, error) { - return addrGen.Next(), nil + return addrGen.NextOutboundGroup(), nil }, OnConnection: func(conn *Conn) { t.Fatalf("network failure: got unexpected connection - %v", @@ -1763,7 +1810,7 @@ func TestMaxNormalConns(t *testing.T) { } return nil, errors.New("network down") } - return addrGen.Next(), nil + return addrGen.NextOutboundGroup(), nil }, OnConnection: func(conn *Conn) { connected <- conn @@ -1801,7 +1848,7 @@ func TestMaxNormalConns(t *testing.T) { // established. go func() { for range targetManual { - go cm.Connect(ctx, addrGen.Next()) + go cm.Connect(ctx, addrGen.NextOutboundGroup()) } }() manualConns := make([]*Conn, 0, targetManual+1) @@ -1813,7 +1860,7 @@ func TestMaxNormalConns(t *testing.T) { // Ensure manual connections that would exceed the max allowed normal // connections are rejected. - _, err := cm.Connect(ctx, addrGen.Next()) + _, err := cm.Connect(ctx, addrGen.NextOutboundGroup()) if !errors.Is(err, ErrMaxNormalConns) { t.Fatalf("did not reject manual connection at max allowed, err: %v", err) } @@ -1821,7 +1868,7 @@ func TestMaxNormalConns(t *testing.T) { // Ensure inbound connections that would exceed the max allowed normal // connections are rejected. - go listener.Connect(addrGen.Next()) + go listener.Connect(addrGen.NextOutboundGroup()) assertNoConnReceived(t, inboundConns) assertConnManagerInternalState(t, cm) @@ -1843,7 +1890,7 @@ func TestMaxNormalConns(t *testing.T) { // Establish another manual connection to take the place of the target // outbound connection that was just closed and wait for it to be // established. - go cm.Connect(ctx, addrGen.Next()) + go cm.Connect(ctx, addrGen.NextOutboundGroup()) assertConnReceived(t, connected, 0, ConnTypeManual) assertConnManagerInternalState(t, cm) @@ -1857,7 +1904,7 @@ func TestMaxNormalConns(t *testing.T) { // Ensure persistent connections are not subject to the max total normal // connections by adding one and waiting for it to be established. - connID, err := cm.AddPersistent(addrGen.Next()) + connID, err := cm.AddPersistent(addrGen.NextOutboundGroup()) if err != nil { t.Fatalf("failed to add persistent connection: %v", err) } From 07b51bbb9872c517329122c2914900c36bb29433 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 26 May 2026 20:50:17 -0500 Subject: [PATCH 40/51] connmgr: Limit auto outbounds per group. Similar to other recent PRs in regards to limiting, the current code that limits automatic outbound peers to strongly prefer network groups that are not already connected is enforced by the server. Similar to the previous cases, enforcement after the fact by the server is not ideal for various reasons. One notable limitation with the current code is that it is not possible for it to be aware of pending attempts which significantly raises the probability of connecting to more than one early on before many addresses have been discovered. Performing the limiting and tracking in the connection manager itself allows these things to be properly and cleanly handled since it is aware of all pending and established connections, including manual ones. To make that happen, this implements the logic necessary to enforce a maximum of one outbound peer network group in the connection manager and removes the relevant code for it from the server. In addition to fixing several of the aforementioned potential logic races, this also further strengthens it against bad actors as a part of reimplementing it. The connection manager is now given a unique per-instance cryptographic key that is used when calculating the group keys. This has the effect of making it unpredictable to external observers which in turn protects against the possibility of network-wide poisoning attacks. --- go.mod | 2 +- internal/connmgr/connmanager.go | 262 ++++++++++++++++++++++++++- internal/connmgr/connmanager_test.go | 27 +-- internal/connmgr/csprng.go | 1 + server.go | 69 ++----- 5 files changed, 285 insertions(+), 76 deletions(-) diff --git a/go.mod b/go.mod index ae5853464..e775d64a2 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.24.0 require ( github.com/davecgh/go-spew v1.1.1 + github.com/dchest/siphash v1.2.3 github.com/decred/base58 v1.0.6 github.com/decred/dcrd/addrmgr/v4 v4.0.0 github.com/decred/dcrd/bech32 v1.1.4 @@ -49,7 +50,6 @@ require ( decred.org/cspp/v2 v2.4.0 // indirect github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 // indirect github.com/companyzero/sntrup4591761 v0.0.0-20220309191932-9e0f3af2f07a // indirect - github.com/dchest/siphash v1.2.3 // indirect github.com/decred/dcrd/dcrec/edwards/v2 v2.0.4 // indirect github.com/decred/dcrd/hdkeychain/v3 v3.1.3 // indirect github.com/golang/snappy v0.0.4 // indirect diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 6c0a042a6..8b09a357b 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "time" + "github.com/dchest/siphash" "github.com/decred/dcrd/addrmgr/v4" ) @@ -45,6 +46,11 @@ const ( // base times the number of retries that have been done. defaultMaxRetryDuration = time.Minute * 5 + // defaultMaxPerOutboundGroup is the default maximum number of connections + // per outbound group to strongly prefer when choosing automatic outbound + // addresses. + defaultMaxPerOutboundGroup = 1 + // defaultMaxNormalConns is the default maximum number of normal inbound, // outbound, and pending connections to permit. defaultMaxNormalConns = 125 @@ -243,6 +249,15 @@ type Config struct { // connections in that case. OnAccept func(*Conn) + // DefaultPort specifies the default peer-to-peer port for the active + // network. It is used to make certain policy decisions related to choosing + // suitable addresses. + // + // A value of 0 removes the default port from policy considerations. + // + // Defaults to 0. + DefaultPort uint16 + // MaxNormalConns is the maximum number of normal inbound, outbound, and // pending connections to permit. Defaults to 125. // @@ -290,9 +305,21 @@ type Config struct { // OnDisconnection is a callback that is fired when a connection is closed. OnDisconnection func(*Conn) - // GetNewAddress is a way to get an address to make a network connection - // to. If nil, no new connections will be made automatically. - GetNewAddress func() (*addrmgr.NetAddress, error) + // GetNewAddress is invoked to get an address suitable for making an + // outbound connection along with the last time the address was attempted. + // + // An error for the final return value indicates there are no addresses + // available at all. + // + // The function might be invoked several times to find a suitable address + // prior to attempting any. [Config.Dial] can be used to detect and record + // all attempts. + // + // If nil, no new connections will be made automatically. + // + // If not nil, it is expected to only return valid, routable addresses or an + // error indicating there are no addresses available. + GetNewAddress func() (*addrmgr.NetAddress, time.Time, error) // Dial connects to the address on the named network. Dial func(ctx context.Context, network, addr string) (net.Conn, error) @@ -306,6 +333,130 @@ type Config struct { Whitelists []netip.Prefix } +// outboundGroupInfo houses information related to tracking outbound groups. +// +// It is used to strongly prefer outbound connections to different network +// groups such that it is extremely difficult for attackers to gain control +// of addresses that are a part of a lot of different groups. +// +// This is separate and protected by its own mutex in order to prevent potential +// logic races that could otherwise be induced if it were done via the ordinary +// pending/active connection tracking. +// +// In particular, it is involved in address selection and thus any addresses +// that will ultimately be attempted need to be tracked under the same lock used +// for that selection. +type outboundGroupInfo struct { + // key is a unique cryptographically random seed used when determining + // outbound network group keys. It ensures different connection manager + // instances produce distinct mappings that are unpredictable to external + // observers. + key [2]uint64 + + sync.Mutex + + // These fields are protected by the embedded mutex. + // + // addrs tracks all pending and active addresses (host:port) that have + // entries in counts. + // + // counts provides fast O(1) lookup of the number of pending and active + // outbound addresses per outbound group. It is kept in sync with the addrs + // map. + addrs map[string]uint32 + counts map[uint64]uint32 +} + +// newOutboundGroupInfo returns an initialized outboundGroupInfo instance using +// the provided CSPRNG to generate a key. +func newOutboundGroupInfo(csprng csprng) *outboundGroupInfo { + return &outboundGroupInfo{ + key: [2]uint64{csprng.Uint64(), csprng.Uint64()}, + addrs: make(map[string]uint32), + counts: make(map[uint64]uint32), + } +} + +// GroupKey returns a key that represents the outbound network group for the +// address. +// +// Addresses are assigned to network groups such that it is extremely difficult +// for attackers to gain control of addresses that are a part of a lot of +// different groups. For example, IPv4 networks use the /16 prefix, so all +// addresses in an attacker-controlled subnet or ISP are assigned the same +// group. Other networks, such as IPv6 and Tor use similarly appropriate values +// for the respective networks. +// +// This function is safe for concurrent access. +func (g *outboundGroupInfo) GroupKey(addr *addrmgr.NetAddress) uint64 { + return siphash.Hash(g.key[0], g.key[1], []byte(addr.GroupKey())) +} + +// addAddr adds information about an address to the local state. This is +// expected to be invoked when an eligible outbound address will be dialed. +// +// This function MUST be called with the embedded mutex held (writes). +func (g *outboundGroupInfo) addAddr(addr *addrmgr.NetAddress) { + g.addrs[addr.String()]++ + g.counts[g.GroupKey(addr)]++ +} + +// AddAddr adds information about an address to the local state. This is +// expected to be invoked when an outbound address will be dialed. +// +// This function is safe for concurrent access. +func (g *outboundGroupInfo) AddAddr(addr *addrmgr.NetAddress) { + g.Lock() + g.addAddr(addr) + g.Unlock() +} + +// removeAddr removes information about an address from the local state. This +// is expected to be invoked when an outbound address that was previously added +// is no longer in use (e.g. a dial failed or a non-persistent connection +// associated with the previous addition is closed). +// +// This function MUST be called with the embedded mutex held (writes). +func (g *outboundGroupInfo) removeAddr(addr *addrmgr.NetAddress) { + // The entry might have already been removed by [ConnManager.Disconnect] or + // [ConnManager.Remove]. + addrStr := addr.String() + if _, ok := g.addrs[addrStr]; !ok { + return + } + + g.addrs[addrStr]-- + if g.addrs[addrStr] == 0 { + delete(g.addrs, addrStr) + } + groupKey := g.GroupKey(addr) + g.counts[groupKey]-- + if g.counts[groupKey] == 0 { + delete(g.counts, groupKey) + } +} + +// RemoveAddr removes information about an address from the local state. This +// is expected to be invoked when an outbound address that was previously added +// is no longer in use (e.g. a dial failed or a non-persistent connection +// associated with the previous addition is closed). +// +// This function is safe for concurrent access. +func (g *outboundGroupInfo) RemoveAddr(addr *addrmgr.NetAddress) { + g.Lock() + g.removeAddr(addr) + g.Unlock() +} + +// groupCount returns the number of actively tracked addresses in the same +// outbound group as the provided address. +// +// This function MUST be called with the embedded mutex held (reads). +func (g *outboundGroupInfo) groupCount(addr *addrmgr.NetAddress) uint32 { + groupKey := g.GroupKey(addr) + return g.counts[groupKey] +} + // ConnManager provides a manager to handle network connections. type ConnManager struct { // nextConnID is used to assign unique connection request IDs. @@ -334,6 +485,10 @@ type ConnManager struct { // is guaranteed not to overflow. maxRetryScalingBits uint8 + // maxPerOutboundGroup is the maximum number of connections per outbound + // group to strongly prefer when choosing automatic outbound addresses. + maxPerOutboundGroup uint32 + // runPersistentChan is used to signal the persistent connections handler to // launch a goroutine that attempts to always maintain an established // connection with a given address. @@ -352,6 +507,13 @@ type ConnManager struct { totalNormalConnsSem semaphore activeOutboundsSem semaphore + // outboundGroups tracks outbound address group information. + // + // It is used to strongly prefer outbound connections to different network + // groups such that it is extremely difficult for attackers to gain control + // of addresses that are a part of a lot of different groups. + outboundGroups *outboundGroupInfo + // ****************************************************************** // The fields below this point are protected by the connection mutex. // ****************************************************************** @@ -526,6 +688,7 @@ func (cm *ConnManager) addPersistentEntry(entry *persistentEntry) { cm.persistent[entry.id] = entry cm.connIDByAddr[entry.addr.String()] = entry.id cm.perHostCounts[addrHostKey(entry.addr)]++ + cm.outboundGroups.AddAddr(entry.addr) } // removePersistentEntry removes a persistent connection entry from the local @@ -540,6 +703,7 @@ func (cm *ConnManager) removePersistentEntry(entry *persistentEntry) { delete(cm.connIDByAddr, entry.addr.String()) cm.decrementPerHostCount(addrHostKey(entry.addr)) } + cm.outboundGroups.RemoveAddr(entry.addr) } // rejectConnectedAddr returns an error if there is already either an @@ -873,7 +1037,12 @@ func (cm *ConnManager) Connect(ctx context.Context, addr net.Addr) (*Conn, error pickNoun(maxAllowed, "connection", "connections")) return nil, MakeError(ErrMaxNormalConns, str) } - onClose := cm.totalNormalConnsSem.Release + + cm.outboundGroups.AddAddr(rAddr) + onClose := func() { + cm.outboundGroups.RemoveAddr(rAddr) + cm.totalNormalConnsSem.Release() + } conn, err := cm.dial(ctx, rAddr, ConnTypeManual, onClose, nil) if err != nil { return nil, err @@ -895,9 +1064,13 @@ func (cm *ConnManager) Disconnect(id uint64) error { // any connections that are already in progress and later succeed are // ignored. cm.connMtx.Lock() + _, isPersistent := cm.persistent[id] if info, ok := cm.pending[id]; ok { info.cancel() cm.removePendingInfo(info) + if !isPersistent { + cm.outboundGroups.RemoveAddr(info.addr) + } cm.connMtx.Unlock() return nil } @@ -908,7 +1081,6 @@ func (cm *ConnManager) Disconnect(id uint64) error { conn.Close() // Close requires the conn mutex. return nil } - _, isPersistent := cm.persistent[id] cm.connMtx.Unlock() // Not found in active or pending, but it might still be a persistent conn @@ -953,6 +1125,9 @@ func (cm *ConnManager) Remove(id uint64) error { if info, ok := cm.pending[id]; ok { info.cancel() cm.removePendingInfo(info) + if !isPersistent { + cm.outboundGroups.RemoveAddr(info.addr) + } cm.connMtx.Unlock() return nil } @@ -1338,14 +1513,81 @@ func (cm *ConnManager) persistentConnsHandler(ctx context.Context) { } } +// errNoSuitableAddr indicates no suitable address was found within the allowed +// attempts. +var errNoSuitableAddr = errors.New("no suitable outbound address") + // pickOutboundAddr returns an address suitable for establishing a new outbound // connection. // -// It simply delegates to [Config.GetNewAddress] for now. +// It calls [Config.GetNewAddress] repeatedly (up to a small limit) and applies +// several heuristics to avoid recently attempted addresses, nondefault ports, +// and addresses in already connected outbound groups. +// +// It returns [errNoSuitableAddr] if no suitable address is found after the +// allowed attempts. +// +// When the error is not nil, the returned address is added to the outbound +// groups and it is the responsibility of the caller to remove it when the +// address is no longer in use. // // This function is safe for concurrent access. func (cm *ConnManager) pickOutboundAddr() (*addrmgr.NetAddress, error) { - return cm.cfg.GetNewAddress() + cm.outboundGroups.Lock() + defer cm.outboundGroups.Unlock() + + const ( + // retries is the number of addrs to request before giving up for now. + retries = 100 + + // skipRecentsUntil is the number of tries to skip recently attempted + // addrs. + skipRecentsUntil = (retries * 3) / 10 + + // skipDefaultPortUntil is the number of tries to skip addrs with + // non-default ports. + skipDefaultPortUntil = retries / 2 + ) + + for tries := range retries { + // An error means no addresses are available. No need to retry for now. + addr, lastTry, err := cm.cfg.GetNewAddress() + if err != nil { + return nil, err + } + + // [Config.GetNewAddress] stipulates the returned address will not be + // invalid or unroutable. Those conditions are not double checked. + + // Skip addresses that already have too many other outbound connections + // in the same network group. + // + // The default maximum allowed by per group is one which means this + // significantly increases attack difficulty. + if cm.outboundGroups.groupCount(addr) >= cm.maxPerOutboundGroup { + continue + } + + // Skip recently attempted addresses unless no suitable address has been + // found for enough tries. + now := time.Now() + if tries < skipRecentsUntil && lastTry.Add(10*time.Minute).After(now) { + continue + } + + // Skip addresses with non-default ports unless no suitable address has + // been found for enough tries. + if defaultPort := cm.cfg.DefaultPort; defaultPort != 0 { + if tries < skipDefaultPortUntil && addr.Port != defaultPort { + continue + } + } + + cm.outboundGroups.addAddr(addr) + return addr, nil + } + + return nil, errNoSuitableAddr } // targetOutboundHandler attempts to automatically maintain the target number of @@ -1417,6 +1659,7 @@ func (cm *ConnManager) targetOutboundHandler(ctx context.Context) { go func(addr *addrmgr.NetAddress) { defer wg.Done() onClose := func() { + cm.outboundGroups.RemoveAddr(addr) cm.totalNormalConnsSem.Release() cm.activeOutboundsSem.Release() } @@ -1531,15 +1774,18 @@ func New(cfg *Config) (*ConnManager, error) { } cfg.TargetOutbound = min(cfg.TargetOutbound, cfg.MaxNormalConns) retryDurationBits := uint8(math.Ceil(math.Log2(float64(cfg.RetryDuration)))) + csprng := globalRand cm := ConnManager{ cfg: *cfg, // Copy so caller can't mutate quit: make(chan struct{}), - csprng: globalRand, + csprng: csprng, maxRetryDuration: defaultMaxRetryDuration, maxRetryScalingBits: 63 - retryDurationBits, + maxPerOutboundGroup: defaultMaxPerOutboundGroup, runPersistentChan: make(chan *persistentEntry, MaxPersistent), totalNormalConnsSem: makeSemaphore(cfg.MaxNormalConns), activeOutboundsSem: makeSemaphore(cfg.TargetOutbound), + outboundGroups: newOutboundGroupInfo(csprng), persistent: make(map[uint64]*persistentEntry, MaxPersistent), pending: make(map[uint64]*pendingConnInfo), active: make(map[uint64]*Conn, cfg.TargetOutbound), diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 5786dfa21..4d8c2c89c 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -293,6 +293,8 @@ func newTestConnManager(t *testing.T, cfg *Config) *ConnManager { seed := newTestPRNGSeed(t) src := mrand.NewChaCha8(seed) cm.csprng = mrand.New(src) // nolint:gosec + cm.outboundGroups.key[0] = cm.csprng.Uint64() + cm.outboundGroups.key[1] = cm.csprng.Uint64() return cm } @@ -974,8 +976,8 @@ func TestTargetOutbound(t *testing.T) { cm := newTestConnManager(t, &Config{ TargetOutbound: targetOutbound, Dial: mockDialer, - GetNewAddress: func() (*addrmgr.NetAddress, error) { - return addrGen.NextOutboundGroup(), nil + GetNewAddress: func() (*addrmgr.NetAddress, time.Time, error) { + return addrGen.NextOutboundGroup(), time.Time{}, nil }, OnConnection: func(conn *Conn) { connected <- conn @@ -1002,8 +1004,8 @@ func TestDoubleClose(t *testing.T) { cm := newTestConnManager(t, &Config{ TargetOutbound: 1, Dial: mockDialer, - GetNewAddress: func() (*addrmgr.NetAddress, error) { - return addrGen.NextOutboundGroup(), nil + GetNewAddress: func() (*addrmgr.NetAddress, time.Time, error) { + return addrGen.NextOutboundGroup(), time.Time{}, nil }, OnConnection: func(conn *Conn) { connected <- conn @@ -1236,8 +1238,8 @@ func TestNetworkFailure(t *testing.T) { TargetOutbound: targetOutbound, RetryDuration: retryTimeout, Dial: errDialer, - GetNewAddress: func() (*addrmgr.NetAddress, error) { - return addrGen.NextOutboundGroup(), nil + GetNewAddress: func() (*addrmgr.NetAddress, time.Time, error) { + return addrGen.NextOutboundGroup(), time.Time{}, nil }, OnConnection: func(conn *Conn) { t.Fatalf("network failure: got unexpected connection - %v", @@ -1802,15 +1804,15 @@ func TestMaxNormalConns(t *testing.T) { OnAccept: func(conn *Conn) { inboundConns <- conn }, - GetNewAddress: func() (*addrmgr.NetAddress, error) { + GetNewAddress: func() (*addrmgr.NetAddress, time.Time, error) { if pauseTargetOutbound.Load() { total := totalPausedAddrs.Add(1) if total == maxFailedAttempts { hitMaxFailedAttempts <- struct{}{} } - return nil, errors.New("network down") + return nil, time.Time{}, errors.New("network down") } - return addrGen.NextOutboundGroup(), nil + return addrGen.NextOutboundGroup(), time.Time{}, nil }, OnConnection: func(conn *Conn) { connected <- conn @@ -1950,15 +1952,15 @@ func TestMaxConnsPerHost(t *testing.T) { OnAccept: func(conn *Conn) { inboundConns <- conn }, - GetNewAddress: func() (*addrmgr.NetAddress, error) { + GetNewAddress: func() (*addrmgr.NetAddress, time.Time, error) { if pauseTargetOutbound.Load() { total := totalPausedAddrs.Add(1) if total == maxFailedAttempts { close(hitMaxFailedAttempts) } - return nil, errors.New("network down") + return nil, time.Time{}, errors.New("network down") } - return nextSameHost(), nil + return nextSameHost(), time.Time{}, nil }, OnConnection: func(conn *Conn) { connected <- conn @@ -1968,6 +1970,7 @@ func TestMaxConnsPerHost(t *testing.T) { }, }) cm.maxRetryDuration = cm.cfg.RetryDuration + cm.maxPerOutboundGroup = maxConnsPerHost + 2 ctx, _, _ := runConnMgrAsync(t, cm) // Wait for the maximum allowed non-whitelisted per-host automatic outbound diff --git a/internal/connmgr/csprng.go b/internal/connmgr/csprng.go index 4f30b6d5b..2aae55268 100644 --- a/internal/connmgr/csprng.go +++ b/internal/connmgr/csprng.go @@ -14,6 +14,7 @@ import ( // uses. This primarily exists so tests can replace the real implementation // with a deterministic PRNG for reproducibility. type csprng interface { + Uint64() uint64 Uint64N(n uint64) uint64 } diff --git a/server.go b/server.go index 9db24745b..1c6eaed0c 100644 --- a/server.go +++ b/server.go @@ -239,19 +239,16 @@ type peerState struct { outboundPeers map[int32]*serverPeer persistentPeers map[int32]*serverPeer banned map[string]time.Time - outboundGroups map[string]int } // makePeerState returns a peer state instance that is used to maintain the -// state of inbound, persistent, and outbound peers as well as banned peers and -// outbound groups. +// state of inbound, persistent, and outbound peers as well as banned peers. func makePeerState() peerState { return peerState{ inboundPeers: make(map[int32]*serverPeer), persistentPeers: make(map[int32]*serverPeer), outboundPeers: make(map[int32]*serverPeer), banned: make(map[string]time.Time), - outboundGroups: make(map[string]int), } } @@ -2683,7 +2680,6 @@ func (s *server) handleAddPeer(sp *serverPeer) bool { } // The peer is an outbound peer at this point. - state.outboundGroups[sp.remoteAddr.GroupKey()]++ if sp.persistent { state.persistentPeers[sp.ID()] = sp } else { @@ -2724,9 +2720,6 @@ func (s *server) DonePeer(sp *serverPeer) { list = state.outboundPeers } if _, ok := list[sp.ID()]; ok { - if !sp.Inbound() { - state.outboundGroups[sp.remoteAddr.GroupKey()]-- - } delete(list, sp.ID()) srvrLog.Debugf("Removed peer %s", sp) return @@ -2819,15 +2812,6 @@ func (s *server) ConnectedCount() int32 { return numConnected } -// OutboundGroupCount returns the number of peers connected to the given -// outbound group key. -func (s *server) OutboundGroupCount(key string) int { - s.peerState.Lock() - count := s.peerState.outboundGroups[key] - s.peerState.Unlock() - return count -} - // AddBytesSent adds the passed number of bytes to the total bytes sent counter // for the server. It is safe for concurrent access. func (s *server) AddBytesSent(bytesSent uint64) { @@ -3911,6 +3895,12 @@ func newServer(ctx context.Context, profiler *profileServer, listenAddrs []string, db database.DB, utxoDb *leveldb.DB, chainParams *chaincfg.Params, dataDir string) (*server, error) { + defaultP2PPort, err := strconv.ParseUint(chainParams.DefaultPort, 10, 16) + if err != nil { + err = fmt.Errorf("invalid default p2p port in chain params: %w", err) + return nil, err + } + amgr := addrmgr.New(cfg.DataDir) services := defaultServices @@ -4258,7 +4248,7 @@ func newServer(ctx context.Context, profiler *profileServer, // to specified peers and actively avoid advertising and connecting to // discovered peers in order to prevent it from becoming a public test // network. - var newAddressFunc func() (*addrmgr.NetAddress, error) + var newAddressFunc func() (*addrmgr.NetAddress, time.Time, error) if !cfg.SimNet && !cfg.RegNet && len(cfg.ConnectPeers) == 0 { filter := func(addrType addrmgr.NetAddressType) bool { switch addrType { @@ -4270,44 +4260,12 @@ func newServer(ctx context.Context, profiler *profileServer, } return false } - newAddressFunc = func() (*addrmgr.NetAddress, error) { - for tries := 0; tries < 100; tries++ { - addr := s.addrManager.GetAddress(filter) - if addr == nil { - break - } - - // Address will not be invalid, local or unroutable - // because addrmanager rejects those on addition. - // Just check that we don't already have an address - // in the same group so that we are not connecting - // to the same network segment at the expense of - // others. - netAddr := addr.NetAddress() - if s.OutboundGroupCount(netAddr.GroupKey()) != 0 { - continue - } - - // Skip recently attempted nodes until we have - // tried 30 times. - if tries < 30 { - lastAttempt := addr.LastAttempt() - if !lastAttempt.IsZero() && - time.Since(lastAttempt) < 10*time.Minute { - continue - } - } - - // allow nondefault ports after 50 failed tries. - if fmt.Sprintf("%d", netAddr.Port) != - s.chainParams.DefaultPort && tries < 50 { - continue - } - - return netAddr, nil + newAddressFunc = func() (*addrmgr.NetAddress, time.Time, error) { + addr := s.addrManager.GetAddress(filter) + if addr == nil { + return nil, time.Time{}, errors.New("no valid connect address") } - - return nil, errors.New("no valid connect address") + return addr.NetAddress(), addr.LastAttempt(), nil } } @@ -4320,6 +4278,7 @@ func newServer(ctx context.Context, profiler *profileServer, OnAccept: func(conn *connmgr.Conn) { s.inboundPeerConnected(ctx, conn) }, + DefaultPort: uint16(defaultP2PPort), RetryDuration: connectionRetryInterval, MaxNormalConns: uint32(cfg.MaxPeers), MaxConnsPerHost: uint32(cfg.MaxSameIP), From d190342b237f3845b1e4e6962dd8eb9a175a0d98 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Tue, 26 May 2026 20:50:17 -0500 Subject: [PATCH 41/51] connmgr: Add outbound group tests. This adds tests to ensure the new outbound group counting and limiting properly limits outbound connections to one connection per outbound group. It also includes randomized address generation with some address in the same group, some with non-default ports, and some recently attempted or not in order to exercise the retry logic as well. --- internal/connmgr/connmanager_test.go | 152 +++++++++++++++++++++++---- 1 file changed, 130 insertions(+), 22 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 4d8c2c89c..2305521ba 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -298,9 +298,9 @@ func newTestConnManager(t *testing.T, cfg *Config) *ConnManager { return cm } -// assertConnManagerInternalState ensures the internal state of the passed +// assertInternalConnState ensures the internal connection state of the passed // connection manager instance is coherent. -func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { +func assertInternalConnState(t *testing.T, cm *ConnManager) { t.Helper() cm.connMtx.Lock() @@ -355,8 +355,8 @@ func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { t.Fatalf("conn ID for addr %s mismatch: %d != %d", addrStr, existingID, id) } - perHostCounts[addrHostKey(entry.addr)]++ connIDByAddr[addrStr] = id + perHostCounts[addrHostKey(entry.addr)]++ } // Assert the addr to conn ID mappings match the values obtained from @@ -374,32 +374,73 @@ func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { } } +// assertInternalOutboundGroupState ensures the internal outbound group state of +// the passed connection manager instance is coherent. +func assertInternalOutboundGroupState(t *testing.T, cm *ConnManager) { + t.Helper() + + cm.outboundGroups.Lock() + defer cm.outboundGroups.Unlock() + + // Assert the outbound group counts match a manual tally. + outboundGroupCounts := make(map[uint64]uint32) + for addr, count := range cm.outboundGroups.addrs { + netAddr := mustParseAddrPort(addr) + outboundGroupCounts[cm.outboundGroups.GroupKey(netAddr)] += count + } + if !reflect.DeepEqual(cm.outboundGroups.counts, outboundGroupCounts) { + t.Fatalf("mismatched outbound group count maps\ngot: %v\nwant %v", + cm.outboundGroups.counts, outboundGroupCounts) + } +} + +// assertConnManagerInternalState ensures the internal state of the passed +// connection manager instance is coherent. +func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { + t.Helper() + + assertInternalConnState(t, cm) + assertInternalOutboundGroupState(t, cm) +} + // assertConnManagerCleanShutdown ensures the internal state of the passed // connection manager is fully cleaned up as expected. It must only be called // after [ConnManager.Run] returns. func assertConnManagerCleanShutdown(t *testing.T, cm *ConnManager) { t.Helper() - cm.connMtx.Lock() - defer cm.connMtx.Unlock() + func() { + cm.connMtx.Lock() + defer cm.connMtx.Unlock() - if len(cm.active) != 0 { - t.Fatalf("active map is not empty: %d entries", len(cm.active)) - } - if len(cm.pending) != 0 { - t.Fatalf("pending map is not empty: %d entries", len(cm.pending)) - } - if len(cm.persistent) != 0 { - t.Fatalf("persistent map is not empty: %d entries", len(cm.persistent)) - } - if len(cm.connIDByAddr) != 0 { - t.Fatalf("conn ID by addr map not empty: %d entries", - len(cm.connIDByAddr)) - } - if len(cm.perHostCounts) != 0 { - t.Fatalf("per host counts map not empty: %d entries", - len(cm.perHostCounts)) - } + if count := len(cm.active); count != 0 { + t.Fatalf("active map is not empty: %d entries", count) + } + if count := len(cm.pending); count != 0 { + t.Fatalf("pending map is not empty: %d entries", count) + } + if count := len(cm.persistent); count != 0 { + t.Fatalf("persistent map is not empty: %d entries", count) + } + if count := len(cm.connIDByAddr); count != 0 { + t.Fatalf("conn ID by addr map not empty: %d entries", count) + } + if count := len(cm.perHostCounts); count != 0 { + t.Fatalf("per host counts map not empty: %d entries", count) + } + }() + + func() { + cm.outboundGroups.Lock() + defer cm.outboundGroups.Unlock() + + if count := len(cm.outboundGroups.addrs); count != 0 { + t.Fatalf("outbound group addrs map not empty: %d entries", count) + } + if count := len(cm.outboundGroups.counts); count != 0 { + t.Fatalf("outbound group counts map not empty: %d entries", count) + } + }() } // TestNewConfig tests that new ConnManager config is validated as expected. @@ -2057,3 +2098,70 @@ func TestMaxConnsPerHost(t *testing.T) { assertNoConnReceivedTimeout(t, connected, noConnWaitTimeout) assertConnManagerInternalState(t, cm) } + +// TestOutboundGroups ensures the connection manager limits the automatic +// outbound connections to one connection per outbound group. It includes +// randomized address generation with some addresses in the same group, some +// with non-default ports, and some recently attempted or not to exercise the +// retry logic as well. +func TestOutboundGroups(t *testing.T) { + t.Parallel() + + addrGen := defaultAddrGenerator() + defaultPort := addrGen.port + var cm *ConnManager + randomizedNewAddr := func() (*addrmgr.NetAddress, time.Time, error) { + // Only return a new outbound group 10% of the time. + var addr *addrmgr.NetAddress + if rv := cm.csprng.Uint64N(10); rv < 1 { + addr = addrGen.NextOutboundGroup() + } else { + addr = addrGen.Next() + } + + // Return a random port 50% of the time. + if cm.csprng.Uint64N(10) < 5 { + const minPort = 1025 + addr.Port = uint16(minPort + cm.csprng.Uint64N(1<<16-1-minPort)) + } + + // Return a recent last attempt 30% of the time. + var lastAttempt time.Time + if cm.csprng.Uint64N(10) < 3 { + lastAttempt = time.Now().Add(-20 * time.Second) + } + + return addr, lastAttempt, nil + } + + const targetOutbound = 5 + connected := make(chan *Conn) + cm = newTestConnManager(t, &Config{ + TargetOutbound: targetOutbound, + RetryDuration: 50 * time.Millisecond, + DefaultPort: defaultPort, + Dial: mockDialer, + GetNewAddress: randomizedNewAddr, + OnConnection: func(conn *Conn) { + connected <- conn + }, + }) + cm.maxRetryDuration = cm.cfg.RetryDuration + runConnMgrAsync(t, cm) + + // Wait for the expected number of target outbound conns to be established. + groups := make(map[uint64]struct{}) + outbounds := make([]*Conn, 0, targetOutbound) + for len(outbounds) < targetOutbound { + conn := assertConnReceived(t, connected, 0, ConnTypeOutbound) + outbounds = append(outbounds, conn) + groups[cm.outboundGroups.GroupKey(&conn.remoteAddr)] = struct{}{} + } + assertConnManagerInternalState(t, cm) + + // Ensure only one address per outbound group was selected. + if len(groups) != targetOutbound { + t.Fatalf("unexpected number of outbound groups -- got %d, want %d", + len(groups), targetOutbound) + } +} From 929aa0ffe442e4ccc5c0fc5384ad196396605cbc Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sun, 31 May 2026 23:27:21 -0500 Subject: [PATCH 42/51] connmgr: Update README.md for outbound groups. --- internal/connmgr/README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/internal/connmgr/README.md b/internal/connmgr/README.md index a7e171979..9321154ac 100644 --- a/internal/connmgr/README.md +++ b/internal/connmgr/README.md @@ -13,7 +13,7 @@ logic. It handles all general connection lifecycle concerns such as accepting inbound connections, automatically maintaining a set number of outbound connections, -maintaining persistent connections, and limiting max connections. +maintaining persistent connections, enforcing limits, and preventing duplicates. The design has a strong emphasis on reliability, readability, and efficiency under high connection load while also aiming to provide an ergonomic API. @@ -26,6 +26,12 @@ The following is a brief overview of the key features: - Automatic outbound maintenance - Maintains up to `TargetOutbound` normal outbound connections via a provided address source (`GetNewAddress`) + - Strongly prefers connections to different network segments + - Incorporates intelligent address selection + - Skips addresses in already-connected outbound groups + - Skips recently attempted addresses unless no suitable addresses are found + after enough retries + - Prefers default peer-to-peer port addresses (configurable via `DefaultPort`) - Persistent connections - Maintains up to `MaxPersistent` addresses that are automatically retried with exponential backoff and jitter on disconnect From cfb50a71529bb8c897df6cd3492eeb0c58745782 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 1 Jun 2026 01:17:54 -0500 Subject: [PATCH 43/51] ratelimit: Introduce new internal package. With the primary goal of supported upcoming inbound rate limiting and anti-flood measures, this implements a simple concurrent safe token bucket rate limiter. Since rate limiting is often in the hot path and exposed to extreme conditions, the package is designed to be highly efficient, use minimal memory, and support high concurrency. The API is currently primarily aimed at serving use cases where the intention is to drop events that exceed the rate limit. However, it also provides the minimal information needed for callers to manually implement blocking behavior until the next event is allowed if desired. This only implements the base package functionality. Future commits will add tests, examples, and documentation. --- internal/ratelimit/ratelimit.go | 188 ++++++++++++++++++++++++++++++++ 1 file changed, 188 insertions(+) create mode 100644 internal/ratelimit/ratelimit.go diff --git a/internal/ratelimit/ratelimit.go b/internal/ratelimit/ratelimit.go new file mode 100644 index 000000000..2dd11bc9f --- /dev/null +++ b/internal/ratelimit/ratelimit.go @@ -0,0 +1,188 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package ratelimit + +import ( + "math" + "sync" + "time" +) + +// Forever represents an infinite duration. +const Forever time.Duration = math.MaxInt64 + +// Limiter provides a simple token bucket rate limiter for controlling the +// frequency of permitted events. The token bucket algorithm used by this +// implementation works by starting with a fixed number of tokens specified by +// the burst size and refills the tokens at the specified rate. Events are +// allowed so long as there is at least one token in the bucket at the time of +// the event. It is ideal for use in traffic shaping and traffic policing. +// +// The limiter is primarily aimed at serving use cases where the intention is to +// drop events that exceed the rate limit by only processing events when +// [Limiter.Allow] reports true and dropping them otherwise. +// +// The current number of tokens in the bucket can be obtained with +// [Limiter.Tokens] and the fixed burst size is available via [Limiter.Burst]. +// +// Callers that wish to block until the next event is allowed instead of merely +// dropping events that exceed the rate can make use of +// [Limiter.UntilNextAllowed] to determine how long they need to wait. +// +// [New] must be used to create a usable limiter since the zero value of this +// struct is not valid. +type Limiter struct { + // These fields do not require a mutex since they are set at initialization + // time and never modified after. + nowFn func() time.Time + + // These fields are protected by the embedded mutex. + // + // rate is the number of events to allow per second. + // + // burst is the maximum amount of tokens to permit for handling rapid + // bursts of events. + // + // tokens is the number of remaining tokens in the bucket. More events are + // permitted while it is greater than 0. + // + // updated is the last time tokens was updated. + mtx sync.Mutex + rate float64 + burst float64 + tokens float64 + updated time.Time +} + +// New returns a token bucket rate limiter that allows events up to the provided +// rate, in number of events per second, while also allowing bursts up to the +// provided burst size. +// +// For example a rate of 10.5 with a burst size of 30 would allow an average of +// 10.5 events per second with periodic bursts of up to 30 events. +// +// In order to rate limit events to every X seconds (versus X events per +// second), specify the rate scaled by 1/X. +// +// Scale the rate accordingly for other time units. +// +// For example, to specify a rate of 15 events per minute, the rate would be 900 +// (15*60) and 450 events every 2 hours would be 450/(2*3600) = 0.0625. +func New(rate float64, burst uint32) *Limiter { + return &Limiter{ + nowFn: time.Now, + rate: rate, + burst: float64(burst), + tokens: float64(burst), + } +} + +// Burst returns the burst size specified when the limiter was created. +// +// This function is safe for concurrent access. +func (l *Limiter) Burst() uint32 { + l.mtx.Lock() + burst := uint32(l.burst) + l.mtx.Unlock() + return burst +} + +// durationToTokens returns the number of tokens that would refill during the +// provided duration at the given rate. +func durationToTokens(d time.Duration, rate float64) float64 { + if rate <= 0 { + return 0 + } + return d.Seconds() * rate +} + +// tokensAt returns the number of available tokens at the provided time. Times +// prior to the last time the limiter was updated will return the current number +// of tokens in the bucket. +// +// This function MUST be called with the embedded mutex held (for reads). +func (l *Limiter) tokensAt(t time.Time) float64 { + updated := l.updated + if t.Before(updated) { + updated = t + } + elapsed := t.Sub(updated) + delta := durationToTokens(elapsed, l.rate) + numTokens := l.tokens + delta + if numTokens > l.burst { + numTokens = l.burst + } + return numTokens +} + +// Tokens returns the number of available tokens at the current time. +// +// This function is safe for concurrent access. +func (l *Limiter) Tokens() float64 { + l.mtx.Lock() + tokens := l.tokensAt(l.nowFn()) + l.mtx.Unlock() + return tokens +} + +// Allow returns true when an event is allowed to happen at the current time +// and, when it returns true, the state is updated to consume a token. +// +// Callers will typically want to process the event normally when true is +// returned and drop it otherwise. +func (l *Limiter) Allow() bool { + l.mtx.Lock() + defer l.mtx.Unlock() + + now := l.nowFn() + tokens := l.tokensAt(now) + tokens -= float64(1) + + if tokens >= 0 && tokens <= l.burst { + l.updated = now + l.tokens = tokens + return true + } + return false +} + +// tokensToDuration returns the duration the provided number of tokens would +// take to refill at the given rate. +func tokensToDuration(tokens float64, rate float64) time.Duration { + if rate <= 0 { + return Forever + } + duration := (tokens / rate) * float64(time.Second) + if uint64(duration) > math.MaxInt64 { + return Forever + } + return time.Duration(duration) +} + +// UntilNextAllowed returns the duration that must elapse until the next event +// is allowed. [Forever] is returned when no more events will ever be allowed, +// such as when there is a negative rate or a burst size of 0. +func (l *Limiter) UntilNextAllowed() time.Duration { + // Events are never allowed with a burst size of 0. + l.mtx.Lock() + if l.burst == 0 { + l.mtx.Unlock() + return Forever + } + tokens := l.tokensAt(l.nowFn()) + rate := l.rate + l.mtx.Unlock() + + // The next event is not allowed until there is at least one token, so + // determine how much is needed to reach one. + needed := 1 - tokens + if needed <= 0 { + // There is already one or more tokens available. + return 0 + } + + // Convert the needed tokens into a duration based on the rate. + return tokensToDuration(needed, rate) +} From eae7af1390fbfed2ff6dafb5d1580c7cd4f45a0c Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 1 Jun 2026 01:17:55 -0500 Subject: [PATCH 44/51] ratelimit: Add allow and max duration tests. This adds tests to ensure the functionality related to the Allow method works as intended including the token regeneration rate, burst handling, allowed reporting, and time until next allowed. It also adds additional tests to ensure pathological conditions that would exceed the max allowed time.Duration are clamped as expected. --- internal/ratelimit/ratelimit_test.go | 240 +++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 internal/ratelimit/ratelimit_test.go diff --git a/internal/ratelimit/ratelimit_test.go b/internal/ratelimit/ratelimit_test.go new file mode 100644 index 000000000..ea42b4d68 --- /dev/null +++ b/internal/ratelimit/ratelimit_test.go @@ -0,0 +1,240 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package ratelimit + +import ( + "math" + "testing" + "time" +) + +// TestLimiter ensures functionality related to [Limiter.Allow] works as +// intended such as the token regeneration rate, burst handling, allowed +// reporting and time until next allowed. It also includes tests for some +// corner cases such as negative rates and backwards time jumps. +func TestLimiter(t *testing.T) { + t.Parallel() + + // Create a mock time to control testing. + startTime := time.Now() + mockNow := startTime + mockNowFn := func() time.Time { return mockNow } + + // asSec returns the provided number of seconds as a duration. + asSec := func(secs int) time.Duration { + return time.Second * time.Duration(secs) + } + + // asMs returns the provided number of milliseconds as a duration. + asMs := func(millis int) time.Duration { + return time.Millisecond * time.Duration(millis) + } + + type perLimiterTest struct { + off time.Duration // offset from start time to test + tokens float64 // expected number of tokens at time + next time.Duration // expected duration until next event is allowed + allowed bool // expected allow value at time to test + } + + tests := []struct { + name string // test description + rate float64 // rate to use + burst uint32 // burst size to use + perLimiterTests []perLimiterTest // tests to run against limiter + }{{ + // Test burst capacity and regen rate with 1 event per sec and burst + // size 5. + name: "1/s, burst 5", + rate: 1.0, + burst: 5, + perLimiterTests: []perLimiterTest{ + // Bucket starts full and allows up to burst rate. + {off: 0, tokens: 5, next: 0, allowed: true}, + {off: 0, tokens: 4, next: 0, allowed: true}, + {off: 0, tokens: 3, next: 0, allowed: true}, + {off: 0, tokens: 2, next: 0, allowed: true}, + {off: 0, tokens: 1, next: 0, allowed: true}, + // Tokens exhausted. + {off: 0, tokens: 0, next: asSec(1), allowed: false}, + // Bucket refills at 1 per sec. + {off: asSec(1), tokens: 1, next: 0, allowed: true}, + {off: asSec(2), tokens: 1, next: 0, allowed: true}, + {off: asSec(2), tokens: 0, next: asSec(1), allowed: false}, + // Back to full. + {off: asSec(7), tokens: 5, next: 0, allowed: true}, + // Doesn't fill back up more than burst. + {off: asSec(9), tokens: 5, next: 0, allowed: true}, + }, + }, { + // Test burst capacity and regen rate with 1 event per sec and burst + // size 1. + name: "1/s, burst 1", + rate: 1.0, + burst: 1, + perLimiterTests: []perLimiterTest{ + // Bucket starts full and allows up to burst rate. + {off: 0, tokens: 1, next: 0, allowed: true}, + // Tokens exhausted. + {off: 0, tokens: 0, next: asSec(1), allowed: false}, + // Bucket refills at 1 per sec. + {off: asSec(1), tokens: 1, next: 0, allowed: true}, + {off: asSec(2), tokens: 1, next: 0, allowed: true}, + {off: asSec(2), tokens: 0, next: asSec(1), allowed: false}, + // Back to full. + {off: asSec(3), tokens: 1, next: 0, allowed: true}, + // Doesn't fill back up more than burst. + {off: asSec(5), tokens: 1, next: 0, allowed: true}, + }, + }, { + // Test backwards time jumps and regen rate with 100 events per sec and + // burst size 6. Thus the refill rate is 1 event per 10ms. + name: "100/s, burst 6", + rate: 100.0, + burst: 6, + perLimiterTests: []perLimiterTest{ + // Start one second in the future, back one second to the starting + // time and consume all tokens. + {off: asSec(1), tokens: 6, next: 0, allowed: true}, + {off: 0, tokens: 5, next: 0, allowed: true}, + {off: 0, tokens: 4, next: 0, allowed: true}, + {off: 0, tokens: 3, next: 0, allowed: true}, + {off: 0, tokens: 2, next: 0, allowed: true}, + {off: 0, tokens: 1, next: 0, allowed: true}, + // Tokens exhausted. + {off: 0, tokens: 0, next: asMs(10), allowed: false}, + {off: 0, tokens: 0, next: asMs(10), allowed: false}, + // Bucket refills at 1 per 10ms. + {off: asMs(10), tokens: 1, next: 0, allowed: true}, + {off: asMs(10), tokens: 0, next: asMs(10), allowed: false}, + {off: asMs(20), tokens: 1, next: 0, allowed: true}, + // Back to full. + {off: asMs(80), tokens: 6, next: 0, allowed: true}, + // Doesn't fill back up more than burst. + {off: asMs(100), tokens: 6, next: 0, allowed: true}, + }, + }, { + // Test burst capacity with 1 event per 10 secs and burst size 10. + name: "1 per 10s, burst 10", + rate: 0.1, + burst: 10, + perLimiterTests: []perLimiterTest{ + {off: 0, tokens: 10, next: 0, allowed: true}, + {off: 0, tokens: 9, next: 0, allowed: true}, + {off: 0, tokens: 8, next: 0, allowed: true}, + {off: 0, tokens: 7, next: 0, allowed: true}, + {off: 0, tokens: 6, next: 0, allowed: true}, + {off: 0, tokens: 5, next: 0, allowed: true}, + {off: 0, tokens: 4, next: 0, allowed: true}, + {off: 0, tokens: 3, next: 0, allowed: true}, + {off: 0, tokens: 2, next: 0, allowed: true}, + {off: 0, tokens: 1, next: 0, allowed: true}, + // Tokens exhausted. + {off: 0, tokens: 0, next: asSec(10), allowed: false}, + {off: 0, tokens: 0, next: asSec(10), allowed: false}, + // Bucket refills at 1 per 10sec. + {off: asSec(1), tokens: 0.1, next: asSec(9), allowed: false}, + {off: asSec(5), tokens: 0.5, next: asSec(5), allowed: false}, + {off: asSec(10), tokens: 1, next: 0, allowed: true}, + {off: asSec(10), tokens: 0, next: asSec(10), allowed: false}, + // Back to full. + {off: asSec(110), tokens: 10, next: 0, allowed: true}, + // Doesn't fill back up more than burst. + {off: asSec(130), tokens: 10, next: 0, allowed: true}, + }, + }, { + // Test negative rates do not regenerate tokens. + name: "negative rate, burst 3", + rate: -1.0, + burst: 3, + perLimiterTests: []perLimiterTest{ + {off: 0, tokens: 3, next: 0, allowed: true}, + {off: 0, tokens: 2, next: 0, allowed: true}, + {off: 0, tokens: 1, next: 0, allowed: true}, + // Tokens exhausted and do not regenerate. + {off: 0, tokens: 0, next: Forever, allowed: false}, + {off: 0, tokens: 0, next: Forever, allowed: false}, + }, + }, { + // Test burst size of 0 never allows any events. + name: "1/s, burst 0", + rate: 1.0, + burst: 0, + perLimiterTests: []perLimiterTest{ + {off: 0, tokens: 0, next: Forever, allowed: false}, + {off: 0, tokens: 0, next: Forever, allowed: false}, + }, + }} + + for _, test := range tests { + // Create limiter with the rate and burst values specified by the test. + // Also override the now function so the cur time can be manipulated. + limiter := New(test.rate, test.burst) + limiter.nowFn = mockNowFn + + // Ensure burst rate is returned properly. + if gotBurst := limiter.Burst(); gotBurst != test.burst { + t.Errorf("%q: mismatched burst -- got %v, want %v", test.name, + gotBurst, test.burst) + } + + for i, plTest := range test.perLimiterTests { + // Ensure the expected number of tokens are reported. + mockNow = startTime.Add(plTest.off) + gotTokens := limiter.Tokens() + if gotTokens != plTest.tokens { + t.Errorf("%q-%d: mismatched tokens -- got %v, want %v", + test.name, i, gotTokens, plTest.tokens) + continue + } + + // Ensure the expected duration until the next allowed event is + // reported. + gotNextAllowed := limiter.UntilNextAllowed() + if gotNextAllowed != plTest.next { + t.Errorf("%q-%d: mismatched next allowed -- got %v, want %v", + test.name, i, gotNextAllowed, plTest.next) + continue + } + + // Ensure the expected allowed status is reported. + gotAllowed := limiter.Allow() + if gotAllowed != plTest.allowed { + t.Errorf("%q-%d: mismatched allowed -- got %v, want %v", + test.name, i, gotAllowed, plTest.allowed) + continue + } + } + } +} + +// TestLimiterMaxDuration ensures any calculated durations that would exceed the +// max allowed value by [time.Duration] are clamped to [Forever]. +func TestLimiterMaxDuration(t *testing.T) { + // Create a mock time to control testing. + mockNow := time.Now() + mockNowFn := func() time.Time { return mockNow } + + // Since [time.Duration] is an int64 and is in nanoseconds, the maximum + // duration that can be represented is 9223372036854775807 nanoseconds. + // + // Calculate the equivalent rate such that the time until the next allowed + // event would exceed that and ensure it is properly clamped to [Forever]. + limiter := New(float64(time.Second)/float64(math.MaxInt64), 1) + limiter.nowFn = mockNowFn + if !limiter.Allow() { + t.Fatalf("burst size of 1 should allow the event") + } + if gotDuration := limiter.UntilNextAllowed(); gotDuration != Forever { + t.Fatalf("unexpected duration -- got %v, want %v", gotDuration, Forever) + } + + // Ensure [Forever] is no longer returned once the duration no longer + // exceeds the max after it previously did. + mockNow = mockNow.Add(time.Millisecond) + if gotDuration := limiter.UntilNextAllowed(); gotDuration == Forever { + t.Fatalf("unexpected duration -- got %v, want < %v", gotDuration, Forever) + } +} From 09264c4c2690d0addd81dfbb9f81ab4d18f510fc Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 1 Jun 2026 01:17:56 -0500 Subject: [PATCH 45/51] ratelimit: Add allow example. This adds a basic example of using the rate limiter. --- internal/ratelimit/example_test.go | 52 ++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 internal/ratelimit/example_test.go diff --git a/internal/ratelimit/example_test.go b/internal/ratelimit/example_test.go new file mode 100644 index 000000000..d31359406 --- /dev/null +++ b/internal/ratelimit/example_test.go @@ -0,0 +1,52 @@ +// Copyright (c) 2026 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package ratelimit_test + +import ( + "fmt" + "math" + "time" + + "github.com/decred/dcrd/internal/ratelimit" +) + +// This example demonstrates creating and using a rate limiter that allows 10 +// events per second with periodic bursts of up to 25 events. +func ExampleLimiter_Allow() { + const eventsPerSec = 10 + const burstTokens = 25 + limiter := ratelimit.New(eventsPerSec, burstTokens) + + // Simulate a burst of events. Ordinarily the if statement would be in + // response to events that are externally generated, but the events are + // simulated here with a loop for the purposes of the example. + var numAllowed, numDropped uint64 + for range burstTokens + 10 { + if limiter.Allow() { + // Ordinarily this would process the allowed event. + numAllowed++ + } else { + numDropped++ + } + } + fmt.Printf("num events allowed: %v\n", numAllowed) + fmt.Printf("num events dropped: %v\n", numDropped) + + // There will be no more tokens available since another event will not be + // allowed for another 100 milliseconds (minus however long has already + // elapsed since the final allowed event and reaching this point) given the + // rate is 10 per second and the burst size has been exhausted. + fmt.Printf("num tokens remaining: %v\n", math.Floor(limiter.Tokens())) + const msPerEvent = time.Duration(float64(1000)/eventsPerSec) * time.Millisecond + time.Sleep(msPerEvent) + isAllowed := limiter.UntilNextAllowed() == 0 + fmt.Printf("event allowed after %v: %v\n", msPerEvent, isAllowed) + + // Output: + // num events allowed: 25 + // num events dropped: 10 + // num tokens remaining: 0 + // event allowed after 100ms: true +} From ea7b0dc111aaa9c1ebd411d8524f0f9e141f639b Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 1 Jun 2026 01:17:57 -0500 Subject: [PATCH 46/51] ratelimit: Add README.md. --- internal/ratelimit/README.md | 70 ++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 internal/ratelimit/README.md diff --git a/internal/ratelimit/README.md b/internal/ratelimit/README.md new file mode 100644 index 000000000..2f7dc4732 --- /dev/null +++ b/internal/ratelimit/README.md @@ -0,0 +1,70 @@ +ratelimit +========= + +[![Build Status](https://github.com/decred/dcrd/workflows/Build%20and%20Test/badge.svg)](https://github.com/decred/dcrd/actions) +[![ISC License](https://img.shields.io/badge/license-ISC-blue.svg)](http://copyfree.org) +[![Doc](https://img.shields.io/badge/doc-reference-blue.svg)](https://pkg.go.dev/github.com/decred/dcrd/internal/ratelimit) + +## Overview + +Package `ratelimit` implements a simple concurrent safe token bucket rate +limiter. + +It is ideal for traffic shaping, rate limiting API calls, and controlling +resource usage while supporting controlled bursts of activity. + +Since rate limiting is often in the hot path and exposed to extreme conditions, +the package is designed to be highly efficient, use minimal memory, and support +high concurrency. + +The API is currently primarily aimed at serving use cases where the intention is +to drop events that exceed the rate limit. However, it also provides the minimal +information needed for callers to manually implement blocking behavior until the +next event is allowed if desired. + +Comprehensive tests are included to ensure proper functionality. + +All dcrd code in the main module are expected to use this package over +golang.org/x/time/rate. This package is more efficient, tailored specifically +to the needs of dcrd, and it avoids an extra dependency on the x packages that +have a release policy that conflicts with dcrd's release policy. + +## Creating a Rate Limiter + +Use `New` to create a limiter with a desired rate (events per second) and +burst size. + +For example, a rate of 10.5 with a burst size of 30 would allow an average of +10.5 events per second with periodic bursts of up to 30 events. + +In order to rate limit events to every `x` seconds (versus `x` events per +second), specify the rate scaled by `1/x` and scale the rate accordingly for +other time units. + +For example, to specify a rate of 15 events per minute, the rate would be 900 +`(15*60)` and 450 events every 2 hours would be 0.0625 (`450/(2*3600)`). + +## Using the Limiter + +Call `Allow` to determine whether an event is permitted at the current time. It +returns `true` when an event may proceed and automatically consumes a token. + +For use cases where blocking until the next event is allowed is preferred +instead of merely dropping events, use `UntilNextAllowed` to determine how long +to wait. + +## Querying State + +The current number of tokens in the bucket can be obtained with `Tokens` and the +fixed maximum burst size is available via `Burst`. + +## Examples + +* [Basic Limiter Usage](https://pkg.go.dev/github.com/decred/dcrd/internal/ratelimit#example-Limiter.Allow) + Demonstrates creating and using a rate limiter that allows up to 10 events + per second with periodic bursts of up to 25 events. + +## License + +Package ratelimit is licensed under the [copyfree](http://copyfree.org) ISC +License. From 82ec86f558a34de1d6842f7269b27879bf9aaf90 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 1 Jun 2026 01:17:58 -0500 Subject: [PATCH 47/51] connmgr: Add inbound rate limiting. The current code leans heavily on sysadmins to provide traffic shaping via their OS firewalls. While the OS is typically an ideal place to implement the vast majority of such policies, most users are not sysadmins and are very unlikely to have strong knowledge of what constitute good traffic patterns and general connection needs of the network. Allowing dcrd to handle the vast majority of it with a reasonable default configuration will alleviate the need of any additional OS-level configuration for the vast majority of users and cases. It also means that dcrd can make more intelligent decisions to dynamically adjust behavior based on prevailing network conditions since it is aware of the protocol. Finally, implementing it here does not prevent more sophisticated sysadmins from enforcing additional policies and QoS with their OS firewalls. As a first layer of defense, this implements token bucket rate limiting on a per network group basis. In order to ensure memory remains tightly controlled, an LRU with a 20 minute TTL that enforces a max of 10000 independent rate limiters is used. Each network group is rate limited to an average of 1 connection per 5 seconds and allows periodic controlled bursts up to 3. For IPv4, the network group is set to individual IPv4s. For IPv6, it is set to the /64 prefix that is typically assigned by residential ISPs. In other words, it has the effect of treating all IPv6s in the typical block residential uses are assigned as a single entity since they are the easiest for the same attacker to control. The connection manager is now given a unique per-instance cryptographic key for calculating the inbound group keys. This has the effect of making it unpredictable to external observers which in turn protects against the possibility of network-wide cache poisoning attacks. It also introduces rate limiting on the logging of dropped connections to ensure the drop logging does not become its own DoS vector. The logging now includes a summary of the number of dropped connections if any were suppressed due to the rate limiting. --- internal/connmgr/connmanager.go | 239 ++++++++++++++++++++++++++- internal/connmgr/connmanager_test.go | 3 + 2 files changed, 236 insertions(+), 6 deletions(-) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 8b09a357b..bc1f7f25e 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -21,6 +21,9 @@ import ( "github.com/dchest/siphash" "github.com/decred/dcrd/addrmgr/v4" + "github.com/decred/dcrd/container/lru" + "github.com/decred/dcrd/internal/ratelimit" + "github.com/decred/slog" ) const ( @@ -63,6 +66,41 @@ const ( // defaultTargetOutbound is the default number of outbound connections to // maintain. defaultTargetOutbound = 8 + + // ********************************************************************* + // Constants related to rate limiting inbound connections and associated + // logging. + // ********************************************************************* + + // maxGroupLimiters specifies the maximum number of inbound group limiters + // to cache and is set to target a reasonable balance between memory usage + // and the number of addresses/network groups simultaneously subject to + // direct rate limiting. + // + // maxPerGroupTTL is the time to keep each inbound rate limiter in the cache + // without access before they expire. + // + // These values result in ~2 MiB memory usage including overhead for normal + // operation and a temporary maximum of ~6.5 MiB under sustained worst case + // attack scenarios. + maxGroupLimiters = 10000 + maxPerGroupTTL = 20 * time.Minute + + // groupRateLimit and groupBurstLimit control the inbound rate limiting per + // network group. + // + // These values result in rate limiting each group to an average of one + // connection per five seconds with periodic bursts up to 3. + groupRateLimit = 0.2 + groupBurstLimit = 3 + + // dropLogRateLimit and dropLogBurstLimit define how often dropped + // connections are allowed to be logged before suppression. + // + // These values result in only allowing an average of 1 dropped connection + // per minute to be logged with periodic bursts up to 4. + dropLogRateLimit = float64(1) / 60 + dropLogBurstLimit = 4 ) // ConnectionType specifies the different types of supported connections. @@ -514,6 +552,10 @@ type ConnManager struct { // of addresses that are a part of a lot of different groups. outboundGroups *outboundGroupInfo + // inboundLimiter tracks information about inbound connections and provides + // per group rate limiting. + inboundLimiter *inboundRateLimiter + // ****************************************************************** // The fields below this point are protected by the connection mutex. // ****************************************************************** @@ -1172,6 +1214,181 @@ func inboundStdlibNetAddrToAddrMgrAddr(addr net.Addr) (*addrmgr.NetAddress, erro return stdlibNetAddrToAddrMgrNetAddr(addr) } +// inboundGroupKey represents an inbound network group to use when rate +// limiting addresses. See [inboundRateLimiter.GroupKey]. +type inboundGroupKey struct { + hash0 uint64 + hash1 uint64 +} + +// inboundRateLimiter houses state related to rate limiting inbound connections. +type inboundRateLimiter struct { + // burstLimit is the max burst size for the group rate limiters. It is set + // to [groupBurstLimit] by default. + burstLimit uint32 + + // key is a unique cryptographically random seed used when determining + // inbound network group keys. It ensures different connection manager + // instances produce distinct mappings that are unpredictable to external + // observers. + key [2]uint64 + + // groupLimiters provides distinct rate limiters per inbound group up to the + // max capacity of the LRU. See [inboundRateLimiter.GroupKey] for details + // on how groups are determined. + groupMtx sync.Mutex + groupLimiters *lru.Map[inboundGroupKey, *ratelimit.Limiter] + + // These fields are protected by the log mutex. + // + // logLimiter provides rate limiting for logging of dropped inbound + // connections. It is used in conjunction with [droppedLogs], so even + // though it has its own mutex, it typically will also need to be protected + // by the embedded mutex. + // + // droppedLogs tallies the number of dropped inbound connections during log + // suppression due to exceeding the logging rate. + logMtx sync.Mutex + logLimiter *ratelimit.Limiter + droppedLogs uint64 +} + +// newInboundRateLimiter returns an initialized inboundRateLimiter instance. +func newInboundRateLimiter(csprng csprng) *inboundRateLimiter { + newGroupLims := lru.NewMapWithDefaultTTL[inboundGroupKey, *ratelimit.Limiter] + return &inboundRateLimiter{ + key: [2]uint64{csprng.Uint64(), csprng.Uint64()}, + burstLimit: groupBurstLimit, + groupLimiters: newGroupLims(maxGroupLimiters, maxPerGroupTTL), + logLimiter: ratelimit.New(dropLogRateLimit, dropLogBurstLimit), + } +} + +// GroupKey returns a key that represents an inbound network group to use when +// rate limiting the provided address. +// +// This should not be confused with the outbound group key. They are not the +// same and serve different purposes. +// +// The group for IPv4 is the entire address (/32 prefix) and the typical +// residential block for IPv6 (/64 prefix). +// +// For IPv4, that has the effect of rate limiting individual addresses. +// +// For IPv6, it has the effect of rate limiting all addresses in the typical +// residential blocks assigned by ISPs as a single entity since they are the +// easiest for the same attacker to control. +// +// This function is safe for concurrent access. +func (l *inboundRateLimiter) GroupKey(addr *addrmgr.NetAddress) inboundGroupKey { + var preimage []byte + switch addr.Type { + case addrmgr.IPv4Address: + const bits = 32 + ip, _ := netip.AddrFromSlice(addr.IP) + prefix, _ := ip.Prefix(bits) + prefixBytes := prefix.Addr().As4() + preimage = prefixBytes[:] + + case addrmgr.IPv6Address: + const bits = 64 + ip, _ := netip.AddrFromSlice(addr.IP) + prefix, _ := ip.Prefix(bits) + prefixBytes := prefix.Addr().As16() + preimage = prefixBytes[:] + + case addrmgr.TorV3Address: + // Remote addresses for inbound connections are never Tor addresses, but + // be safe and treat them all as a single group anyway. + preimage = []byte("tor") + + case addrmgr.UnknownAddressType: + fallthrough + default: + // Group all unknown or future address types together for safety, but + // this should never be hit in practice. + preimage = []byte("unknown") + } + + h0, h1 := siphash.Hash128(l.key[0], l.key[1], preimage) + return inboundGroupKey{h0, h1} +} + +// Allow updates the limiter state for the given address and returns whether an +// inbound connection from it is permitted at the current time. +// +// It enforces a per group (prefix based) rate limit. The connection is allowed +// when that rate limit has not been exceeded. +// +// Care must be taken when modifying this method. It is in the critical hot +// path for every inbound connection and must remain fast and tightly control +// memory usage in order to remain resilient under sustained misbehavior. +// +// This function is safe for concurrent access. +func (l *inboundRateLimiter) Allow(addr *addrmgr.NetAddress) bool { + // Rate limit the inbound group. + // + // Either get an existing rate limiter or create a new one when one does not + // already exist. Then put the limiter into the LRU cache unconditionally + // so its TTL is updated. + // + // Adding a new entry may evict another limiter when at max capacity. In + // practice, that case is only realistically possible to hit when under + // a heavy DDoS attack. + groupKey := l.GroupKey(addr) + l.groupMtx.Lock() + limiter, ok := l.groupLimiters.Get(groupKey) + if !ok { + limiter = ratelimit.New(groupRateLimit, l.burstLimit) + } + l.groupLimiters.Put(groupKey, limiter) + l.groupMtx.Unlock() + allowed := limiter.Allow() + + return allowed +} + +// LogDrops consolidates the logic for logging dropped connections with +// throttling. +// +// This function is safe for concurrent access. +func (l *inboundRateLimiter) LogDrops(addr *addrmgr.NetAddress, reason string) { + l.logMtx.Lock() + defer l.logMtx.Unlock() + + // Only log a few dropped connections individually with a periodic summary + // once the rate is exceeded to prevent flooding spam. + if !l.logLimiter.Allow() { + // Report how long no further dropped connections will be logged when + // suppression starts. + if l.droppedLogs == 0 { + nextAllowed := l.logLimiter.UntilNextAllowed() + log.Debugf("Dropped connection from %v: %v -- suppressing drop "+ + "logs for %v", addr, reason, nextAllowed.Round(time.Second)) + + // Report a summary of the total number of suppressed dropped + // connections once messages are allowed again, but only when the + // logging level requires it. + if log.Level() <= slog.LevelDebug { + time.AfterFunc(nextAllowed, func() { + l.logMtx.Lock() + defer l.logMtx.Unlock() + + if dropped := l.droppedLogs; dropped > 0 { + log.Debugf("Dropped %d %s while suppressed", dropped, + pickNoun(dropped, "connection", "connections")) + } + l.droppedLogs = 0 + }) + } + } + + l.droppedLogs++ + return + } + log.Debugf("Dropped connection from %v: %v", addr, reason) +} + // listenHandler accepts incoming connections on a given listener. It must be // run as a goroutine. func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) { @@ -1208,6 +1425,15 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) continue } isWhitelisted := cm.IsWhitelisted(rAddr) + isLoopback := net.IP(rAddr.IP).IsLoopback() + + // Apply rate limiting for inbound connections that are not whitelisted + // or originating from a loopback address. + if !isWhitelisted && !isLoopback && !cm.inboundLimiter.Allow(rAddr) { + cm.inboundLimiter.LogDrops(rAddr, "rate limited") + netConn.Close() + continue + } // Reject connections with the same host:port as any existing pending, // established, or persistent connections. Note that this does NOT @@ -1220,7 +1446,7 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) cm.connMtx.Lock() if err := cm.rejectDuplicateAddr(rAddr); err != nil { cm.connMtx.Unlock() - log.Debugf("Dropped connection from %v: %v", rAddr, err) + cm.inboundLimiter.LogDrops(rAddr, err.Error()) netConn.Close() continue } @@ -1229,7 +1455,7 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) err = cm.rejectMaxConnsPerHost(rAddr, isWhitelisted) if err != nil { cm.connMtx.Unlock() - log.Debugf("Dropped connection from %v: %v", rAddr, err) + cm.inboundLimiter.LogDrops(rAddr, err.Error()) netConn.Close() continue } @@ -1249,10 +1475,10 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) continue } if !acquired { - maxAllowed := cm.cfg.MaxNormalConns - log.Debugf("Dropped connection from %v: a maximum of %d %s is "+ - "allowed", rAddr, maxAllowed, pickNoun(maxAllowed, - "connection", "connections")) + maxConns := cm.cfg.MaxNormalConns + reason := fmt.Sprintf("a maximum of %d %s is allowed", maxConns, + pickNoun(maxConns, "connection", "connections")) + cm.inboundLimiter.LogDrops(rAddr, reason) netConn.Close() continue } @@ -1786,6 +2012,7 @@ func New(cfg *Config) (*ConnManager, error) { totalNormalConnsSem: makeSemaphore(cfg.MaxNormalConns), activeOutboundsSem: makeSemaphore(cfg.TargetOutbound), outboundGroups: newOutboundGroupInfo(csprng), + inboundLimiter: newInboundRateLimiter(csprng), persistent: make(map[uint64]*persistentEntry, MaxPersistent), pending: make(map[uint64]*pendingConnInfo), active: make(map[uint64]*Conn, cfg.TargetOutbound), diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index 2305521ba..f44c130ec 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -295,6 +295,8 @@ func newTestConnManager(t *testing.T, cfg *Config) *ConnManager { cm.csprng = mrand.New(src) // nolint:gosec cm.outboundGroups.key[0] = cm.csprng.Uint64() cm.outboundGroups.key[1] = cm.csprng.Uint64() + cm.inboundLimiter.key[0] = cm.csprng.Uint64() + cm.inboundLimiter.key[1] = cm.csprng.Uint64() return cm } @@ -1709,6 +1711,7 @@ func TestRejectDuplicateConns(t *testing.T) { disconnected <- conn }, }) + cm.inboundLimiter.burstLimit = 4 ctx, _, _ := runConnMgrAsync(t, cm) // Dial a manual connection and wait for it to become pending. From e170d00fd50bf060e1f095173789e710789c00e5 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 1 Jun 2026 01:17:59 -0500 Subject: [PATCH 48/51] connmgr: Add rate limiting tests. This adds tests to ensure the new rate limiting properly limits inbound connections at the intended rate and allows small controlled periodic bursts. It also includes testing of more than one simultaneous inbound group. --- internal/connmgr/connmanager_test.go | 86 +++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index f44c130ec..aa819bead 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -12,6 +12,7 @@ import ( "errors" "flag" "fmt" + "math" mrand "math/rand/v2" "net" "net/netip" @@ -133,6 +134,7 @@ func mustParseAddrPort(addr string) *addrmgr.NetAddress { // addrGenerator houses state for an address generator used to simplify tests. type addrGenerator struct { mtx sync.Mutex + inboundGroupPrefixBits uint outboundGroupPrefixBits uint addr netip.Addr port uint16 @@ -143,6 +145,7 @@ type addrGenerator struct { func newAddrGenerator(baseAddrPort string) *addrGenerator { addrPort := netip.MustParseAddrPort(baseAddrPort) return &addrGenerator{ + inboundGroupPrefixBits: 32, outboundGroupPrefixBits: 16, addr: addrPort.Addr(), port: addrPort.Port(), @@ -220,6 +223,17 @@ func (g *addrGenerator) nextPrefix(prefixBits uint) *addrmgr.NetAddress { return addrmgr.NewNetAddressFromIPPort(g.addr.AsSlice(), g.port, 0) } +// NextInboundGroup advances the generator to the next inbound group IP and +// returns the result. It skips all addresses of the form "x.x.x.0". +// +// An inbound group is determined by a certain number of prefix bits. +func (g *addrGenerator) NextInboundGroup() *addrmgr.NetAddress { + g.mtx.Lock() + defer g.mtx.Unlock() + + return g.nextPrefix(g.inboundGroupPrefixBits) +} + // NextOutboundGroup advances the generator to the next outbound group IP and // returns the result. It skips all addresses of the form "x.x.x.0". // @@ -235,7 +249,7 @@ func (g *addrGenerator) NextOutboundGroup() *addrmgr.NetAddress { // starting base address and port useful throughout the tests. The base address // is a normal routable IPv4 address. func defaultAddrGenerator() *addrGenerator { - return newAddrGenerator(fmt.Sprintf("12.0.0.0:%d", defaultTestP2PPort)) + return newAddrGenerator(fmt.Sprintf("12.1.1.0:%d", defaultTestP2PPort)) } // defaultTestAddr returns a default address to use throughout the tests. It is @@ -2168,3 +2182,73 @@ func TestOutboundGroups(t *testing.T) { len(groups), targetOutbound) } } + +// TestInboundRateLimiting ensures the connection manager rate limits inbound +// connections as expected. It includes tests for normal rate limiting and +// bursts. +func TestInboundRateLimiting(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + inboundConns := make(chan *Conn) + listener := defaultMockListener() + cm := newTestConnManager(t, &Config{ + Listeners: []net.Listener{listener}, + MaxConnsPerHost: 100, + OnAccept: func(conn *Conn) { + inboundConns <- conn + }, + Dial: mockDialer, + }) + runConnMgrAsync(t, cm) + + // Ensure exactly the max allowed burst of inbound connections from the + // same address are accepted. + addrGen := defaultAddrGenerator() + addrGen.Next() + for range groupBurstLimit { + go listener.Connect(addrGen.NextPort()) + assertConnReceived(t, inboundConns, 0, ConnTypeInbound).Close() + } + assertConnManagerInternalState(t, cm) + + // Ensure connections from the same address are now rate limited. + for range 3 { + go listener.Connect(addrGen.NextPort()) + } + assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cm) + + // Wait just long enough for the next connection to be allowed and + // ensure it is. + perConnSecs := time.Duration(math.Ceil(1/groupRateLimit)) * time.Second + time.Sleep(perConnSecs - connTestNonReceiveTimeout) + go listener.Connect(addrGen.NextPort()) + assertConnReceived(t, inboundConns, 0, ConnTypeInbound).Close() + assertConnManagerInternalState(t, cm) + + // Wait just long enough to reset the burst tokens and ensure another + // burst of inbound connections from the same address are accepted. + time.Sleep(groupBurstLimit * perConnSecs) + for range groupBurstLimit { + go listener.Connect(addrGen.NextPort()) + assertConnReceived(t, inboundConns, 0, ConnTypeInbound).Close() + } + assertConnManagerInternalState(t, cm) + + // Ensure the next inbound group is not rate limited and independently + // allows the max allowed burst. + addrGen.NextInboundGroup() + for range groupBurstLimit { + go listener.Connect(addrGen.NextPort()) + assertConnReceived(t, inboundConns, 0, ConnTypeInbound).Close() + } + assertConnManagerInternalState(t, cm) + + // Ensure connections from the same address in the new inbound group are + // now rate limited. + for range 3 { + go listener.Connect(addrGen.NextPort()) + } + assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cm) + }) +} From f895318947268cd376b9e0016275b9e7a53aaf50 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 1 Jun 2026 01:18:00 -0500 Subject: [PATCH 49/51] connmgr: Add inbound anti-flooding measures. Per-group rate limiting is a great first line of defense, however, it is not able to effectively handle prefix spraying by more sophisticated attackers who control large numbers of prefixes because the number of individual rate limiters necessarily must be limited as well to prevent unchecked memory growth. This makes such attacks significantly more difficult by implementing anti-flood measures which detect them and dynamically adjust the limiting to actively combat it. The detection uses an extremely efficient sliding window approach that tracks allowed connections over the past minute via a ring buffer. Once flooding is detected, the prefixes for both IPv4 and IPv6 are coarsened to treat larger blocks of addresses that attackers are most likely to control as a single entity thereby reducing the pressure on the cache. IPv4 is modified to /24 and IPv6 to /56. Second, probabilistic connection dropping that uses a quadratic rational S-curve to dynamically increase the drop probability according to the intensity of the flooding is activated. This approach was chosen because it ultimately still allows honest peers to periodically make it through even in the face of a sustained DDoS attack whereas more naive approaches do not. --- internal/connmgr/connmanager.go | 220 +++++++++++++++++++++++++++++--- internal/connmgr/csprng.go | 9 ++ 2 files changed, 209 insertions(+), 20 deletions(-) diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index bc1f7f25e..e7d1cbce8 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -101,6 +101,27 @@ const ( // per minute to be logged with periodic bursts up to 4. dropLogRateLimit = float64(1) / 60 dropLogBurstLimit = 4 + + // floodLow is the number of allowed connection attempts in the last minute + // to consider active low-intensity connection flooding. ~5 average + // attempts per second. + // + // floodHighFactor is the multiple of [floodLow] to consider active + // high-intensity connection flooding. + floodLow = 5 * 60 + floodHighFactor = 3 + + // These values tune how often connections are probabilistically dropped + // during active flooding. A quadratic rational S-curve is used. See + // [inboundRateLimiter.ShouldDropProbabilistic]. + // + // floodMinDropProb and floodMaxDropProb are the minimum and maximum + // probability with which to drop connections during active flooding. + // + // floodRamp is the normalized intensity to start rapid growth. + floodMinDropProb = 0.2 + floodMaxDropProb = 0.85 + floodRamp = 0.1 ) // ConnectionType specifies the different types of supported connections. @@ -553,7 +574,7 @@ type ConnManager struct { outboundGroups *outboundGroupInfo // inboundLimiter tracks information about inbound connections and provides - // per group rate limiting. + // per group rate limiting with anti-flood protection. inboundLimiter *inboundRateLimiter // ****************************************************************** @@ -1221,7 +1242,8 @@ type inboundGroupKey struct { hash1 uint64 } -// inboundRateLimiter houses state related to rate limiting inbound connections. +// inboundRateLimiter houses state related to rate limiting inbound connections +// and flood detection. type inboundRateLimiter struct { // burstLimit is the max burst size for the group rate limiters. It is set // to [groupBurstLimit] by default. @@ -1239,6 +1261,27 @@ type inboundRateLimiter struct { groupMtx sync.Mutex groupLimiters *lru.Map[inboundGroupKey, *ratelimit.Limiter] + // These fields are protected by the flood mutex. + // + // attempts houses a sliding window of the number of allowed connections per + // second over the previous minute as a ring buffer. + // + // attemptsStart is the current unix time for the head of the ring buffer. + // + // totalAttempts is the sum of all attempts in the window. + floodMtx sync.RWMutex + attempts [60]uint32 + attemptsStart int64 + totalAttempts uint64 + + // flooding tracks whether or not flooding mode is active. It is an atomic + // so that it is independently safe to read concurrently without any + // additional mutex. + // + // Nevertheless, it is only modified under [floodMtx] since it depends on + // [totalAttempts]. + flooding atomic.Bool + // These fields are protected by the log mutex. // // logLimiter provides rate limiting for logging of dropped inbound @@ -1270,28 +1313,39 @@ func newInboundRateLimiter(csprng csprng) *inboundRateLimiter { // This should not be confused with the outbound group key. They are not the // same and serve different purposes. // -// The group for IPv4 is the entire address (/32 prefix) and the typical -// residential block for IPv6 (/64 prefix). +// By default, the group for IPv4 is the entire address (/32 prefix) and the +// typical residential block for IPv6 (/64 prefix). When flooding is detected, +// the groups are coarsened to /24 for IPv4 and /56 for IPv6 in order to +// increase the cost of prefix spraying. // -// For IPv4, that has the effect of rate limiting individual addresses. +// For IPv4, that has the effect of rate limiting individual addresses during +// normal conditions and dynamically adjusting to rate limit the blocks of IPv4s +// that are the easiest for the same attacker to control as a single entity. // -// For IPv6, it has the effect of rate limiting all addresses in the typical -// residential blocks assigned by ISPs as a single entity since they are the -// easiest for the same attacker to control. +// Similarly, for IPv6, it has the effect of rate limiting all addresses in the +// typical residential blocks assigned by ISPs as a single entity and +// dynamically adjusting to encompass the full range typically assigned by an +// ISP since they are the easiest for the same attacker to control. // // This function is safe for concurrent access. func (l *inboundRateLimiter) GroupKey(addr *addrmgr.NetAddress) inboundGroupKey { var preimage []byte switch addr.Type { case addrmgr.IPv4Address: - const bits = 32 + bits := 32 + if l.flooding.Load() { + bits = 24 + } ip, _ := netip.AddrFromSlice(addr.IP) prefix, _ := ip.Prefix(bits) prefixBytes := prefix.Addr().As4() preimage = prefixBytes[:] case addrmgr.IPv6Address: - const bits = 64 + bits := 64 + if l.flooding.Load() { + bits = 56 + } ip, _ := netip.AddrFromSlice(addr.IP) prefix, _ := ip.Prefix(bits) prefixBytes := prefix.Addr().As16() @@ -1314,15 +1368,82 @@ func (l *inboundRateLimiter) GroupKey(addr *addrmgr.NetAddress) inboundGroupKey return inboundGroupKey{h0, h1} } +// recordAttempt updates the flood state to decay stale data and records +// attempts that were not rate limited by prefix. +// +// The flood detection scheme is a simple sliding window over the prior minute +// implemented as an efficient ring buffer. +func (l *inboundRateLimiter) recordAttempt(rateLimited bool) { + // buckets is the number of one second buckets in the sliding window. + const buckets = 60 + + now := time.Now() + nowUnix := now.Unix() + idx := nowUnix % buckets + + l.floodMtx.Lock() + defer l.floodMtx.Unlock() + + // Advance the sliding window to the current time. This approach is + // extremely efficient and still provides excellent properties such as + // deterministic behavior, good burst detection, and reasonable reaction + // time. + // + // This entire section only runs a max of once per second. + // + // In other words, in a real flooding scenario where this might be called + // hundreds or thousands of times per second, it is effectively a noop. + // + // On the other end of the spectrum, when more than 1 min has elapsed since + // the last update, it reduces to a tiny memcpy to zero the array. + // + // Otherwise, it is somewhere between clearing 1 to 59 buckets which only + // involves very cheap calculations. + if l.attemptsStart != nowUnix { + numExpired := nowUnix - l.attemptsStart + if numExpired >= buckets { + for i := range l.attempts { + l.attempts[i] = 0 + } + l.totalAttempts = 0 + } else if numExpired > 0 { + tail := l.attemptsStart + 1 + for i := range numExpired { + oldIdx := (tail + i) % buckets + l.totalAttempts -= uint64(l.attempts[oldIdx]) + l.attempts[oldIdx] = 0 + } + } + l.attemptsStart = nowUnix + } + + // Record allowed attempts. + if !rateLimited { + l.attempts[idx]++ + l.totalAttempts++ + } + + // Activate flooding mode if there have been enough recent allowed attempts + // in the last minute. Deactivate otherwise. + l.flooding.Store(l.totalAttempts > floodLow) +} + // Allow updates the limiter state for the given address and returns whether an // inbound connection from it is permitted at the current time. // -// It enforces a per group (prefix based) rate limit. The connection is allowed -// when that rate limit has not been exceeded. +// This is the first line of defense against inbound attacks. +// +// It enforces a per group (prefix based) rate limit that is dynamically +// adjusted depending on whether flooding is detected. The connection is +// allowed when that rate limit has not been exceeded. +// +// It also records allowed attempts for flood detection and updates the flood +// state when needed. // // Care must be taken when modifying this method. It is in the critical hot // path for every inbound connection and must remain fast and tightly control -// memory usage in order to remain resilient under sustained misbehavior. +// memory usage in order to remain resilient under sustained misbehavior and +// flooding. // // This function is safe for concurrent access. func (l *inboundRateLimiter) Allow(addr *addrmgr.NetAddress) bool { @@ -1334,7 +1455,7 @@ func (l *inboundRateLimiter) Allow(addr *addrmgr.NetAddress) bool { // // Adding a new entry may evict another limiter when at max capacity. In // practice, that case is only realistically possible to hit when under - // a heavy DDoS attack. + // a heavy DDoS attack. The anti-flooding measures help combat that. groupKey := l.GroupKey(addr) l.groupMtx.Lock() limiter, ok := l.groupLimiters.Get(groupKey) @@ -1345,9 +1466,57 @@ func (l *inboundRateLimiter) Allow(addr *addrmgr.NetAddress) bool { l.groupMtx.Unlock() allowed := limiter.Allow() + // Tally attempts that were not rate limited and periodically update the + // state related to detecting active flooding. + l.recordAttempt(!allowed) + return allowed } +// ShouldDropProbabilistic returns whether or not a connection should be +// probabilistically dropped. +// +// No probabilistic dropping is applied unless active flooding is detected and +// the probability of dropping a connection is increased according to the +// intensity of flooding per an S-curve. +// +// This acts as a second layer of defense against any inbound attackers that +// make it through the initial rate limiting. +// +// This function is safe for concurrent access. +func (l *inboundRateLimiter) ShouldDropProbabilistic(csprng csprng) bool { + // Don't probabilistically drop anything unless flooding has been detected. + if !l.flooding.Load() { + return false + } + + l.floodMtx.RLock() + totalAttempts := l.totalAttempts + l.floodMtx.RUnlock() + + // Scale the probability of dropping a connection in accordance with the + // intensity of flooding using a smooth quadratic rational S-curve between + // the min and max drop probability. + // + // This provides a smooth non-linear curve to apply increasing backpressure + // as flooding intensifies. + // + // The equation is: + // + // P(x) = minP + (1+r)*(maxP-minP)*x^2 / (r+x^2) + // + // Where x is the normalized flood intensity in the range [0, 1] and r is + // normalized intensity to start rapid growth. + const factor = (1 + floodRamp) * (floodMaxDropProb - floodMinDropProb) + const floodHigh = floodLow * floodHighFactor + norm := float64(max(totalAttempts, floodLow)-floodLow) / float64(floodHigh) + norm = min(norm, 1.0) + nSquared := norm * norm + prob := floodMinDropProb + factor*nSquared/(floodRamp+nSquared) + + return csprng.Float64() < prob +} + // LogDrops consolidates the logic for logging dropped connections with // throttling. // @@ -1381,6 +1550,7 @@ func (l *inboundRateLimiter) LogDrops(addr *addrmgr.NetAddress, reason string) { l.droppedLogs = 0 }) } + return } l.droppedLogs++ @@ -1427,12 +1597,22 @@ func (cm *ConnManager) listenHandler(ctx context.Context, listener net.Listener) isWhitelisted := cm.IsWhitelisted(rAddr) isLoopback := net.IP(rAddr.IP).IsLoopback() - // Apply rate limiting for inbound connections that are not whitelisted - // or originating from a loopback address. - if !isWhitelisted && !isLoopback && !cm.inboundLimiter.Allow(rAddr) { - cm.inboundLimiter.LogDrops(rAddr, "rate limited") - netConn.Close() - continue + // Apply rate limiting and anti-flooding measures for inbound + // connections that are not whitelisted or originating from a loopback + // address. + if !isWhitelisted && !isLoopback { + if !cm.inboundLimiter.Allow(rAddr) { + cm.inboundLimiter.LogDrops(rAddr, "rate limited") + netConn.Close() + continue + } + + if cm.inboundLimiter.ShouldDropProbabilistic(cm.csprng) { + cm.inboundLimiter.LogDrops(rAddr, "probabilistically blocked "+ + "during flood") + netConn.Close() + continue + } } // Reject connections with the same host:port as any existing pending, diff --git a/internal/connmgr/csprng.go b/internal/connmgr/csprng.go index 2aae55268..75bcc6a58 100644 --- a/internal/connmgr/csprng.go +++ b/internal/connmgr/csprng.go @@ -16,6 +16,7 @@ import ( type csprng interface { Uint64() uint64 Uint64N(n uint64) uint64 + Float64() float64 } // lockingPRNG wraps an instance of [rand.PRNG] with a mutex so it can be used @@ -41,6 +42,14 @@ func (p *lockingPRNG) Uint64N(n uint64) uint64 { return p.prng.Uint64N(n) } +// Float64 returns a random float64 in the half-open interval [0.0,1.0). +func (p *lockingPRNG) Float64() float64 { + p.Lock() + defer p.Unlock() + + return float64(p.prng.Uint64N(1<<53)) / (1 << 53) +} + // Read fills s with len(s) of cryptographically-secure random bytes. It never // errors. func (p *lockingPRNG) Read(s []byte) { From 7b47062cf1b4ef7f8da16d5335de267a8a1ae67f Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 1 Jun 2026 01:18:00 -0500 Subject: [PATCH 50/51] connmgr: Add anti-flood tests. This updates the inbound rate limiting test to ensure the new anti-flood measures properly detect floods, limit inbound connections with the coarsened prefixes and probabilistically drop any that make through when flooding is active, and properly returns to normal behavior once flooding stops. --- internal/connmgr/connmanager_test.go | 134 ++++++++++++++++++++++++++- 1 file changed, 132 insertions(+), 2 deletions(-) diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index aa819bead..53ae73000 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -223,6 +223,19 @@ func (g *addrGenerator) nextPrefix(prefixBits uint) *addrmgr.NetAddress { return addrmgr.NewNetAddressFromIPPort(g.addr.AsSlice(), g.port, 0) } +// SetFlooded sets the address generator to use inbound prefix groups per the +// given flooded state. +func (g *addrGenerator) SetFlooded(active bool) { + g.mtx.Lock() + defer g.mtx.Unlock() + + bits := uint(32) + if active { + bits = 24 + } + g.inboundGroupPrefixBits = bits +} + // NextInboundGroup advances the generator to the next inbound group IP and // returns the result. It skips all addresses of the form "x.x.x.0". // @@ -410,6 +423,33 @@ func assertInternalOutboundGroupState(t *testing.T, cm *ConnManager) { } } +// assertInternalRateLimiterState ensures the internal rate limiting and +// flooding state of the passed connection manager instance is coherent. +func assertInternalRateLimiterState(t *testing.T, cm *ConnManager) { + t.Helper() + + l := cm.inboundLimiter + l.floodMtx.Lock() + defer l.floodMtx.Unlock() + + // Assert the total attempts counts matches the value obtained from manually + // tallying it. + var totalAttempts uint64 + for _, attempts := range l.attempts { + totalAttempts += uint64(attempts) + } + if l.totalAttempts != totalAttempts { + t.Fatalf("mismatched total attempts count: %d != %d", l.totalAttempts, + totalAttempts) + } + + // Assert the flooding flag status is correct. + flooding := l.totalAttempts > floodLow + if got := l.flooding.Load(); got != flooding { + t.Fatalf("mismatched flooding flag: %v != %v", got, flooding) + } +} + // assertConnManagerInternalState ensures the internal state of the passed // connection manager instance is coherent. func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { @@ -417,6 +457,7 @@ func assertConnManagerInternalState(t *testing.T, cm *ConnManager) { assertInternalConnState(t, cm) assertInternalOutboundGroupState(t, cm) + assertInternalRateLimiterState(t, cm) } // assertConnManagerCleanShutdown ensures the internal state of the passed @@ -666,6 +707,16 @@ func assertNoConnReceived(t *testing.T, ch <-chan *Conn) { assertNoConnReceivedTimeout(t, ch, connTestNonReceiveTimeout) } +// assertFlooded ensures the flooding status of the connection manager is the +// given value. +func assertFlooded(t *testing.T, cm *ConnManager, want bool) { + t.Helper() + + if got := cm.inboundLimiter.flooding.Load(); got != want { + t.Fatalf("flooding status %v is not %v", got, want) + } +} + // TestConnectMode tests that the connection manager works in the connect mode. // // In connect mode, automatic connections are disabled, so test that connections @@ -2184,8 +2235,8 @@ func TestOutboundGroups(t *testing.T) { } // TestInboundRateLimiting ensures the connection manager rate limits inbound -// connections as expected. It includes tests for normal rate limiting and -// bursts. +// connections behavior as expected. It includes tests for normal rate +// limiting, bursts, flooding, and flood recovery. func TestInboundRateLimiting(t *testing.T) { synctest.Test(t, func(t *testing.T) { inboundConns := make(chan *Conn) @@ -2250,5 +2301,84 @@ func TestInboundRateLimiting(t *testing.T) { } assertNoConnReceived(t, inboundConns) assertConnManagerInternalState(t, cm) + + // Make exactly enough allowed connections to reach one prior to the low + // intensity flood cutover. Ensure all connections are accepted and + // flooding is not detected. + // + // Wait long enough to reset the flood state first to simplify the + // calcs. + // + // Then, advance time such that the connections fill up the entire + // sliding window used to track allowed connections for flood detection. + time.Sleep(60 * time.Second) + for i := range floodLow { + if i%groupBurstLimit == 0 { + addrGen.NextInboundGroup() + } + go listener.Connect(addrGen.NextPort()) + assertConnReceived(t, inboundConns, 0, ConnTypeInbound).Close() + time.Sleep(60 * time.Second / (floodLow + 1)) + } + assertConnManagerInternalState(t, cm) + assertFlooded(t, cm, false) + + // Ensure the next connection activates flooding mode. + // + // The current inbound group might be rate limited depending on the + // actual values above and flooding mode will have coarsened it, so use + // an address from the next inbound group with the coarsened prefix. + // + // The connection may be probabilistically dropped due to flooding, so + // it may or may not be accepted. + addrGen.SetFlooded(true) + go listener.Connect(addrGen.NextInboundGroup()) + select { + case conn := <-inboundConns: + assertConnType(t, conn, ConnTypeInbound) + conn.Close() + case <-time.After(connTestNonReceiveTimeout): + } + assertConnManagerInternalState(t, cm) + assertFlooded(t, cm, true) + + // Make enough connections from the same address to hit the group burst + // limit. + // + // The connections may be probabilistically dropped due to flooding, so + // they may or may not be accepted. + for range groupBurstLimit { + go listener.Connect(addrGen.NextPort()) + select { + case conn := <-inboundConns: + assertConnType(t, conn, ConnTypeInbound) + conn.Close() + case <-time.After(connTestNonReceiveTimeout): + } + } + assertConnManagerInternalState(t, cm) + assertFlooded(t, cm, true) + + // Ensure the next few addresss in the same inbound group are now rate + // limited. Use a value high enough to ensure they aren't just being + // dropped probabilistically. + for range int(math.Ceil(groupBurstLimit/(1-floodMaxDropProb))) * 2 { + go listener.Connect(addrGen.Next()) + } + assertNoConnReceived(t, inboundConns) + assertConnManagerInternalState(t, cm) + + // Ensure flood mode is deactivated once flooding subsides. + // + // Wait for an entire window to pass with no connections and then ensure + // the same address can connect up to the burst limit again. + time.Sleep(time.Minute) + addrGen.SetFlooded(false) + for range groupBurstLimit { + go listener.Connect(addrGen.NextPort()) + assertConnReceived(t, inboundConns, 0, ConnTypeInbound).Close() + } + assertConnManagerInternalState(t, cm) + assertFlooded(t, cm, false) }) } From 392ca861a205fc3474bdf523f6bd53e9ef658c7c Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Mon, 1 Jun 2026 01:18:01 -0500 Subject: [PATCH 51/51] connmgr: Update README.md for inbound limiting. --- internal/connmgr/README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/internal/connmgr/README.md b/internal/connmgr/README.md index 9321154ac..a03421f18 100644 --- a/internal/connmgr/README.md +++ b/internal/connmgr/README.md @@ -13,7 +13,8 @@ logic. It handles all general connection lifecycle concerns such as accepting inbound connections, automatically maintaining a set number of outbound connections, -maintaining persistent connections, enforcing limits, and preventing duplicates. +maintaining persistent connections, preventing duplicates, and enforcing +multiple layers of connection limits and anti-abuse protections. The design has a strong emphasis on reliability, readability, and efficiency under high connection load while also aiming to provide an ergonomic API. @@ -23,6 +24,12 @@ The following is a brief overview of the key features: - Inbound listening - Accepts inbound connections on provided `Listeners` - Uses connection shedding for rejected inbound connections + - Provides token bucket rate limiting on a per network group basis + - Anti-flood protection + - Detects floods based on allowed connection attempts + - Dynamically coarsens network group rate limiting during flooding + - Probabilistically drops connections when flooding is active via an S-curve + - Rate limits logging of dropped connections - Automatic outbound maintenance - Maintains up to `TargetOutbound` normal outbound connections via a provided address source (`GetNewAddress`)