diff --git a/cmd/bee/cmd/cmd.go b/cmd/bee/cmd/cmd.go index b9dbaa12f0e..4f75d40a2da 100644 --- a/cmd/bee/cmd/cmd.go +++ b/cmd/bee/cmd/cmd.go @@ -88,6 +88,7 @@ const ( optionAutoTLSDomain = "autotls-domain" optionAutoTLSRegistrationEndpoint = "autotls-registration-endpoint" optionAutoTLSCAEndpoint = "autotls-ca-endpoint" + optionUseSIMD = "use-simd-hashing" // blockchain-rpc optionNameBlockchainRpcEndpoint = "blockchain-rpc-endpoint" @@ -337,6 +338,7 @@ func (c *command) setAllFlags(cmd *cobra.Command) { cmd.Flags().String(optionAutoTLSDomain, p2pforge.DefaultForgeDomain, "autotls domain") cmd.Flags().String(optionAutoTLSRegistrationEndpoint, p2pforge.DefaultForgeEndpoint, "autotls registration endpoint") cmd.Flags().String(optionAutoTLSCAEndpoint, p2pforge.DefaultCAEndpoint, "autotls certificate authority endpoint") + cmd.Flags().Bool(optionUseSIMD, false, "use SIMD BMT hasher (available only on linux amd64 platforms)") } // preRun must be called from every command's PreRunE, after which c.logger is diff --git a/cmd/bee/cmd/start.go b/cmd/bee/cmd/start.go index 469eb207f8f..fb42347b2f6 100644 --- a/cmd/bee/cmd/start.go +++ b/cmd/bee/cmd/start.go @@ -15,6 +15,7 @@ import ( "os" "os/signal" "path/filepath" + "runtime" "strings" "sync/atomic" "syscall" @@ -22,8 +23,11 @@ import ( "github.com/ethersphere/bee/v2" "github.com/ethersphere/bee/v2/pkg/accesscontrol" + "github.com/ethersphere/bee/v2/pkg/bmt" + "github.com/ethersphere/bee/v2/pkg/bmtpool" chaincfg "github.com/ethersphere/bee/v2/pkg/config" "github.com/ethersphere/bee/v2/pkg/crypto" + "github.com/ethersphere/bee/v2/pkg/keccak" "github.com/ethersphere/bee/v2/pkg/keystore" filekeystore "github.com/ethersphere/bee/v2/pkg/keystore/file" memkeystore "github.com/ethersphere/bee/v2/pkg/keystore/mem" @@ -266,6 +270,21 @@ func buildBeeNode(ctx context.Context, c *command, cmd *cobra.Command, logger lo neighborhoodSuggester = c.config.GetString(optionNameNeighborhoodSuggester) } + useSIMD := c.config.GetBool(optionUseSIMD) + if useSIMD { + if runtime.GOOS != "linux" || runtime.GOARCH != "amd64" { + return nil, fmt.Errorf("SIMD hashing requires linux/amd64 (this build is %s/%s)", runtime.GOOS, runtime.GOARCH) + } + if !keccak.HasSIMD() { + return nil, errors.New("SIMD hashing requires a CPU with AVX2 or AVX-512; this CPU has neither") + } + bmt.SetSIMDOptIn(true) + // Rebuild the global bmtpool instance so the new SIMDOptIn value + // is reflected in the pool created for hot-path BMT hashing. + bmtpool.Rebuild() + logger.Info("SIMD hashing enabled", "batch_width", keccak.BatchWidth(), "avx512", keccak.HasAVX512()) + } + b, err := node.NewBee(ctx, c.config.GetString(optionNameP2PAddr), signerConfig.publicKey, signerConfig.signer, networkID, logger, signerConfig.libp2pPrivateKey, signerConfig.pssPrivateKey, signerConfig.session, &node.Options{ Addr: c.config.GetString(optionNameP2PAddr), AllowPrivateCIDRs: c.config.GetBool(optionNameAllowPrivateCIDRs), diff --git a/go.mod b/go.mod index 6885b1805d0..c8b0cc496f5 100644 --- a/go.mod +++ b/go.mod @@ -103,7 +103,7 @@ require ( github.com/ipfs/go-log/v2 v2.6.0 // indirect github.com/jackpal/go-nat-pmp v1.0.2 // indirect github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect - github.com/klauspost/cpuid/v2 v2.2.10 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 github.com/koron/go-ssdp v0.0.6 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/libdns/libdns v0.2.2 // indirect diff --git a/go.sum b/go.sum index d0fc255af2a..a8510646f91 100644 --- a/go.sum +++ b/go.sum @@ -541,8 +541,8 @@ github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYW github.com/klauspost/cpuid v0.0.0-20170728055534-ae7887de9fa5/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/cpuid/v2 v2.0.4/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.6/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= -github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6/go.mod h1:+ZoRqAPRLkC4NPOvfYeR5KNOrY6TD+/sAC3HXPZgDYg= github.com/klauspost/pgzip v1.0.2-0.20170402124221-0bf5dcad4ada/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= github.com/klauspost/reedsolomon v1.11.8 h1:s8RpUW5TK4hjr+djiOpbZJB4ksx+TdYbRH7vHQpwPOY= diff --git a/pkg/bmt/benchmark_test.go b/pkg/bmt/benchmark_test.go index aa6cc1703dd..561909b1985 100644 --- a/pkg/bmt/benchmark_test.go +++ b/pkg/bmt/benchmark_test.go @@ -26,7 +26,16 @@ func BenchmarkBMT(b *testing.B) { b.Run(fmt.Sprintf("%v_size_%v", "REF", size), func(b *testing.B) { benchmarkRefHasher(b, size) }) - b.Run(fmt.Sprintf("%v_size_%v", "BMT", size), func(b *testing.B) { + b.Run(fmt.Sprintf("%v_size_%v", "BMT_Goroutine", size), func(b *testing.B) { + prev := bmt.SIMDOptIn() + bmt.SetSIMDOptIn(false) + defer bmt.SetSIMDOptIn(prev) + benchmarkBMT(b, size) + }) + b.Run(fmt.Sprintf("%v_size_%v", "BMT_SIMD", size), func(b *testing.B) { + prev := bmt.SIMDOptIn() + bmt.SetSIMDOptIn(true) + defer bmt.SetSIMDOptIn(prev) benchmarkBMT(b, size) }) } @@ -81,13 +90,15 @@ func benchmarkBMTBaseline(b *testing.B, _ int) { } } -// benchmarks BMT Hasher +// benchmarks BMT Hasher with whichever pool the current SIMDOptIn selects. +// Tree is constructed inside the benchmark so the SIMDOptIn value at call time +// determines which implementation gets exercised. func benchmarkBMT(b *testing.B, n int) { b.Helper() testData := testutil.RandBytesWithSeed(b, 4096, seed) - pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, testSegmentCount, testPoolSize)) + pool := bmt.NewPool(bmt.NewConf(testSegmentCount, testPoolSize)) h := pool.Get() defer pool.Put(h) @@ -106,7 +117,7 @@ func benchmarkPool(b *testing.B, poolsize int) { testData := testutil.RandBytesWithSeed(b, 4096, seed) - pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, testSegmentCount, poolsize)) + pool := bmt.NewPool(bmt.NewConf(testSegmentCount, poolsize)) cycles := 100 b.ReportAllocs() diff --git a/pkg/bmt/bmt.go b/pkg/bmt/bmt.go index ae3c5e95421..fb77650d17f 100644 --- a/pkg/bmt/bmt.go +++ b/pkg/bmt/bmt.go @@ -11,52 +11,17 @@ import ( "github.com/ethersphere/bee/v2/pkg/swarm" ) -var _ Hash = (*Hasher)(nil) - var ( zerospan = make([]byte, 8) zerosection = make([]byte, 64) ) -// Hasher is a reusable hasher for fixed maximum size chunks representing a BMT -// It reuses a pool of trees for amortised memory allocation and resource control, -// and supports order-agnostic concurrent segment writes and section (double segment) writes -// as well as sequential read and write. -// -// The same hasher instance must not be called concurrently on more than one chunk. -// -// The same hasher instance is synchronously reusable. -// -// Sum gives back the tree to the pool and guaranteed to leave -// the tree and itself in a state reusable for hashing a new chunk. -type Hasher struct { - *Conf // configuration - bmt *tree // prebuilt BMT resource for flowcontrol and proofs - size int // bytes written to Hasher since last Reset() - pos int // index of rightmost currently open segment - result chan []byte // result channel - errc chan error // error channel - span []byte // The span of the data subsumed under the chunk -} - -// NewHasher gives back an instance of a Hasher struct -func NewHasher(hasherFact func() hash.Hash) *Hasher { - conf := NewConf(hasherFact, swarm.BmtBranches, 32) +// SIMDOptIn reports whether the SIMD hasher has been opted into. +var SIMDOptIn func() bool = func() bool { return false } - return &Hasher{ - Conf: conf, - result: make(chan []byte), - errc: make(chan error, 1), - span: make([]byte, SpanSize), - bmt: newTree(conf.maxSize, conf.depth, conf.hasher), - } -} - -// Capacity returns the maximum amount of bytes that will be processed by this hasher implementation. -// since BMT assumes a balanced binary tree, capacity it is always a power of 2 -func (h *Hasher) Capacity() int { - return h.maxSize -} +// SetSIMDOptIn sets the SIMD opt-in flag. Intended to be called once during +// startup before the first NewPool call (cmd/bee calls it after flag parsing). +func SetSIMDOptIn(b bool) { SIMDOptIn = func() bool { return b } } // LengthToSpan creates a binary data span size representation. // It is required for calculating the BMT hash. @@ -71,211 +36,6 @@ func LengthFromSpan(span []byte) uint64 { return binary.LittleEndian.Uint64(span) } -// SetHeaderInt64 sets the metadata preamble to the little endian binary representation of int64 argument for the current hash operation. -func (h *Hasher) SetHeaderInt64(length int64) { - binary.LittleEndian.PutUint64(h.span, uint64(length)) -} - -// SetHeader sets the metadata preamble to the span bytes given argument for the current hash operation. -func (h *Hasher) SetHeader(span []byte) { - copy(h.span, span) -} - -// Size returns the digest size of the hash -func (h *Hasher) Size() int { - return h.segmentSize -} - -// BlockSize returns the optimal write size to the Hasher -func (h *Hasher) BlockSize() int { - return 2 * h.segmentSize -} - -// Hash returns the BMT root hash of the buffer and an error -// using Hash presupposes sequential synchronous writes (io.Writer interface). -func (h *Hasher) Hash(b []byte) ([]byte, error) { - if h.size == 0 { - return doHash(h.hasher(), h.span, h.zerohashes[h.depth]) - } - copy(h.bmt.buffer[h.size:], zerosection) - // write the last section with final flag set to true - go h.processSection(h.pos, true) - select { - case result := <-h.result: - return doHash(h.hasher(), h.span, result) - case err := <-h.errc: - return nil, err - } -} - -// Sum returns the BMT root hash of the buffer, unsafe version of Hash -func (h *Hasher) Sum(b []byte) []byte { - s, _ := h.Hash(b) - return s -} - -// Write calls sequentially add to the buffer to be hashed, -// with every full segment calls processSection in a go routine. -func (h *Hasher) Write(b []byte) (int, error) { - l := len(b) - maxVal := h.maxSize - h.size - if l > maxVal { - l = maxVal - } - copy(h.bmt.buffer[h.size:], b) - secsize := 2 * h.segmentSize - from := h.size / secsize - h.size += l - to := h.size / secsize - if l == maxVal { - to-- - } - h.pos = to - for i := from; i < to; i++ { - go h.processSection(i, false) - } - return l, nil -} - -// Reset prepares the Hasher for reuse -func (h *Hasher) Reset() { - h.pos = 0 - h.size = 0 - copy(h.span, zerospan) -} - -// processSection writes the hash of i-th section into level 1 node of the BMT tree. -func (h *Hasher) processSection(i int, final bool) { - secsize := 2 * h.segmentSize - offset := i * secsize - level := 1 - // select the leaf node for the section - n := h.bmt.leaves[i] - isLeft := n.isLeft - hasher := n.hasher - n = n.parent - // hash the section - section, err := doHash(hasher, h.bmt.buffer[offset:offset+secsize]) - if err != nil { - select { - case h.errc <- err: - default: - } - return - } - // write hash into parent node - if final { - // for the last segment use writeFinalNode - h.writeFinalNode(level, n, isLeft, section) - } else { - h.writeNode(n, isLeft, section) - } -} - -// writeNode pushes the data to the node. -// if it is the first of 2 sisters written, the routine terminates. -// if it is the second, it calculates the hash and writes it -// to the parent node recursively. -// since hashing the parent is synchronous the same hasher can be used. -func (h *Hasher) writeNode(n *node, isLeft bool, s []byte) { - var err error - for { - // at the root of the bmt just write the result to the result channel - if n == nil { - h.result <- s - return - } - // otherwise assign child hash to left or right segment - if isLeft { - n.left = s - } else { - n.right = s - } - // the child-thread first arriving will terminate - if n.toggle() { - return - } - // the thread coming second now can be sure both left and right children are written - // so it calculates the hash of left|right and pushes it to the parent - s, err = doHash(n.hasher, n.left, n.right) - if err != nil { - select { - case h.errc <- err: - default: - } - return - } - isLeft = n.isLeft - n = n.parent - } -} - -// writeFinalNode is following the path starting from the final datasegment to the -// BMT root via parents. -// For unbalanced trees it fills in the missing right sister nodes using -// the pool's lookup table for BMT subtree root hashes for all-zero sections. -// Otherwise behaves like `writeNode`. -func (h *Hasher) writeFinalNode(level int, n *node, isLeft bool, s []byte) { - var err error - for { - // at the root of the bmt just write the result to the result channel - if n == nil { - if s != nil { - h.result <- s - } - return - } - var noHash bool - if isLeft { - // coming from left sister branch - // when the final section's path is going via left child node - // we include an all-zero subtree hash for the right level and toggle the node. - n.right = h.zerohashes[level] - if s != nil { - n.left = s - // if a left final node carries a hash, it must be the first (and only thread) - // so the toggle is already in passive state no need no call - // yet thread needs to carry on pushing hash to parent - noHash = false - } else { - // if again first thread then propagate nil and calculate no hash - noHash = n.toggle() - } - } else { - // right sister branch - if s != nil { - // if hash was pushed from right child node, write right segment change state - n.right = s - // if toggle is true, we arrived first so no hashing just push nil to parent - noHash = n.toggle() - } else { - // if s is nil, then thread arrived first at previous node and here there will be two, - // so no need to do anything and keep s = nil for parent - noHash = true - } - } - // the child-thread first arriving will just continue resetting s to nil - // the second thread now can be sure both left and right children are written - // it calculates the hash of left|right and pushes it to the parent - if noHash { - s = nil - } else { - s, err = doHash(n.hasher, n.left, n.right) - if err != nil { - select { - case h.errc <- err: - default: - } - return - } - } - // iterate to parent - isLeft = n.isLeft - n = n.parent - level++ - } -} - // calculates the Keccak256 SHA3 hash of the data func sha3hash(data ...[]byte) ([]byte, error) { return doHash(swarm.NewHasher(), data...) @@ -291,3 +51,15 @@ func doHash(h hash.Hash, data ...[]byte) ([]byte, error) { } return h.Sum(nil), nil } + +// SEGMENT_SIZE is the keccak256 output size in bytes, also the BMT leaf segment size. +const SEGMENT_SIZE = 32 + +// sizeToParams calculates the depth (number of levels) and segment count in the BMT tree. +func sizeToParams(n int) (c, d int) { + c = 2 + for ; c < n; c *= 2 { + d++ + } + return c, d + 1 +} diff --git a/pkg/bmt/bmt_simd.go b/pkg/bmt/bmt_simd.go new file mode 100644 index 00000000000..e26b6fd1cd1 --- /dev/null +++ b/pkg/bmt/bmt_simd.go @@ -0,0 +1,160 @@ +// Copyright 2024 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && amd64 && !purego + +package bmt + +import ( + "github.com/ethersphere/bee/v2/pkg/keccak" +) + +// hashSIMD computes the BMT root hash using SIMD-accelerated Keccak hashing. +// It processes the tree level by level from leaves to root, using batched +// SIMD calls instead of goroutine-per-section. A single thread handles all +// levels since SIMD already provides intra-call parallelism (4-way or 8-way). +func (h *simdHasher) hashSIMD() ([]byte, error) { + secsize := 2 * h.segmentSize + bw := h.batchWidth + prefixLen := len(h.prefix) + + // Leaf level: hash each section and write results to parent nodes. + h.hashLeavesBatch(0, len(h.bmt.levels[0]), bw, secsize, prefixLen) + + // Internal levels: process each level single-threaded (diminishing work). + for lvl := 1; lvl < len(h.bmt.levels)-1; lvl++ { + h.hashNodesBatch(h.bmt.levels[lvl], bw, prefixLen) + } + + // Root level: hash using scalar hasher. + root := h.bmt.levels[len(h.bmt.levels)-1][0] + return doHash(root.hasher, root.left, root.right) +} + +// hashLeavesBatch hashes leaf sections in the range [start, end) using SIMD batches. +func (h *simdHasher) hashLeavesBatch(start, end, bw, secsize, prefixLen int) { + buf := h.bmt.buffer + + if bw == 8 { + var inputs [8][]byte + for i := start; i < end; i += 8 { + batch := 8 + if i+batch > end { + batch = end - i + } + for j := 0; j < batch; j++ { + offset := (i + j) * secsize + if prefixLen > 0 { + copy(h.bmt.leafConcat[j][prefixLen:], buf[offset:offset+secsize]) + inputs[j] = h.bmt.leafConcat[j][:prefixLen+secsize] + } else { + inputs[j] = buf[offset : offset+secsize] + } + } + for j := batch; j < 8; j++ { + inputs[j] = nil + } + outputs := keccak.Sum256x8(inputs) + for j := 0; j < batch; j++ { + leaf := h.bmt.levels[0][i+j] + if leaf.isLeft { + copy(leaf.parent.left, outputs[j][:]) + } else { + copy(leaf.parent.right, outputs[j][:]) + } + } + } + } else { + var inputs [4][]byte + for i := start; i < end; i += 4 { + batch := 4 + if i+batch > end { + batch = end - i + } + for j := 0; j < batch; j++ { + offset := (i + j) * secsize + if prefixLen > 0 { + copy(h.bmt.leafConcat[j][prefixLen:], buf[offset:offset+secsize]) + inputs[j] = h.bmt.leafConcat[j][:prefixLen+secsize] + } else { + inputs[j] = buf[offset : offset+secsize] + } + } + for j := batch; j < 4; j++ { + inputs[j] = nil + } + outputs := keccak.Sum256x4(inputs) + for j := 0; j < batch; j++ { + leaf := h.bmt.levels[0][i+j] + if leaf.isLeft { + copy(leaf.parent.left, outputs[j][:]) + } else { + copy(leaf.parent.right, outputs[j][:]) + } + } + } + } +} + +// hashNodesBatch hashes a level of internal nodes using SIMD batches. +// Each node's left||right (64 bytes) is hashed to produce the input for its parent. +func (h *simdHasher) hashNodesBatch(nodes []*simdNode, bw, prefixLen int) { + count := len(nodes) + segSize := h.segmentSize + concat := &h.bmt.concat + + if bw == 8 { + var inputs [8][]byte + for i := 0; i < count; i += 8 { + batch := 8 + if i+batch > count { + batch = count - i + } + for j := 0; j < batch; j++ { + n := nodes[i+j] + copy(concat[j][prefixLen:prefixLen+segSize], n.left) + copy(concat[j][prefixLen+segSize:], n.right) + inputs[j] = concat[j][:prefixLen+2*segSize] + } + for j := batch; j < 8; j++ { + inputs[j] = nil + } + outputs := keccak.Sum256x8(inputs) + for j := 0; j < batch; j++ { + n := nodes[i+j] + if n.isLeft { + copy(n.parent.left, outputs[j][:]) + } else { + copy(n.parent.right, outputs[j][:]) + } + } + } + } else { + var inputs [4][]byte + for i := 0; i < count; i += 4 { + batch := 4 + if i+batch > count { + batch = count - i + } + for j := 0; j < batch; j++ { + n := nodes[i+j] + copy(concat[j][prefixLen:prefixLen+segSize], n.left) + copy(concat[j][prefixLen+segSize:], n.right) + inputs[j] = concat[j][:prefixLen+2*segSize] + } + for j := batch; j < 4; j++ { + inputs[j] = nil + } + outputs := keccak.Sum256x4(inputs) + for j := 0; j < batch; j++ { + n := nodes[i+j] + if n.isLeft { + copy(n.parent.left, outputs[j][:]) + } else { + copy(n.parent.right, outputs[j][:]) + } + } + } + } +} diff --git a/pkg/bmt/bmt_test.go b/pkg/bmt/bmt_test.go index c82fb896a6f..89e3a37e94d 100644 --- a/pkg/bmt/bmt_test.go +++ b/pkg/bmt/bmt_test.go @@ -45,7 +45,7 @@ func refHash(count int, data []byte) ([]byte, error) { } // syncHash hashes the data and the span using the bmt hasher -func syncHash(h *bmt.Hasher, data []byte) ([]byte, error) { +func syncHash(h bmt.Hasher, data []byte) ([]byte, error) { h.Reset() h.SetHeaderInt64(int64(len(data))) _, err := h.Write(data) @@ -67,7 +67,7 @@ func TestHasherEmptyData(t *testing.T) { if err != nil { t.Fatal(err) } - pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, count, 1)) + pool := bmt.NewPool(bmt.NewConf(count, 1)) h := pool.Get() resHash, err := syncHash(h, nil) if err != nil { @@ -92,7 +92,7 @@ func TestSyncHasherCorrectness(t *testing.T) { maxValue := count * hashSize var incr int capacity := 1 - pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, count, capacity)) + pool := bmt.NewPool(bmt.NewConf(count, capacity)) for n := 0; n <= maxValue; n += incr { h := pool.Get() incr = 1 + rand.Intn(5) @@ -125,7 +125,7 @@ func TestHasherReuse(t *testing.T) { func testHasherReuse(t *testing.T, poolsize int) { t.Helper() - pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, testSegmentCount, poolsize)) + pool := bmt.NewPool(bmt.NewConf(testSegmentCount, poolsize)) h := pool.Get() defer pool.Put(h) @@ -145,7 +145,7 @@ func TestBMTConcurrentUse(t *testing.T) { t.Parallel() testData := testutil.RandBytesWithSeed(t, 4096, seed) - pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, testSegmentCount, testPoolSize)) + pool := bmt.NewPool(bmt.NewConf(testSegmentCount, testPoolSize)) cycles := 100 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) @@ -178,7 +178,7 @@ func TestBMTWriterBuffers(t *testing.T) { t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { t.Parallel() - pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, count, testPoolSize)) + pool := bmt.NewPool(bmt.NewConf(count, testPoolSize)) h := pool.Get() defer pool.Put(h) @@ -253,7 +253,7 @@ func TestBMTWriterBuffers(t *testing.T) { } // helper function that compares reference and optimised implementations for correctness -func testHasherCorrectness(h *bmt.Hasher, data []byte, n, count int) (err error) { +func testHasherCorrectness(h bmt.Hasher, data []byte, n, count int) (err error) { if len(data) < n { n = len(data) } @@ -271,11 +271,70 @@ func testHasherCorrectness(h *bmt.Hasher, data []byte, n, count int) (err error) return nil } +// TestGoroutineSIMDParity verifies that the goroutine and SIMD pools produce +// identical root hashes for the same input across every segment count in +// testSegmentCounts and a spread of write lengths. Iterating over the full +// segment-count table exercises the degenerate 1-level and shallow-tree branches +// of simdHasher.Hash that a fixed 128-segment test would miss. +// +// On platforms where the dispatcher falls back to the goroutine pool +// (non-linux/amd64 or CPU without AVX2/AVX-512), both pools end up as goroutine +// pools and the test degrades to a goroutine-vs-goroutine comparison. +// +// Sub-tests are intentionally not run in parallel: the SIMDOptIn flip is global +// state and concurrent NewPool calls would see flapping values. +func TestGoroutineSIMDParity(t *testing.T) { + testData := testutil.RandBytesWithSeed(t, 4096, seed) + + prev := bmt.SIMDOptIn() + defer bmt.SetSIMDOptIn(prev) + + // Representative write lengths; clamped per segment count below. 0 covers the + // empty-input path, and the neighbourhood of section boundaries (31/32/33, + // 63/64/65) tickles the padding/fall-through in Hash. + lengths := []int{0, 1, 31, 32, 33, 63, 64, 65, 128, 500, 1024, 2048, 4095, 4096} + + for _, count := range testSegmentCounts { + t.Run(fmt.Sprintf("segments_%d", count), func(t *testing.T) { + bmt.SetSIMDOptIn(false) + gPool := bmt.NewPool(bmt.NewConf(count, 1)) + bmt.SetSIMDOptIn(true) + sPool := bmt.NewPool(bmt.NewConf(count, 1)) + + maxLen := count * hashSize + for _, n := range lengths { + if n > maxLen { + continue + } + t.Run(fmt.Sprintf("len_%d", n), func(t *testing.T) { + gh := gPool.Get() + gHash, err := syncHash(gh, testData[:n]) + gPool.Put(gh) + if err != nil { + t.Fatal(err) + } + + sh := sPool.Get() + sHash, err := syncHash(sh, testData[:n]) + sPool.Put(sh) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(gHash, sHash) { + t.Fatalf("goroutine/simd mismatch at count=%d len=%d\n goroutine: %x\n simd: %x", count, n, gHash, sHash) + } + }) + } + }) + } +} + // verifies that the bmt.Hasher can be used with the hash.Hash interface func TestUseSyncAsOrdinaryHasher(t *testing.T) { t.Parallel() - pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, testSegmentCount, testPoolSize)) + pool := bmt.NewPool(bmt.NewConf(testSegmentCount, testPoolSize)) h := pool.Get() defer pool.Put(h) data := []byte("moodbytesmoodbytesmoodbytesmoodbytes") diff --git a/pkg/bmt/dispatch_other.go b/pkg/bmt/dispatch_other.go new file mode 100644 index 00000000000..5b16eb0ba38 --- /dev/null +++ b/pkg/bmt/dispatch_other.go @@ -0,0 +1,23 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !linux || !amd64 || purego + +package bmt + +// NewPool returns a BMT pool. On this platform only the goroutine implementation +// is compiled in, so SIMDOptIn() is ignored. +func NewPool(c *Conf) Pool { + return newGoroutinePool(c) +} + +// NewHasher returns a standalone (non-pooled) BMT hasher. +func NewHasher() Hasher { + return newGoroutineHasher() +} + +// NewPrefixHasher returns a standalone BMT hasher with the given prefix. +func NewPrefixHasher(prefix []byte) Hasher { + return newGoroutinePrefixHasher(prefix) +} diff --git a/pkg/bmt/dispatch_simd.go b/pkg/bmt/dispatch_simd.go new file mode 100644 index 00000000000..1123dfe2f9a --- /dev/null +++ b/pkg/bmt/dispatch_simd.go @@ -0,0 +1,37 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && amd64 && !purego + +package bmt + +import ( + "github.com/ethersphere/bee/v2/pkg/keccak" +) + +// NewPool returns a BMT pool. If SIMDOptIn() is true and the CPU exposes AVX2 +// or AVX-512, a SIMD-batched pool is returned. Otherwise the goroutine-based +// pool is returned (silent fallback). +func NewPool(c *Conf) Pool { + if SIMDOptIn() && keccak.HasSIMD() { + return newSIMDPool(c) + } + return newGoroutinePool(c) +} + +// NewHasher returns a standalone (non-pooled) BMT hasher. +func NewHasher() Hasher { + if SIMDOptIn() && keccak.HasSIMD() { + return newSIMDHasher() + } + return newGoroutineHasher() +} + +// NewPrefixHasher returns a standalone BMT hasher with the given prefix. +func NewPrefixHasher(prefix []byte) Hasher { + if SIMDOptIn() && keccak.HasSIMD() { + return newSIMDPrefixHasher(prefix) + } + return newGoroutinePrefixHasher(prefix) +} diff --git a/pkg/bmt/hasher_goroutine.go b/pkg/bmt/hasher_goroutine.go new file mode 100644 index 00000000000..2448d70515b --- /dev/null +++ b/pkg/bmt/hasher_goroutine.go @@ -0,0 +1,308 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bmt + +import ( + "encoding/binary" + + "github.com/ethersphere/bee/v2/pkg/swarm" +) + +var _ Hasher = (*goroutineHasher)(nil) + +// goroutineHasher is the goroutine-based BMT hasher implementation. +// It uses one goroutine per leaf section to hash sections concurrently, +// with atomic toggles to coordinate parent-node writes. +// +// The same hasher instance must not be called concurrently on more than one chunk. +// The same hasher instance is synchronously reusable after calling Reset. +type goroutineHasher struct { + *goroutineConf + bmt *goroutineTree + size int + pos int + result chan []byte + errc chan error + span []byte +} + +func newGoroutineHasher() *goroutineHasher { + gc := newGoroutineConf(nil, swarm.BmtBranches, 32) + return &goroutineHasher{ + goroutineConf: gc, + result: make(chan []byte), + errc: make(chan error, 1), + span: make([]byte, SpanSize), + bmt: newGoroutineTree(gc.maxSize, gc.depth, gc.hasherFunc), + } +} + +func newGoroutinePrefixHasher(prefix []byte) *goroutineHasher { + gc := newGoroutineConf(prefix, swarm.BmtBranches, 32) + return &goroutineHasher{ + goroutineConf: gc, + result: make(chan []byte), + errc: make(chan error, 1), + span: make([]byte, SpanSize), + bmt: newGoroutineTree(gc.maxSize, gc.depth, gc.hasherFunc), + } +} + +// Capacity returns the maximum number of bytes this hasher can process. +func (h *goroutineHasher) Capacity() int { return h.maxSize } + +// Size returns the digest size. +func (h *goroutineHasher) Size() int { return h.segmentSize } + +// BlockSize returns the optimal write block size. +func (h *goroutineHasher) BlockSize() int { return 2 * h.segmentSize } + +// Sum is the unsafe Hash wrapper. +func (h *goroutineHasher) Sum(b []byte) []byte { s, _ := h.Hash(b); return s } + +// SetHeaderInt64 sets the span preamble from an int64. +func (h *goroutineHasher) SetHeaderInt64(length int64) { + binary.LittleEndian.PutUint64(h.span, uint64(length)) +} + +// SetHeader copies the span preamble from the argument. +func (h *goroutineHasher) SetHeader(span []byte) { copy(h.span, span) } + +// Write appends to the chunk buffer; each complete section triggers a goroutine-level hash. +func (h *goroutineHasher) Write(b []byte) (int, error) { + l := len(b) + maxVal := h.maxSize - h.size + if l > maxVal { + l = maxVal + } + copy(h.bmt.buffer[h.size:], b) + secsize := 2 * h.segmentSize + from := h.size / secsize + h.size += l + to := h.size / secsize + if l == maxVal { + to-- + } + h.pos = to + for i := from; i < to; i++ { + go h.processSection(i, false) + } + return l, nil +} + +// Hash returns the BMT root hash of the buffer written so far. +func (h *goroutineHasher) Hash(b []byte) ([]byte, error) { + if h.size == 0 { + return doHash(h.baseHasher(), h.span, h.zerohashes[h.depth]) + } + copy(h.bmt.buffer[h.size:], zerosection) + // write the last section with final flag set to true + go h.processSection(h.pos, true) + select { + case result := <-h.result: + return doHash(h.baseHasher(), h.span, result) + case err := <-h.errc: + return nil, err + } +} + +// Reset prepares the Hasher for reuse. +func (h *goroutineHasher) Reset() { + h.pos = 0 + h.size = 0 + copy(h.span, zerospan) +} + +// Proof returns the inclusion proof of the i-th data segment. +func (h *goroutineHasher) Proof(i int) Proof { + index := i + if i < 0 || i > 127 { + panic("segment index can only lie between 0-127") + } + i = i / 2 + n := h.bmt.leaves[i] + isLeft := n.isLeft + var sisters [][]byte + for n = n.parent; n != nil; n = n.parent { + sisters = append(sisters, n.getSister(isLeft)) + isLeft = n.isLeft + } + + secsize := 2 * h.segmentSize + offset := i * secsize + section := make([]byte, secsize) + copy(section, h.bmt.buffer[offset:offset+secsize]) + segment, firstSegmentSister := section[:h.segmentSize], section[h.segmentSize:] + if index%2 != 0 { + segment, firstSegmentSister = firstSegmentSister, segment + } + sisters = append([][]byte{firstSegmentSister}, sisters...) + return Proof{segment, sisters, h.span, index} +} + +// Verify reconstructs the BMT root from a proof for the i-th segment. +func (h *goroutineHasher) Verify(i int, proof Proof) (root []byte, err error) { + var section []byte + if i%2 == 0 { + section = append(append(section, proof.ProveSegment...), proof.ProofSegments[0]...) + } else { + section = append(append(section, proof.ProofSegments[0]...), proof.ProveSegment...) + } + i = i / 2 + n := h.bmt.leaves[i] + hasher := h.baseHasher() + isLeft := n.isLeft + root, err = doHash(hasher, section) + if err != nil { + return nil, err + } + n = n.parent + + for _, sister := range proof.ProofSegments[1:] { + if isLeft { + root, err = doHash(hasher, root, sister) + } else { + root, err = doHash(hasher, sister, root) + } + if err != nil { + return nil, err + } + isLeft = n.isLeft + n = n.parent + } + return doHash(hasher, proof.Span, root) +} + +// processSection writes the hash of i-th section into level 1 node of the BMT tree. +func (h *goroutineHasher) processSection(i int, final bool) { + secsize := 2 * h.segmentSize + offset := i * secsize + level := 1 + // select the leaf node for the section + n := h.bmt.leaves[i] + isLeft := n.isLeft + hasher := n.hasher + n = n.parent + // hash the section + section, err := doHash(hasher, h.bmt.buffer[offset:offset+secsize]) + if err != nil { + select { + case h.errc <- err: + default: + } + return + } + // write hash into parent node + if final { + // for the last segment use writeFinalNode + h.writeFinalNode(level, n, isLeft, section) + } else { + h.writeNode(n, isLeft, section) + } +} + +// writeNode pushes the data to the node. +// if it is the first of 2 sisters written, the routine terminates. +// if it is the second, it calculates the hash and writes it +// to the parent node recursively. +// since hashing the parent is synchronous the same hasher can be used. +func (h *goroutineHasher) writeNode(n *goroutineNode, isLeft bool, s []byte) { + var err error + for { + // at the root of the bmt just write the result to the result channel + if n == nil { + h.result <- s + return + } + // otherwise assign child hash to left or right segment + if isLeft { + n.left = s + } else { + n.right = s + } + // the child-thread first arriving will terminate + if n.toggle() { + return + } + // the thread coming second now can be sure both left and right children are written + // so it calculates the hash of left|right and pushes it to the parent + s, err = doHash(n.hasher, n.left, n.right) + if err != nil { + select { + case h.errc <- err: + default: + } + return + } + isLeft = n.isLeft + n = n.parent + } +} + +// writeFinalNode is following the path starting from the final datasegment to the +// BMT root via parents. +// For unbalanced trees it fills in the missing right sister nodes using +// the pool's lookup table for BMT subtree root hashes for all-zero sections. +// Otherwise behaves like `writeNode`. +func (h *goroutineHasher) writeFinalNode(level int, n *goroutineNode, isLeft bool, s []byte) { + var err error + for { + // at the root of the bmt just write the result to the result channel + if n == nil { + if s != nil { + h.result <- s + } + return + } + var noHash bool + if isLeft { + // coming from left sister branch + // when the final section's path is going via left child node + // we include an all-zero subtree hash for the right level and toggle the node. + n.right = h.zerohashes[level] + if s != nil { + n.left = s + // if a left final node carries a hash, it must be the first (and only thread) + // so the toggle is already in passive state no need no call + // yet thread needs to carry on pushing hash to parent + noHash = false + } else { + // if again first thread then propagate nil and calculate no hash + noHash = n.toggle() + } + } else { + // right sister branch + if s != nil { + // if hash was pushed from right child node, write right segment change state + n.right = s + // if toggle is true, we arrived first so no hashing just push nil to parent + noHash = n.toggle() + } else { + // if s is nil, then thread arrived first at previous node and here there will be two, + // so no need to do anything and keep s = nil for parent + noHash = true + } + } + // the child-thread first arriving will just continue resetting s to nil + // the second thread now can be sure both left and right children are written + // it calculates the hash of left|right and pushes it to the parent + if noHash { + s = nil + } else { + s, err = doHash(n.hasher, n.left, n.right) + if err != nil { + select { + case h.errc <- err: + default: + } + return + } + } + // iterate to parent + isLeft = n.isLeft + n = n.parent + level++ + } +} diff --git a/pkg/bmt/hasher_simd.go b/pkg/bmt/hasher_simd.go new file mode 100644 index 00000000000..56ff4b14d56 --- /dev/null +++ b/pkg/bmt/hasher_simd.go @@ -0,0 +1,126 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && amd64 && !purego + +package bmt + +import ( + "encoding/binary" + + "github.com/ethersphere/bee/v2/pkg/swarm" +) + +var _ Hasher = (*simdHasher)(nil) + +// simdHasher is a BMT hasher implementation that buffers all data and defers hashing +// to Hash(), using SIMD-accelerated Keccak for multi-level trees. +// +// Single-threaded: SIMD provides intra-call parallelism (4-way or 8-way), replacing +// the goroutine-per-section model of the goroutine hasher. +// +// The same hasher instance must not be called concurrently on more than one chunk. +// The same hasher instance is synchronously reusable after calling Reset. +type simdHasher struct { + *simdConf + bmt *simdTree + size int + span []byte +} + +func newSIMDHasher() *simdHasher { + sc := newSIMDConf(nil, swarm.BmtBranches, 32) + return &simdHasher{ + simdConf: sc, + span: make([]byte, SpanSize), + bmt: newSIMDTree(sc.maxSize, sc.depth, sc.baseHasher, sc.prefix), + } +} + +func newSIMDPrefixHasher(prefix []byte) *simdHasher { + sc := newSIMDConf(prefix, swarm.BmtBranches, 32) + return &simdHasher{ + simdConf: sc, + span: make([]byte, SpanSize), + bmt: newSIMDTree(sc.maxSize, sc.depth, sc.baseHasher, sc.prefix), + } +} + +// Capacity returns the maximum number of bytes this hasher can process. +func (h *simdHasher) Capacity() int { return h.maxSize } + +// Size returns the digest size. +func (h *simdHasher) Size() int { return h.segmentSize } + +// BlockSize returns the optimal write block size. +func (h *simdHasher) BlockSize() int { return 2 * h.segmentSize } + +// Sum is the unsafe Hash wrapper. +func (h *simdHasher) Sum(b []byte) []byte { s, _ := h.Hash(b); return s } + +// SetHeaderInt64 sets the span preamble from an int64. +func (h *simdHasher) SetHeaderInt64(length int64) { + binary.LittleEndian.PutUint64(h.span, uint64(length)) +} + +// SetHeader copies the span preamble from the argument. +func (h *simdHasher) SetHeader(span []byte) { copy(h.span, span) } + +// Write buffers input to be hashed. All hashing is deferred to Hash(). +func (h *simdHasher) Write(b []byte) (int, error) { + l := len(b) + maxVal := h.maxSize - h.size + if l > maxVal { + l = maxVal + } + copy(h.bmt.buffer[h.size:], b) + h.size += l + return l, nil +} + +// Hash returns the BMT root hash of the buffer written so far. +func (h *simdHasher) Hash(b []byte) ([]byte, error) { + // empty input: no data was ever written, so the BMT root is the all-zero + // subtree hash at depth h.depth. Still prepend the span so the output + // shape matches a normal chunk hash. + if h.size == 0 { + return doHash(h.baseHasher(), h.span, h.zerohashes[h.depth]) + } + // zero-fill the tail of the buffer: every leaf section must carry + // deterministic bytes, because SIMD batches hash whole sections at a time + // without a "is this section occupied?" check. clear() lowers to memclr. + clear(h.bmt.buffer[h.size:]) + // degenerate single-level tree: the whole buffer is already one section, + // so there is nothing for SIMD to batch — just run the scalar hasher. + if len(h.bmt.levels) == 1 { + secsize := 2 * h.segmentSize + root := h.bmt.levels[0][0] + rootHash, err := doHash(root.hasher, h.bmt.buffer[:secsize]) + if err != nil { + return nil, err + } + return doHash(h.baseHasher(), h.span, rootHash) + } + // general case: fan out through the tree via hashSIMD, which processes + // each level bottom-up in batches of 4 or 8 hashes per SIMD call. + rootHash, err := h.hashSIMD() + if err != nil { + return nil, err + } + // prepend the span and hash once more to produce the chunk address. + // When a prefix is configured, baseHasher() returns a PrefixHasher whose + // Reset re-absorbs the prefix, so the final digest is + // keccak(prefix || span || rootHash) — the same wrap-per-level rule used + // by hashSIMD at every internal level. + return doHash(h.baseHasher(), h.span, rootHash) +} + +// Reset prepares the Hasher for reuse. The internal data buffer is not zeroed; +// any stale bytes past h.size are overwritten on the next Write and zero-filled +// on demand by Hash, so post-Reset inspection of the buffer may still see +// previous-chunk contents. +func (h *simdHasher) Reset() { + h.size = 0 + copy(h.span, zerospan) +} diff --git a/pkg/bmt/interface.go b/pkg/bmt/interface.go index 767892365b5..b2e81ca9c23 100644 --- a/pkg/bmt/interface.go +++ b/pkg/bmt/interface.go @@ -12,11 +12,15 @@ const ( SpanSize = 8 ) -// Hash provides the necessary extension of the hash interface to add the length-prefix of the BMT hash. +// Hasher provides the necessary extension of the hash interface to add the length-prefix of the BMT hash. // // Any implementation should make it possible to generate a BMT hash using the hash.Hash interface only. // However, the limitation will be that the Span of the BMT hash always must be limited to the amount of bytes actually written. -type Hash interface { +// +// Inclusion-proof generation (Proof / Verify / zero-padded Hash) is NOT part of +// this interface — it lives on the Prover type, which is always goroutine-backed +// regardless of SIMDOptIn. See pkg/bmt/proof.go. +type Hasher interface { hash.Hash // SetHeaderInt64 sets the header bytes of BMT hash to the little endian binary representation of the int64 argument. diff --git a/pkg/bmt/pool.go b/pkg/bmt/pool.go index a7f7245c40d..5be36aab8e5 100644 --- a/pkg/bmt/pool.go +++ b/pkg/bmt/pool.go @@ -4,152 +4,29 @@ package bmt -import ( - "hash" - "sync/atomic" -) - -// BaseHasherFunc is a hash.Hash constructor function used for the base hash of the BMT. -// implemented by Keccak256 SHA3 sha3.NewLegacyKeccak256 -type BaseHasherFunc func() hash.Hash - -// configuration -type Conf struct { - segmentSize int // size of leaf segments, stipulated to be = hash size - segmentCount int // the number of segments on the base level of the BMT - capacity int // pool capacity, controls concurrency - depth int // depth of the bmt trees = int(log2(segmentCount))+1 - maxSize int // the total length of the data (count * size) - zerohashes [][]byte // lookup table for predictable padding subtrees for all levels - hasher BaseHasherFunc // base hasher to use for the BMT levels -} - -// Pool provides a pool of trees used as resources by the BMT Hasher. -// A tree popped from the pool is guaranteed to have a clean state ready -// for hashing a new chunk. -type Pool struct { - c chan *tree // the channel to obtain a resource from the pool - *Conf // configuration -} - -func NewConf(hasher BaseHasherFunc, segmentCount, capacity int) *Conf { - count, depth := sizeToParams(segmentCount) - segmentSize := hasher().Size() - zerohashes := make([][]byte, depth+1) - zeros := make([]byte, segmentSize) - zerohashes[0] = zeros - var err error - // initialises the zerohashes lookup table - for i := 1; i < depth+1; i++ { - if zeros, err = doHash(hasher(), zeros, zeros); err != nil { - panic(err.Error()) - } - zerohashes[i] = zeros - } - return &Conf{ - hasher: hasher, - segmentSize: segmentSize, - segmentCount: segmentCount, - capacity: capacity, - maxSize: count * segmentSize, - depth: depth, - zerohashes: zerohashes, - } -} - -// NewPool creates a tree pool with hasher, segment size, segment count and capacity -// it reuses free trees or creates a new one if capacity is not reached. -func NewPool(c *Conf) *Pool { - p := &Pool{ - Conf: c, - c: make(chan *tree, c.capacity), - } - for i := 0; i < c.capacity; i++ { - p.c <- newTree(p.maxSize, p.depth, p.hasher) - } - return p -} - -// Get returns a BMT hasher possibly reusing a tree from the pool -func (p *Pool) Get() *Hasher { - t := <-p.c - return &Hasher{ - Conf: p.Conf, - result: make(chan []byte), - errc: make(chan error, 1), - span: make([]byte, SpanSize), - bmt: t, - } -} - -// Put is called after using a bmt hasher to return the tree to a pool for reuse -func (p *Pool) Put(h *Hasher) { - p.c <- h.bmt -} - -// tree is a reusable control structure representing a BMT -// organised in a binary tree -// -// Hasher uses a Pool to obtain a tree for each chunk hash -// the tree is 'locked' while not in the pool. -type tree struct { - leaves []*node // leaf nodes of the tree, other nodes accessible via parent links - buffer []byte +// Pool is the interface all BMT hasher pools satisfy. +type Pool interface { + // Get returns a BMT hasher, possibly reusing one from the pool. + Get() Hasher + // Put returns a hasher to the pool for reuse. + Put(Hasher) } -// node is a reusable segment hasher representing a node in a BMT. -type node struct { - isLeft bool // whether it is left side of the parent double segment - parent *node // pointer to parent node in the BMT - state int32 // atomic increment impl concurrent boolean toggle - left, right []byte // this is where the two children sections are written - hasher hash.Hash // preconstructed hasher on nodes -} - -// newNode constructs a segment hasher node in the BMT (used by newTree). -func newNode(index int, parent *node, hasher hash.Hash) *node { - return &node{ - parent: parent, - isLeft: index%2 == 0, - hasher: hasher, - } -} - -// newTree initialises a tree by building up the nodes of a BMT -func newTree(maxsize, depth int, hashfunc func() hash.Hash) *tree { - n := newNode(0, nil, hashfunc()) - prevlevel := []*node{n} - // iterate over levels and creates 2^(depth-level) nodes - // the 0 level is on double segment sections so we start at depth - 2 - count := 2 - for level := depth - 2; level >= 0; level-- { - nodes := make([]*node, count) - for i := 0; i < count; i++ { - parent := prevlevel[i/2] - nodes[i] = newNode(i, parent, hashfunc()) - } - prevlevel = nodes - count *= 2 - } - // the datanode level is the nodes on the last level - return &tree{ - leaves: prevlevel, - buffer: make([]byte, maxsize), - } +// Conf captures the user-facing configuration of a BMT hasher pool. +// Implementation-specific derived state (tree depth, batch width, etc.) +// is computed internally by each pool implementation. +type Conf struct { + SegmentCount int + Capacity int + Prefix []byte } -// atomic bool toggle implementing a concurrent reusable 2-state object. -// Atomic addint with %2 implements atomic bool toggle. -// It returns true if the toggler just put it in the active/waiting state. -func (n *node) toggle() bool { - return atomic.AddInt32(&n.state, 1)%2 == 1 +// NewConf returns a new Conf for the given segment count and pool capacity. +func NewConf(segmentCount, capacity int) *Conf { + return &Conf{SegmentCount: segmentCount, Capacity: capacity} } -// sizeToParams calculates the depth (number of levels) and segment count in the BMT tree. -func sizeToParams(n int) (c, d int) { - c = 2 - for ; c < n; c *= 2 { - d++ - } - return c, d + 1 +// NewConfWithPrefix is like NewConf but with an optional prefix prepended to every hash operation. +func NewConfWithPrefix(prefix []byte, segmentCount, capacity int) *Conf { + return &Conf{SegmentCount: segmentCount, Capacity: capacity, Prefix: prefix} } diff --git a/pkg/bmt/pool_goroutine.go b/pkg/bmt/pool_goroutine.go new file mode 100644 index 00000000000..e8764c10646 --- /dev/null +++ b/pkg/bmt/pool_goroutine.go @@ -0,0 +1,206 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bmt + +import ( + "hash" + "sync/atomic" + + "github.com/ethersphere/bee/v2/pkg/swarm" +) + +// goroutineConf is the internal configuration for the goroutine BMT pool. +type goroutineConf struct { + segmentSize int + segmentCount int + capacity int + depth int + maxSize int + zerohashes [][]byte + prefix []byte + hasherFunc func() hash.Hash +} + +func (c *goroutineConf) baseHasher() hash.Hash { + if len(c.prefix) > 0 { + return swarm.NewPrefixHasher(c.prefix) + } + return swarm.NewHasher() +} + +// goroutinePool is the goroutine-based BMT hasher pool. +type goroutinePool struct { + c chan *goroutineTree + *goroutineConf +} + +func newGoroutineConf(prefix []byte, segmentCount, capacity int) *goroutineConf { + count, depth := sizeToParams(segmentCount) + segmentSize := SEGMENT_SIZE + + hasherFunc := func() hash.Hash { + if len(prefix) > 0 { + return swarm.NewPrefixHasher(prefix) + } + return swarm.NewHasher() + } + + c := &goroutineConf{ + segmentSize: segmentSize, + segmentCount: segmentCount, + capacity: capacity, + maxSize: count * segmentSize, + depth: depth, + prefix: prefix, + hasherFunc: hasherFunc, + } + + zerohashes := make([][]byte, depth+1) + zeros := make([]byte, segmentSize) + zerohashes[0] = zeros + var err error + for i := 1; i < depth+1; i++ { + if zeros, err = doHash(c.baseHasher(), zeros, zeros); err != nil { + panic(err.Error()) + } + zerohashes[i] = zeros + } + c.zerohashes = zerohashes + + return c +} + +// newGoroutinePool creates a goroutine BMT pool from a public Conf. +func newGoroutinePool(c *Conf) *goroutinePool { + gc := newGoroutineConf(c.Prefix, c.SegmentCount, c.Capacity) + p := &goroutinePool{ + goroutineConf: gc, + c: make(chan *goroutineTree, gc.capacity), + } + for i := 0; i < gc.capacity; i++ { + p.c <- newGoroutineTree(gc.maxSize, gc.depth, gc.hasherFunc) + } + return p +} + +// Get returns a BMT hasher, possibly reusing a tree from the pool. +func (p *goroutinePool) Get() Hasher { + t := <-p.c + return &goroutineHasher{ + goroutineConf: p.goroutineConf, + result: make(chan []byte), + errc: make(chan error, 1), + span: make([]byte, SpanSize), + bmt: t, + } +} + +// Put returns a hasher's tree to the pool for reuse. +func (p *goroutinePool) Put(h Hasher) { + gh, ok := h.(*goroutineHasher) + if !ok { + panic("bmt: goroutinePool.Put called with non-goroutineHasher") + } + p.c <- gh.bmt +} + +// goroutineProverPool is a pool of goroutine-backed Provers. Trees are reused +// via a channel; the hasher shell is reallocated per GetProver (cheap relative +// to a tree). +type goroutineProverPool struct { + c chan *goroutineTree + *goroutineConf +} + +func newGoroutineProverPool(c *Conf) *goroutineProverPool { + gc := newGoroutineConf(c.Prefix, c.SegmentCount, c.Capacity) + p := &goroutineProverPool{ + goroutineConf: gc, + c: make(chan *goroutineTree, gc.capacity), + } + for i := 0; i < gc.capacity; i++ { + p.c <- newGoroutineTree(gc.maxSize, gc.depth, gc.hasherFunc) + } + return p +} + +// GetProver returns a goroutine-backed Prover, reusing a tree from the pool. +func (p *goroutineProverPool) GetProver() *Prover { + t := <-p.c + return &Prover{ + goroutineHasher: &goroutineHasher{ + goroutineConf: p.goroutineConf, + result: make(chan []byte), + errc: make(chan error, 1), + span: make([]byte, SpanSize), + bmt: t, + }, + } +} + +// PutProver returns a Prover's tree to the pool for reuse. +func (p *goroutineProverPool) PutProver(pr *Prover) { + p.c <- pr.bmt +} + +// goroutineTree is the tree structure used by the goroutine hasher. +type goroutineTree struct { + leaves []*goroutineNode + buffer []byte +} + +// goroutineNode is a reusable segment hasher node in the goroutine BMT. +type goroutineNode struct { + isLeft bool + parent *goroutineNode + state int32 + left, right []byte + hasher hash.Hash +} + +func newGoroutineNode(index int, parent *goroutineNode, hasher hash.Hash) *goroutineNode { + return &goroutineNode{ + parent: parent, + isLeft: index%2 == 0, + hasher: hasher, + } +} + +func newGoroutineTree(maxsize, depth int, hashfunc func() hash.Hash) *goroutineTree { + n := newGoroutineNode(0, nil, hashfunc()) + prevlevel := []*goroutineNode{n} + count := 2 + for level := depth - 2; level >= 0; level-- { + nodes := make([]*goroutineNode, count) + for i := 0; i < count; i++ { + parent := prevlevel[i/2] + nodes[i] = newGoroutineNode(i, parent, hashfunc()) + } + prevlevel = nodes + count *= 2 + } + return &goroutineTree{ + leaves: prevlevel, + buffer: make([]byte, maxsize), + } +} + +// toggle implements atomic bool toggle for goroutineNode coordination. +func (n *goroutineNode) toggle() bool { + return atomic.AddInt32(&n.state, 1)%2 == 1 +} + +// getSister returns a copy of the sibling section on the opposite side. +func (n *goroutineNode) getSister(isLeft bool) []byte { + var src []byte + if isLeft { + src = n.right + } else { + src = n.left + } + buf := make([]byte, len(src)) + copy(buf, src) + return buf +} diff --git a/pkg/bmt/pool_simd.go b/pkg/bmt/pool_simd.go new file mode 100644 index 00000000000..95f8323f175 --- /dev/null +++ b/pkg/bmt/pool_simd.go @@ -0,0 +1,170 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && amd64 && !purego + +package bmt + +import ( + "hash" + + "github.com/ethersphere/bee/v2/pkg/keccak" + "github.com/ethersphere/bee/v2/pkg/swarm" +) + +// simdConf is the internal configuration for the SIMD BMT pool. +type simdConf struct { + segmentSize int + segmentCount int + capacity int + depth int + maxSize int + zerohashes [][]byte + prefix []byte + batchWidth int +} + +func (c *simdConf) baseHasher() hash.Hash { + if len(c.prefix) > 0 { + return swarm.NewPrefixHasher(c.prefix) + } + return swarm.NewHasher() +} + +// simdPool is the SIMD-batched BMT hasher pool. +type simdPool struct { + c chan *simdTree + *simdConf +} + +func newSIMDConf(prefix []byte, segmentCount, capacity int) *simdConf { + count, depth := sizeToParams(segmentCount) + segmentSize := SEGMENT_SIZE + + c := &simdConf{ + segmentSize: segmentSize, + segmentCount: segmentCount, + capacity: capacity, + maxSize: count * segmentSize, + depth: depth, + prefix: prefix, + batchWidth: keccak.BatchWidth(), + } + + zerohashes := make([][]byte, depth+1) + zeros := make([]byte, segmentSize) + zerohashes[0] = zeros + var err error + for i := 1; i < depth+1; i++ { + if zeros, err = doHash(c.baseHasher(), zeros, zeros); err != nil { + panic(err.Error()) + } + zerohashes[i] = zeros + } + c.zerohashes = zerohashes + + return c +} + +// newSIMDPool creates a SIMD BMT pool from a public Conf. +func newSIMDPool(c *Conf) *simdPool { + sc := newSIMDConf(c.Prefix, c.SegmentCount, c.Capacity) + p := &simdPool{ + simdConf: sc, + c: make(chan *simdTree, sc.capacity), + } + for i := 0; i < sc.capacity; i++ { + p.c <- newSIMDTree(sc.maxSize, sc.depth, sc.baseHasher, sc.prefix) + } + return p +} + +func (p *simdPool) Get() Hasher { + t := <-p.c + return &simdHasher{ + simdConf: p.simdConf, + span: make([]byte, SpanSize), + bmt: t, + } +} + +func (p *simdPool) Put(h Hasher) { + sh, ok := h.(*simdHasher) + if !ok { + panic("bmt: simdPool.Put called with non-simdHasher") + } + p.c <- sh.bmt +} + +// simdTree is the tree structure used by the SIMD hasher. +type simdTree struct { + leaves []*simdNode + levels [][]*simdNode + buffer []byte + concat [8][]byte + leafConcat [8][]byte +} + +// simdNode is a reusable segment hasher node in the SIMD BMT. +type simdNode struct { + isLeft bool + parent *simdNode + left, right []byte + hasher hash.Hash +} + +func newSIMDNode(index int, parent *simdNode, hasher hash.Hash) *simdNode { + return &simdNode{ + parent: parent, + isLeft: index%2 == 0, + hasher: hasher, + left: make([]byte, hasher.Size()), + right: make([]byte, hasher.Size()), + } +} + +func newSIMDTree(maxsize, depth int, hashfunc func() hash.Hash, prefix []byte) *simdTree { + prefixLen := len(prefix) + n := newSIMDNode(0, nil, hashfunc()) + prevlevel := []*simdNode{n} + allLevels := [][]*simdNode{prevlevel} + count := 2 + for level := depth - 2; level >= 0; level-- { + nodes := make([]*simdNode, count) + for i := 0; i < count; i++ { + parent := prevlevel[i/2] + nodes[i] = newSIMDNode(i, parent, hashfunc()) + } + allLevels = append(allLevels, nodes) + prevlevel = nodes + count *= 2 + } + // reverse so levels[0]=leaves, levels[len-1]=root + for i, j := 0, len(allLevels)-1; i < j; i, j = i+1, j-1 { + allLevels[i], allLevels[j] = allLevels[j], allLevels[i] + } + segSize := hashfunc().Size() + bufSize := prefixLen + 2*segSize + var concat [8][]byte + for i := range concat { + concat[i] = make([]byte, bufSize) + if prefixLen > 0 { + copy(concat[i][:prefixLen], prefix) + } + } + var leafConcat [8][]byte + for i := range leafConcat { + leafConcat[i] = make([]byte, prefixLen+2*segSize) + if prefixLen > 0 { + copy(leafConcat[i][:prefixLen], prefix) + } + } + return &simdTree{ + leaves: prevlevel, + levels: allLevels, + buffer: make([]byte, maxsize), + concat: concat, + leafConcat: leafConcat, + } +} diff --git a/pkg/bmt/proof.go b/pkg/bmt/proof.go index b08017b43cd..0e90a3a9fa6 100644 --- a/pkg/bmt/proof.go +++ b/pkg/bmt/proof.go @@ -4,12 +4,21 @@ package bmt -// Prover wraps the Hasher to allow Merkle proof functionality +// Prover adds Merkle-proof generation and verification on top of a BMT hasher. +// +// Proof generation requires the tree to be fully populated, so Prover.Hash +// zero-pads any unwritten sections before hashing. Prover is always backed by +// the goroutine BMT implementation, independent of the SIMD opt-in flag: +// proofs are produced on a rare, redistribution-only code path where the +// well-tested goroutine implementation is preferred over the SIMD speedup. +// +// The embedded *goroutineHasher is package-private, so callers outside +// pkg/bmt cannot bypass Hash/Sum to skip padding. type Prover struct { - *Hasher + *goroutineHasher } -// Proof represents a Merkle proof of segment +// Proof represents a Merkle proof of segment. type Proof struct { ProveSegment []byte ProofSegments [][]byte @@ -17,89 +26,46 @@ type Proof struct { Index int } -// Hash overrides base hash function of Hasher to fill buffer with zeros until chunk length -func (p Prover) Hash(b []byte) ([]byte, error) { +// Hash zero-pads any unwritten sections (so every leaf section in the BMT is +// populated and Proof paths are reconstructible), then computes the BMT root. +// Shadows the promoted goroutineHasher.Hash. +func (p *Prover) Hash(b []byte) ([]byte, error) { for i := p.size; i < p.maxSize; i += len(zerosection) { - _, err := p.Write(zerosection) - if err != nil { + if _, err := p.Write(zerosection); err != nil { return nil, err } } - return p.Hasher.Hash(b) + return p.goroutineHasher.Hash(b) } -// Proof returns the inclusion proof of the i-th data segment -func (p Prover) Proof(i int) Proof { - index := i - - if i < 0 || i > 127 { - panic("segment index can only lie between 0-127") - } - i = i / 2 - n := p.bmt.leaves[i] - isLeft := n.isLeft - var sisters [][]byte - for n = n.parent; n != nil; n = n.parent { - sisters = append(sisters, n.getSister(isLeft)) - isLeft = n.isLeft - } - - secsize := 2 * p.segmentSize - offset := i * secsize - section := make([]byte, secsize) - copy(section, p.bmt.buffer[offset:offset+secsize]) - segment, firstSegmentSister := section[:p.segmentSize], section[p.segmentSize:] - if index%2 != 0 { - segment, firstSegmentSister = firstSegmentSister, segment - } - sisters = append([][]byte{firstSegmentSister}, sisters...) - return Proof{segment, sisters, p.span, index} +// Sum shadows the promoted goroutineHasher.Sum so Prover used as a hash.Hash +// also pads. Without this override Sum would call the unpadded Hash and +// produce a different digest than Prover.Hash. +func (p *Prover) Sum(b []byte) []byte { + s, _ := p.Hash(b) + return s } -// Verify returns the bmt hash obtained from the proof which can then be checked against -// the BMT hash of the chunk -func (p Prover) Verify(i int, proof Proof) (root []byte, err error) { - var section []byte - if i%2 == 0 { - section = append(append(section, proof.ProveSegment...), proof.ProofSegments[0]...) - } else { - section = append(append(section, proof.ProofSegments[0]...), proof.ProveSegment...) - } - i = i / 2 - n := p.bmt.leaves[i] - hasher := p.hasher() - isLeft := n.isLeft - root, err = doHash(hasher, section) - if err != nil { - return nil, err - } - n = n.parent - - for _, sister := range proof.ProofSegments[1:] { - if isLeft { - root, err = doHash(hasher, root, sister) - } else { - root, err = doHash(hasher, sister, root) - } - if err != nil { - return nil, err - } - isLeft = n.isLeft - n = n.parent - } - return doHash(hasher, proof.Span, root) +// NewProver returns a Prover backed by a freshly allocated goroutine-based +// BMT hasher, independent of the SIMD opt-in flag. +func NewProver() *Prover { + return &Prover{goroutineHasher: newGoroutineHasher()} } -func (n *node) getSister(isLeft bool) []byte { - var src []byte +// NewPrefixProver is NewProver with an optional keccak prefix prepended to every +// BMT node hash. Also goroutine-backed regardless of SIMDOptIn. +func NewPrefixProver(prefix []byte) *Prover { + return &Prover{goroutineHasher: newGoroutinePrefixHasher(prefix)} +} - if isLeft { - src = n.right - } else { - src = n.left - } +// ProverPool is a pool of goroutine-backed Provers. Ignores SIMDOptIn by design. +type ProverPool interface { + GetProver() *Prover + PutProver(*Prover) +} - buf := make([]byte, len(src)) - copy(buf, src) - return buf +// NewProverPool returns a pool of goroutine-backed Provers, independent of +// SIMDOptIn. See Prover for the rationale. +func NewProverPool(c *Conf) ProverPool { + return newGoroutineProverPool(c) } diff --git a/pkg/bmt/proof_test.go b/pkg/bmt/proof_test.go index 4658cb353c7..c1145bf5e80 100644 --- a/pkg/bmt/proof_test.go +++ b/pkg/bmt/proof_test.go @@ -46,18 +46,17 @@ func TestProofCorrectness(t *testing.T) { } } - pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, 128, 128)) - hh := pool.Get() + pool := bmt.NewProverPool(bmt.NewConf(128, 128)) + pr := pool.GetProver() t.Cleanup(func() { - pool.Put(hh) + pool.PutProver(pr) }) - hh.SetHeaderInt64(4096) + pr.SetHeaderInt64(4096) - _, err := hh.Write(testData) + _, err := pr.Write(testData) if err != nil { t.Fatal(err) } - pr := bmt.Prover{hh} rh, err := pr.Hash(nil) if err != nil { t.Fatal(err) @@ -80,7 +79,7 @@ func TestProofCorrectness(t *testing.T) { verifySegments(t, expSegmentStrings, proof.ProofSegments) - if !bytes.Equal(proof.ProveSegment, testDataPadded[:hh.Size()]) { + if !bytes.Equal(proof.ProveSegment, testDataPadded[:pr.Size()]) { t.Fatal("section incorrect") } @@ -106,7 +105,7 @@ func TestProofCorrectness(t *testing.T) { verifySegments(t, expSegmentStrings, proof.ProofSegments) - if !bytes.Equal(proof.ProveSegment, testDataPadded[127*hh.Size():]) { + if !bytes.Equal(proof.ProveSegment, testDataPadded[127*pr.Size():]) { t.Fatal("section incorrect") } @@ -132,7 +131,7 @@ func TestProofCorrectness(t *testing.T) { verifySegments(t, expSegmentStrings, proof.ProofSegments) - if !bytes.Equal(proof.ProveSegment, testDataPadded[64*hh.Size():65*hh.Size()]) { + if !bytes.Equal(proof.ProveSegment, testDataPadded[64*pr.Size():65*pr.Size()]) { t.Fatal("section incorrect") } @@ -163,7 +162,7 @@ func TestProofCorrectness(t *testing.T) { segments = append(segments, decoded) } - segment := testDataPadded[64*hh.Size() : 65*hh.Size()] + segment := testDataPadded[64*pr.Size() : 65*pr.Size()] rootHash, err := pr.Verify(64, bmt.Proof{ ProveSegment: segment, @@ -191,20 +190,19 @@ func TestProof(t *testing.T) { t.Fatal(err) } - pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, 128, 128)) - hh := pool.Get() + pool := bmt.NewProverPool(bmt.NewConf(128, 128)) + pr := pool.GetProver() t.Cleanup(func() { - pool.Put(hh) + pool.PutProver(pr) }) - hh.SetHeaderInt64(4096) + pr.SetHeaderInt64(4096) - _, err = hh.Write(buf) + _, err = pr.Write(buf) if err != nil { t.Fatal(err) } - rh, err := hh.Hash(nil) - pr := bmt.Prover{hh} + rh, err := pr.Hash(nil) if err != nil { t.Fatal(err) } @@ -215,10 +213,10 @@ func TestProof(t *testing.T) { proof := pr.Proof(i) - h := pool.Get() - defer pool.Put(h) + verifier := pool.GetProver() + defer pool.PutProver(verifier) - root, err := bmt.Prover{h}.Verify(i, proof) + root, err := verifier.Verify(i, proof) if err != nil { t.Fatal(err) } diff --git a/pkg/bmtpool/bmtpool.go b/pkg/bmtpool/bmtpool.go index 88c1ad32dba..fd57fc9aeb9 100644 --- a/pkg/bmtpool/bmtpool.go +++ b/pkg/bmtpool/bmtpool.go @@ -7,26 +7,65 @@ package bmtpool import ( + "sync/atomic" + "github.com/ethersphere/bee/v2/pkg/bmt" "github.com/ethersphere/bee/v2/pkg/swarm" ) const Capacity = 32 -var instance *bmt.Pool +// instance holds the active bmt.Pool. Using an atomic pointer so Rebuild can +// swap it in after cmd/bee calls bmt.SetSIMDOptIn during startup, without +// introducing lock contention on the hot Get/Put path. +var instance atomic.Pointer[bmt.Pool] + +// proverInstance holds the Prover pool. Goroutine-backed by construction +// (see bmt.NewProverPool) and intentionally not affected by Rebuild or the +// SIMD opt-in flag: proofs run on a rare, redistribution-only code path where +// the well-tested goroutine implementation is preferred. +var proverInstance bmt.ProverPool // nolint:gochecknoinits func init() { - instance = bmt.NewPool(bmt.NewConf(swarm.NewHasher, swarm.BmtBranches, Capacity)) + // Eager init: construct the pool at package load time so its internal + // channel is created outside any testing/synctest bubble. Tests that use + // synctest would otherwise trip "receive on synctest channel from outside + // bubble" if the pool were first created inside a bubble. + // + // bmt.SIMDOptIn() is still false here (flag hasn't been parsed). cmd/bee + // calls Rebuild after flag parsing if the user opted in. + p := bmt.NewPool(bmt.NewConf(swarm.BmtBranches, Capacity)) + instance.Store(&p) + proverInstance = bmt.NewProverPool(bmt.NewConf(swarm.BmtBranches, Capacity)) +} + +// Rebuild discards the current pool and constructs a new one reading the +// latest bmt.SIMDOptIn() value. Must be called exactly once during startup +// after CLI flag parsing and before the first external Get, so that +// --use-simd-hashing takes effect and no in-flight hashers are referencing the +// old pool's tree. The Prover pool is intentionally not rebuilt. +func Rebuild() { + p := bmt.NewPool(bmt.NewConf(swarm.BmtBranches, Capacity)) + instance.Store(&p) +} + +// Get a bmt Hasher instance. Instances are reset before being returned to the caller. +func Get() bmt.Hasher { + return (*instance.Load()).Get() +} + +// Put a bmt Hasher back into the pool. +func Put(h bmt.Hasher) { + (*instance.Load()).Put(h) } -// Get a bmt Hasher instance. -// Instances are reset before being returned to the caller. -func Get() *bmt.Hasher { - return instance.Get() +// GetProver returns a goroutine-backed Prover from the global prover pool. +func GetProver() *bmt.Prover { + return proverInstance.GetProver() } -// Put a bmt Hasher back into the pool -func Put(h *bmt.Hasher) { - instance.Put(h) +// PutProver returns a Prover to the global prover pool for reuse. +func PutProver(p *bmt.Prover) { + proverInstance.PutProver(p) } diff --git a/pkg/keccak/keccak.go b/pkg/keccak/keccak.go new file mode 100644 index 00000000000..43b14eee713 --- /dev/null +++ b/pkg/keccak/keccak.go @@ -0,0 +1,71 @@ +// Copyright 2026 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package keccak provides legacy Keccak-256 (Ethereum-compatible) hashing +// with SIMD acceleration via XKCP. +// +// Sum256x4 uses AVX2 (4-way parallel). Sum256x8 uses AVX-512 (8-way parallel). +// The caller is responsible for dispatching to the correct function based on +// CPU capabilities; Sum256x8 will crash on non-AVX-512 hardware. +// +// # Platform availability +// +// Sum256x4 and Sum256x8 are only compiled on linux/amd64 (with the !purego +// build tag). On every other platform (darwin, windows, linux/arm64, or any +// build configured with -tags=purego) the symbols do not exist; callers must +// route around them using the HasSIMD / BatchWidth helpers below, which are +// available on all platforms and return false / 0 where no SIMD hasher is +// linked in. +// +// # Equal-length (or nil) input constraint +// +// All non-nil inputs passed to Sum256x4 / Sum256x8 in the same call MUST have +// the same length. This is an intrinsic limit of the batched XKCP primitive: +// KeccakP1600timesN_PermuteAll_24rounds advances the sponge state of every +// lane in lockstep, so any lane that has already absorbed its final padding +// block gets extra, unwanted permutations every time a longer lane absorbs +// another block — silently corrupting its output. +// +// Nil (or zero-length) lanes are allowed for partial batches: they are +// treated as no-op fillers so the caller can populate N_real < N lanes with +// real data and ignore the remaining N - N_real output digests. The digests +// produced for nil lanes are not meaningful and must not be consumed. Mixing +// distinct non-zero lengths within one call is unsupported and will produce +// wrong digests for the shorter lanes, even when every individual length +// would work on its own in an all-same-length call. +package keccak + +import "encoding/hex" + +// Hash256 represents a 32-byte Keccak-256 hash +type Hash256 [32]byte + +// HexString returns the hash as a hexadecimal string +func (h Hash256) HexString() string { + return hex.EncodeToString(h[:]) +} + +// HasAVX512 reports whether the CPU supports AVX-512 (F + VL) and the +// AVX-512 code path is available. +func HasAVX512() bool { + return hasAVX512 +} + +// HasSIMD reports whether any SIMD-accelerated Keccak path is available +// (AVX2 or AVX-512). +func HasSIMD() bool { + return hasAVX2 || hasAVX512 +} + +// BatchWidth returns the SIMD batch width: 8 for AVX-512, 4 for AVX2, or 0 +// if no SIMD acceleration is available. +func BatchWidth() int { + if hasAVX512 { + return 8 + } + if hasAVX2 { + return 4 + } + return 0 +} diff --git a/pkg/keccak/keccak_amd64.go b/pkg/keccak/keccak_amd64.go new file mode 100644 index 00000000000..3baac55736d --- /dev/null +++ b/pkg/keccak/keccak_amd64.go @@ -0,0 +1,13 @@ +// Copyright 2026 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && amd64 && !purego + +package keccak + +//go:noescape +func keccak256x4(inputs *[4][]byte, outputs *[4]Hash256) + +//go:noescape +func keccak256x8(inputs *[8][]byte, outputs *[8]Hash256) diff --git a/pkg/keccak/keccak_cpu.go b/pkg/keccak/keccak_cpu.go new file mode 100644 index 00000000000..5d0d7763f22 --- /dev/null +++ b/pkg/keccak/keccak_cpu.go @@ -0,0 +1,14 @@ +// Copyright 2026 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && amd64 && !purego + +package keccak + +import cpuid "github.com/klauspost/cpuid/v2" + +var ( + hasAVX2 = cpuid.CPU.Supports(cpuid.AVX2) + hasAVX512 = cpuid.CPU.Supports(cpuid.AVX512F, cpuid.AVX512VL) +) diff --git a/pkg/keccak/keccak_cpu_other.go b/pkg/keccak/keccak_cpu_other.go new file mode 100644 index 00000000000..1258c36ce17 --- /dev/null +++ b/pkg/keccak/keccak_cpu_other.go @@ -0,0 +1,14 @@ +// Copyright 2026 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !linux || !amd64 || purego + +package keccak + +// No SIMD Keccak implementations are available on this platform; +// Sum256x4/Sum256x8 will fall back to the scalar goroutine path. +var ( + hasAVX2 = false + hasAVX512 = false +) diff --git a/pkg/keccak/keccak_simd.go b/pkg/keccak/keccak_simd.go new file mode 100644 index 00000000000..126c53cb909 --- /dev/null +++ b/pkg/keccak/keccak_simd.go @@ -0,0 +1,42 @@ +// Copyright 2026 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && amd64 && !purego + +package keccak + +// Sum256x4 computes 4 Keccak-256 hashes in parallel using AVX2. +// +// All non-nil inputs MUST have the same length; nil (or zero-length) lanes +// are allowed as partial-batch fillers whose output must be ignored. See the +// package doc for the rationale. +// +// Only compiled on linux/amd64 (the only platform where the XKCP .syso is +// linkable). Call sites that may run on other platforms must be gated on the +// same build tag or on keccak.HasSIMD() at runtime. +func Sum256x4(inputs [4][]byte) [4]Hash256 { + var outputs [4]Hash256 + var inputsCopy [4][]byte + copy(inputsCopy[:], inputs[:]) + keccak256x4(&inputsCopy, &outputs) + return outputs +} + +// Sum256x8 computes 8 Keccak-256 hashes in parallel using AVX-512. +// Must only be called on AVX-512-capable hardware. +// +// All non-nil inputs MUST have the same length; nil (or zero-length) lanes +// are allowed as partial-batch fillers whose output must be ignored. See the +// package doc for the rationale. +// +// Only compiled on linux/amd64 (the only platform where the XKCP .syso is +// linkable). Call sites that may run on other platforms must be gated on the +// same build tag or on keccak.HasAVX512() at runtime. +func Sum256x8(inputs [8][]byte) [8]Hash256 { + var outputs [8]Hash256 + var inputsCopy [8][]byte + copy(inputsCopy[:], inputs[:]) + keccak256x8(&inputsCopy, &outputs) + return outputs +} diff --git a/pkg/keccak/keccak_test.go b/pkg/keccak/keccak_test.go new file mode 100644 index 00000000000..1ac1751e21b --- /dev/null +++ b/pkg/keccak/keccak_test.go @@ -0,0 +1,144 @@ +// Copyright 2026 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux && amd64 && !purego + +package keccak + +import ( + "bytes" + "crypto/rand" + "fmt" + "testing" + + "golang.org/x/crypto/sha3" +) + +// referenceHash returns the legacy Keccak-256 digest using golang.org/x/crypto, +// which is the same function the BMT hasher uses in its goroutine path. +func referenceHash(data []byte) []byte { + h := sha3.NewLegacyKeccak256() + h.Write(data) + return h.Sum(nil) +} + +// testInputLengths covers a wide spread: sub-rate (<136), one-byte-before the +// rate (135 — where Keccak padding forces the 0x81 collision), exact rate, +// just over, and multi-block sizes out to 4096 bytes. The BMT caller only +// uses 64/96-byte inputs; the larger sizes exercise the lockstep multi-block +// absorption in the C wrapper. +var testInputLengths = []int{0, 1, 31, 32, 63, 64, 65, 95, 96, 97, 128, 134, 135, 136, 137, 200, 272, 1024, 4096} + +func TestSum256x4(t *testing.T) { + if !HasSIMD() { + t.Skip("AVX2 not available on this CPU") + } + + for _, length := range testInputLengths { + t.Run(fmt.Sprintf("len_%d", length), func(t *testing.T) { + var inputs [4][]byte + var expected [4][]byte + for i := range inputs { + buf := make([]byte, length) + _, _ = rand.Read(buf) + inputs[i] = buf + expected[i] = referenceHash(buf) + } + + got := Sum256x4(inputs) + for i := range got { + if !bytes.Equal(got[i][:], expected[i]) { + t.Errorf("lane %d mismatch:\n got %x\n want %x", i, got[i], expected[i]) + } + } + }) + } +} + +// TestSum256x4_PartialBatch mirrors the BMT caller's usage when the last batch +// is partial: some lanes carry real data, the rest are nil. The wrapper +// produces a (meaningless) digest for the nil lanes and correct digests for +// the real lanes — only the latter are asserted. +func TestSum256x4_PartialBatch(t *testing.T) { + if !HasSIMD() { + t.Skip("AVX2 not available on this CPU") + } + + for _, realLen := range []int{64, 96, 128} { + for realLanes := 1; realLanes < 4; realLanes++ { + t.Run(fmt.Sprintf("real_%d_of_4_len_%d", realLanes, realLen), func(t *testing.T) { + var inputs [4][]byte + var expected [4][]byte + for i := range realLanes { + buf := make([]byte, realLen) + _, _ = rand.Read(buf) + inputs[i] = buf + expected[i] = referenceHash(buf) + } + // inputs[realLanes..3] remain nil + + got := Sum256x4(inputs) + for i := range realLanes { + if !bytes.Equal(got[i][:], expected[i]) { + t.Errorf("real lane %d mismatch:\n got %x\n want %x", i, got[i], expected[i]) + } + } + }) + } + } +} + +func TestSum256x8(t *testing.T) { + if !HasAVX512() { + t.Skip("AVX-512 not available on this CPU") + } + + for _, length := range testInputLengths { + t.Run(fmt.Sprintf("len_%d", length), func(t *testing.T) { + var inputs [8][]byte + var expected [8][]byte + for i := range inputs { + buf := make([]byte, length) + _, _ = rand.Read(buf) + inputs[i] = buf + expected[i] = referenceHash(buf) + } + + got := Sum256x8(inputs) + for i := range got { + if !bytes.Equal(got[i][:], expected[i]) { + t.Errorf("lane %d mismatch:\n got %x\n want %x", i, got[i], expected[i]) + } + } + }) + } +} + +func TestSum256x8_PartialBatch(t *testing.T) { + if !HasAVX512() { + t.Skip("AVX-512 not available on this CPU") + } + + for _, realLen := range []int{64, 96, 128} { + for realLanes := 1; realLanes < 8; realLanes++ { + t.Run(fmt.Sprintf("real_%d_of_8_len_%d", realLanes, realLen), func(t *testing.T) { + var inputs [8][]byte + var expected [8][]byte + for i := range realLanes { + buf := make([]byte, realLen) + _, _ = rand.Read(buf) + inputs[i] = buf + expected[i] = referenceHash(buf) + } + + got := Sum256x8(inputs) + for i := range realLanes { + if !bytes.Equal(got[i][:], expected[i]) { + t.Errorf("real lane %d mismatch:\n got %x\n want %x", i, got[i], expected[i]) + } + } + }) + } + } +} diff --git a/pkg/keccak/keccak_times4_linux_amd64.s b/pkg/keccak/keccak_times4_linux_amd64.s new file mode 100644 index 00000000000..dc06dc8ec25 --- /dev/null +++ b/pkg/keccak/keccak_times4_linux_amd64.s @@ -0,0 +1,10 @@ +//go:build amd64 && !purego + +#include "textflag.h" + +// func keccak256x4(inputs *[4][]byte, outputs *[4]Hash256) +TEXT ·keccak256x4(SB), $8192-16 + MOVQ inputs+0(FP), DI + MOVQ outputs+8(FP), SI + CALL go_keccak256x4(SB) + RET diff --git a/pkg/keccak/keccak_times4_linux_amd64.syso b/pkg/keccak/keccak_times4_linux_amd64.syso new file mode 100644 index 00000000000..5443a23f286 Binary files /dev/null and b/pkg/keccak/keccak_times4_linux_amd64.syso differ diff --git a/pkg/keccak/keccak_times8_linux_amd64.s b/pkg/keccak/keccak_times8_linux_amd64.s new file mode 100644 index 00000000000..81647acb01b --- /dev/null +++ b/pkg/keccak/keccak_times8_linux_amd64.s @@ -0,0 +1,12 @@ +//go:build amd64 && !purego + +#include "textflag.h" + +// func keccak256x8(inputs *[8][]byte, outputs *[8]Hash256) +// Frame size 16384: AVX-512 state is larger (25 x 64 bytes = 1600 bytes) and +// the permutation uses more stack. Generous headroom provided. +TEXT ·keccak256x8(SB), $16384-16 + MOVQ inputs+0(FP), DI + MOVQ outputs+8(FP), SI + CALL go_keccak256x8(SB) + RET diff --git a/pkg/keccak/keccak_times8_linux_amd64.syso b/pkg/keccak/keccak_times8_linux_amd64.syso new file mode 100644 index 00000000000..2f9ce34e549 Binary files /dev/null and b/pkg/keccak/keccak_times8_linux_amd64.syso differ diff --git a/pkg/storageincentives/main_test.go b/pkg/storageincentives/main_test.go index b302a1833fb..9b6e13090a0 100644 --- a/pkg/storageincentives/main_test.go +++ b/pkg/storageincentives/main_test.go @@ -7,9 +7,11 @@ package storageincentives_test import ( "testing" + "github.com/ethersphere/bee/v2/pkg/bmt" "go.uber.org/goleak" ) func TestMain(m *testing.M) { + bmt.SetSIMDOptIn(true) goleak.VerifyTestMain(m) } diff --git a/pkg/storageincentives/proof.go b/pkg/storageincentives/proof.go index 86dd115c67e..a95987b1d2f 100644 --- a/pkg/storageincentives/proof.go +++ b/pkg/storageincentives/proof.go @@ -7,7 +7,6 @@ package storageincentives import ( "errors" "fmt" - "hash" "math/big" "github.com/ethersphere/bee/v2/pkg/bmt" @@ -55,13 +54,10 @@ func makeInclusionProofs( require2++ } - prefixHasherFactory := func() hash.Hash { - return swarm.NewPrefixHasher(anchor1) - } - prefixHasherPool := bmt.NewPool(bmt.NewConf(prefixHasherFactory, swarm.BmtBranches, 8)) + prefixHasherPool := bmt.NewProverPool(bmt.NewConfWithPrefix(anchor1, swarm.BmtBranches, 8)) // Sample chunk proofs - rccontent := bmt.Prover{Hasher: bmtpool.Get()} + rccontent := bmtpool.GetProver() rccontent.SetHeaderInt64(swarm.HashSize * storer.SampleSize * 2) rsc, err := sampleChunk(reserveSampleItems) if err != nil { @@ -79,12 +75,12 @@ func makeInclusionProofs( proof1p1 := rccontent.Proof(int(require1) * 2) proof2p1 := rccontent.Proof(int(require2) * 2) proofLastp1 := rccontent.Proof(require3 * 2) - bmtpool.Put(rccontent.Hasher) + bmtpool.PutProver(rccontent) // Witness1 proofs segmentIndex := int(new(big.Int).Mod(new(big.Int).SetBytes(anchor2), big.NewInt(int64(128))).Uint64()) // OG chunk proof - chunk1Content := bmt.Prover{Hasher: bmtpool.Get()} + chunk1Content := bmtpool.GetProver() chunk1Offset := spanOffset(reserveSampleItems[require1]) chunk1Content.SetHeader(reserveSampleItems[require1].ChunkData[chunk1Offset : chunk1Offset+swarm.SpanSize]) chunk1ContentPayload := reserveSampleItems[require1].ChunkData[chunk1Offset+swarm.SpanSize:] @@ -98,7 +94,7 @@ func makeInclusionProofs( } proof1p2 := chunk1Content.Proof(segmentIndex) // TR chunk proof - chunk1TrContent := bmt.Prover{Hasher: prefixHasherPool.Get()} + chunk1TrContent := prefixHasherPool.GetProver() chunk1TrContent.SetHeader(reserveSampleItems[require1].ChunkData[chunk1Offset : chunk1Offset+swarm.SpanSize]) _, err = chunk1TrContent.Write(chunk1ContentPayload) if err != nil { @@ -110,13 +106,13 @@ func makeInclusionProofs( } proof1p3 := chunk1TrContent.Proof(segmentIndex) // cleanup - bmtpool.Put(chunk1Content.Hasher) - prefixHasherPool.Put(chunk1TrContent.Hasher) + bmtpool.PutProver(chunk1Content) + prefixHasherPool.PutProver(chunk1TrContent) // Witness2 proofs // OG Chunk proof chunk2Offset := spanOffset(reserveSampleItems[require2]) - chunk2Content := bmt.Prover{Hasher: bmtpool.Get()} + chunk2Content := bmtpool.GetProver() chunk2ContentPayload := reserveSampleItems[require2].ChunkData[chunk2Offset+swarm.SpanSize:] chunk2Content.SetHeader(reserveSampleItems[require2].ChunkData[chunk2Offset : chunk2Offset+swarm.SpanSize]) _, err = chunk2Content.Write(chunk2ContentPayload) @@ -129,7 +125,7 @@ func makeInclusionProofs( } proof2p2 := chunk2Content.Proof(segmentIndex) // TR Chunk proof - chunk2TrContent := bmt.Prover{Hasher: prefixHasherPool.Get()} + chunk2TrContent := prefixHasherPool.GetProver() chunk2TrContent.SetHeader(reserveSampleItems[require2].ChunkData[chunk2Offset : chunk2Offset+swarm.SpanSize]) _, err = chunk2TrContent.Write(chunk2ContentPayload) if err != nil { @@ -141,13 +137,13 @@ func makeInclusionProofs( } proof2p3 := chunk2TrContent.Proof(segmentIndex) // cleanup - bmtpool.Put(chunk2Content.Hasher) - prefixHasherPool.Put(chunk2TrContent.Hasher) + bmtpool.PutProver(chunk2Content) + prefixHasherPool.PutProver(chunk2TrContent) // Witness3 proofs // OG Chunk proof chunkLastOffset := spanOffset(reserveSampleItems[require3]) - chunkLastContent := bmt.Prover{Hasher: bmtpool.Get()} + chunkLastContent := bmtpool.GetProver() chunkLastContent.SetHeader(reserveSampleItems[require3].ChunkData[chunkLastOffset : chunkLastOffset+swarm.SpanSize]) chunkLastContentPayload := reserveSampleItems[require3].ChunkData[chunkLastOffset+swarm.SpanSize:] _, err = chunkLastContent.Write(chunkLastContentPayload) @@ -160,7 +156,7 @@ func makeInclusionProofs( } proofLastp2 := chunkLastContent.Proof(segmentIndex) // TR Chunk Proof - chunkLastTrContent := bmt.Prover{Hasher: prefixHasherPool.Get()} + chunkLastTrContent := prefixHasherPool.GetProver() chunkLastTrContent.SetHeader(reserveSampleItems[require3].ChunkData[chunkLastOffset : chunkLastOffset+swarm.SpanSize]) _, err = chunkLastTrContent.Write(chunkLastContentPayload) if err != nil { @@ -172,8 +168,8 @@ func makeInclusionProofs( } proofLastp3 := chunkLastTrContent.Proof(segmentIndex) // cleanup - bmtpool.Put(chunkLastContent.Hasher) - prefixHasherPool.Put(chunkLastTrContent.Hasher) + bmtpool.PutProver(chunkLastContent) + prefixHasherPool.PutProver(chunkLastTrContent) // map to output and add SOC related data if it is necessary A, err := redistribution.NewChunkInclusionProof(proof1p1, proof1p2, proof1p3, reserveSampleItems[require1]) diff --git a/pkg/storageincentives/soc_mine_test.go b/pkg/storageincentives/soc_mine_test.go index 0265a9a21f7..6193ede8e22 100644 --- a/pkg/storageincentives/soc_mine_test.go +++ b/pkg/storageincentives/soc_mine_test.go @@ -9,12 +9,10 @@ import ( "encoding/binary" "encoding/hex" "fmt" - "hash" "math/big" "os" "sync" "testing" - "testing/synctest" "github.com/ethersphere/bee/v2/pkg/bmt" "github.com/ethersphere/bee/v2/pkg/cac" @@ -33,62 +31,58 @@ import ( // to generate uploads using the input // cat socs.txt | tail 19 | head 16 | perl -pne 's/([a-f0-9]+)\t([a-f0-9]+)\t([a-f0-9]+)\t([a-f0-9]+)/echo -n $4 | xxd -r -p | curl -X POST \"http:\/\/localhost:1633\/soc\/$1\/$2?sig=$3\" -H \"accept: application\/json, text\/plain, \/\" -H \"content-type: application\/octet-stream\" -H \"swarm-postage-batch-id: 14b26beca257e763609143c6b04c2c487f01a051798c535c2f542ce75a97c05f\" --data-binary \@-/' func TestSocMine(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - // the anchor used in neighbourhood selection and reserve salt for sampling - prefix, err := hex.DecodeString("3617319a054d772f909f7c479a2cebe5066e836a939412e32403c99029b92eff") - if err != nil { - t.Fatal(err) - } - // the transformed address hasher factory function - prefixhasher := func() hash.Hash { return swarm.NewPrefixHasher(prefix) } - // Create a pool for efficient hasher reuse - trHasherPool := bmt.NewPool(bmt.NewConf(prefixhasher, swarm.BmtBranches, 8)) - // the bignum cast of the maximum sample value (upper bound on transformed addresses as a 256-bit article) - // this constant is for a minimum reserve size of 2 million chunks with sample size of 16 - // = 1.284401 * 10^71 = 1284401 + 66 0-s - mstring := "1284401" - for range 66 { - mstring = mstring + "0" - } - n, ok := new(big.Int).SetString(mstring, 10) - if !ok { - t.Fatalf("SetString: error setting to '%s'", mstring) - } - // the filter function on the SOC address - // meant to make sure we pass check for proof of retrievability for - // a node of overlay 0x65xxx with a reserve depth of 1, i.e., - // SOC address must start with zero bit - filterSOCAddr := func(a swarm.Address) bool { - return a.Bytes()[0]&0x80 != 0x00 - } - // the filter function on the transformed address using the density estimation constant - filterTrAddr := func(a swarm.Address) (bool, error) { - m := new(big.Int).SetBytes(a.Bytes()) - return m.Cmp(n) < 0, nil - } - // setup the signer with a private key from a fixture - data, err := hex.DecodeString("634fb5a872396d9693e5c9f9d7233cfa93f395c093371017ff44aa9ae6564cdd") - if err != nil { - t.Fatal(err) - } - privKey, err := crypto.DecodeSecp256k1PrivateKey(data) - if err != nil { - t.Fatal(err) - } - signer := crypto.NewDefaultSigner(privKey) + // the anchor used in neighbourhood selection and reserve salt for sampling + prefix, err := hex.DecodeString("3617319a054d772f909f7c479a2cebe5066e836a939412e32403c99029b92eff") + if err != nil { + t.Fatal(err) + } + // Create a pool for efficient hasher reuse + trHasherPool := bmt.NewPool(bmt.NewConfWithPrefix(prefix, swarm.BmtBranches, 8)) + // the bignum cast of the maximum sample value (upper bound on transformed addresses as a 256-bit article) + // this constant is for a minimum reserve size of 2 million chunks with sample size of 16 + // = 1.284401 * 10^71 = 1284401 + 66 0-s + mstring := "1284401" + for range 66 { + mstring = mstring + "0" + } + n, ok := new(big.Int).SetString(mstring, 10) + if !ok { + t.Fatalf("SetString: error setting to '%s'", mstring) + } + // the filter function on the SOC address + // meant to make sure we pass check for proof of retrievability for + // a node of overlay 0x65xxx with a reserve depth of 1, i.e., + // SOC address must start with zero bit + filterSOCAddr := func(a swarm.Address) bool { + return a.Bytes()[0]&0x80 != 0x00 + } + // the filter function on the transformed address using the density estimation constant + filterTrAddr := func(a swarm.Address) (bool, error) { + m := new(big.Int).SetBytes(a.Bytes()) + return m.Cmp(n) < 0, nil + } + // setup the signer with a private key from a fixture + data, err := hex.DecodeString("634fb5a872396d9693e5c9f9d7233cfa93f395c093371017ff44aa9ae6564cdd") + if err != nil { + t.Fatal(err) + } + privKey, err := crypto.DecodeSecp256k1PrivateKey(data) + if err != nil { + t.Fatal(err) + } + signer := crypto.NewDefaultSigner(privKey) - sampleSize := 16 - // for sanity check: given a filterSOCAddr requiring a 0 leading bit (chance of 1/2) - // we expect an overall rough 4 million chunks to be mined to create this sample - // for 8 workers that is half a million round on average per worker - err = makeChunks(t, signer, sampleSize, filterSOCAddr, filterTrAddr, trHasherPool) - if err != nil { - t.Fatal(err) - } - }) + sampleSize := 16 + // for sanity check: given a filterSOCAddr requiring a 0 leading bit (chance of 1/2) + // we expect an overall rough 4 million chunks to be mined to create this sample + // for 8 workers that is half a million round on average per worker + err = makeChunks(t, signer, sampleSize, filterSOCAddr, filterTrAddr, trHasherPool) + if err != nil { + t.Fatal(err) + } } -func makeChunks(t *testing.T, signer crypto.Signer, sampleSize int, filterSOCAddr func(swarm.Address) bool, filterTrAddr func(swarm.Address) (bool, error), trHasherPool *bmt.Pool) error { +func makeChunks(t *testing.T, signer crypto.Signer, sampleSize int, filterSOCAddr func(swarm.Address) bool, filterTrAddr func(swarm.Address) (bool, error), trHasherPool bmt.Pool) error { t.Helper() // set owner address from signer diff --git a/pkg/storer/sample.go b/pkg/storer/sample.go index 869da215e7d..7e64a09e392 100644 --- a/pkg/storer/sample.go +++ b/pkg/storer/sample.go @@ -9,7 +9,6 @@ import ( "context" "encoding/binary" "fmt" - "hash" "math/big" "runtime" "sort" @@ -121,16 +120,12 @@ func (db *DB) ReserveSample( // Phase 2: Get the chunk data and calculate transformed hash sampleItemChan := make(chan SampleItem, 3*workers) - prefixHasherFactory := func() hash.Hash { - return swarm.NewPrefixHasher(anchor) - } - db.logger.Debug("reserve sampler workers", "count", workers) for range workers { g.Go(func() error { wstat := SampleStats{} - hasher := bmt.NewHasher(prefixHasherFactory) + hasher := bmt.NewPrefixHasher(anchor) defer func() { addStats(wstat) }() @@ -299,7 +294,7 @@ func (db *DB) batchesBelowValue(until *big.Int) (map[string]struct{}, error) { return res, err } -func transformedAddress(hasher *bmt.Hasher, chunk swarm.Chunk, chType swarm.ChunkType) (swarm.Address, error) { +func transformedAddress(hasher bmt.Hasher, chunk swarm.Chunk, chType swarm.ChunkType) (swarm.Address, error) { switch chType { case swarm.ChunkTypeContentAddressed: return transformedAddressCAC(hasher, chunk) @@ -310,7 +305,7 @@ func transformedAddress(hasher *bmt.Hasher, chunk swarm.Chunk, chType swarm.Chun } } -func transformedAddressCAC(hasher *bmt.Hasher, chunk swarm.Chunk) (swarm.Address, error) { +func transformedAddressCAC(hasher bmt.Hasher, chunk swarm.Chunk) (swarm.Address, error) { hasher.Reset() hasher.SetHeader(chunk.Data()[:bmt.SpanSize]) @@ -327,7 +322,7 @@ func transformedAddressCAC(hasher *bmt.Hasher, chunk swarm.Chunk) (swarm.Address return swarm.NewAddress(taddr), nil } -func transformedAddressSOC(hasher *bmt.Hasher, socChunk swarm.Chunk) (swarm.Address, error) { +func transformedAddressSOC(hasher bmt.Hasher, socChunk swarm.Chunk) (swarm.Address, error) { // Calculate transformed address from wrapped chunk cacChunk, err := soc.UnwrapCAC(socChunk) if err != nil { @@ -407,12 +402,9 @@ func RandSample(t *testing.T, anchor []byte) Sample { // MakeSampleUsingChunks returns Sample constructed using supplied chunks. func MakeSampleUsingChunks(chunks []swarm.Chunk, anchor []byte) (Sample, error) { - prefixHasherFactory := func() hash.Hash { - return swarm.NewPrefixHasher(anchor) - } items := make([]SampleItem, len(chunks)) for i, ch := range chunks { - tr, err := transformedAddress(bmt.NewHasher(prefixHasherFactory), ch, getChunkType(ch)) + tr, err := transformedAddress(bmt.NewPrefixHasher(anchor), ch, getChunkType(ch)) if err != nil { return Sample{}, err }