Files
mev-beta/vendor/github.com/ethereum/go-verkle/conversion.go

227 lines
6.7 KiB
Go

package verkle
import (
"bytes"
"context"
"fmt"
"runtime"
"sort"
"github.com/crate-crypto/go-ipa/banderwagon"
"golang.org/x/sync/errgroup"
)
// BatchNewLeafNodeData is a struct that contains the data needed to create a new leaf node.
type BatchNewLeafNodeData struct {
Stem Stem
Values map[byte][]byte
}
// BatchNewLeafNode creates a new leaf node from the given data. It optimizes LeafNode creation
// by batching expensive cryptography operations. It returns the LeafNodes sorted by stem.
func BatchNewLeafNode(nodesValues []BatchNewLeafNodeData) ([]LeafNode, error) {
cfg := GetConfig()
ret := make([]LeafNode, len(nodesValues))
numBatches := runtime.NumCPU()
batchSize := len(nodesValues) / numBatches
group, _ := errgroup.WithContext(context.Background())
for i := 0; i < numBatches; i++ {
start := i * batchSize
end := (i + 1) * batchSize
if i == numBatches-1 {
end = len(nodesValues)
}
work := func(ret []LeafNode, nodesValues []BatchNewLeafNodeData) func() error {
return func() error {
c1c2points := make([]*Point, 2*len(nodesValues))
c1c2frs := make([]*Fr, 2*len(nodesValues))
for i, nv := range nodesValues {
valsslice := make([][]byte, NodeWidth)
for idx := range nv.Values {
valsslice[idx] = nv.Values[idx]
}
var leaf *LeafNode
leaf, err := NewLeafNode(nv.Stem, valsslice)
if err != nil {
return err
}
ret[i] = *leaf
c1c2points[2*i], c1c2points[2*i+1] = ret[i].c1, ret[i].c2
c1c2frs[2*i], c1c2frs[2*i+1] = new(Fr), new(Fr)
}
if err := banderwagon.BatchMapToScalarField(c1c2frs, c1c2points); err != nil {
return fmt.Errorf("mapping to scalar field: %s", err)
}
var poly [NodeWidth]Fr
poly[0].SetUint64(1)
for i, nv := range nodesValues {
if err := StemFromLEBytes(&poly[1], nv.Stem); err != nil {
return err
}
poly[2] = *c1c2frs[2*i]
poly[3] = *c1c2frs[2*i+1]
ret[i].commitment = cfg.CommitToPoly(poly[:], 252)
}
return nil
}
}
group.Go(work(ret[start:end], nodesValues[start:end]))
}
if err := group.Wait(); err != nil {
return nil, fmt.Errorf("creating leaf node: %s", err)
}
sort.Slice(ret, func(i, j int) bool {
return bytes.Compare(ret[i].stem, ret[j].stem) < 0
})
return ret, nil
}
// firstDiffByteIdx will return the first index in which the two stems differ.
// Both stems *must* be different.
func firstDiffByteIdx(stem1 []byte, stem2 []byte) int {
for i := range stem1 {
if stem1[i] != stem2[i] {
return i
}
}
panic("stems are equal")
}
func (n *InternalNode) InsertMigratedLeaves(leaves []LeafNode, resolver NodeResolverFn) error {
sort.Slice(leaves, func(i, j int) bool {
return bytes.Compare(leaves[i].stem, leaves[j].stem) < 0
})
// We first mark all children of the subtreess that we'll update in parallel,
// so the subtree updating doesn't produce a concurrent access to n.cowChild(...).
var lastChildrenIdx = -1
for i := range leaves {
if int(leaves[i].stem[0]) != lastChildrenIdx {
lastChildrenIdx = int(leaves[i].stem[0])
if _, ok := n.children[lastChildrenIdx].(HashedNode); ok {
serialized, err := resolver([]byte{byte(lastChildrenIdx)})
if err != nil {
return fmt.Errorf("resolving node: %s", err)
}
resolved, err := ParseNode(serialized, 1)
if err != nil {
return fmt.Errorf("parsing node %x: %w", serialized, err)
}
n.children[lastChildrenIdx] = resolved
}
n.cowChild(byte(lastChildrenIdx))
}
}
// We insert the migrated leaves for each subtree of the root node.
group, _ := errgroup.WithContext(context.Background())
group.SetLimit(runtime.NumCPU())
currStemFirstByte := 0
for i := range leaves {
if leaves[currStemFirstByte].stem[0] != leaves[i].stem[0] {
start := currStemFirstByte
end := i
group.Go(func() error {
return n.insertMigratedLeavesSubtree(leaves[start:end], resolver)
})
currStemFirstByte = i
}
}
group.Go(func() error {
return n.insertMigratedLeavesSubtree(leaves[currStemFirstByte:], resolver)
})
if err := group.Wait(); err != nil {
return fmt.Errorf("inserting migrated leaves: %w", err)
}
return nil
}
func (n *InternalNode) insertMigratedLeavesSubtree(leaves []LeafNode, resolver NodeResolverFn) error { // skipcq: GO-R1005
for i := range leaves {
ln := leaves[i]
parent := n
// Look for the appropriate parent for the leaf node.
for {
if _, ok := parent.children[ln.stem[parent.depth]].(HashedNode); ok {
serialized, err := resolver(ln.stem[:parent.depth+1])
if err != nil {
return fmt.Errorf("resolving node path=%x: %w", ln.stem[:parent.depth+1], err)
}
resolved, err := ParseNode(serialized, parent.depth+1)
if err != nil {
return fmt.Errorf("parsing node %x: %w", serialized, err)
}
parent.children[ln.stem[parent.depth]] = resolved
}
nextParent, ok := parent.children[ln.stem[parent.depth]].(*InternalNode)
if !ok {
break
}
parent.cowChild(ln.stem[parent.depth])
parent = nextParent
}
switch node := parent.children[ln.stem[parent.depth]].(type) {
case Empty:
parent.cowChild(ln.stem[parent.depth])
parent.children[ln.stem[parent.depth]] = &ln
ln.setDepth(parent.depth + 1)
case *LeafNode:
if bytes.Equal(node.stem, ln.stem) {
// In `ln` we have migrated key/values which should be copied to the leaf
// only if there isn't a value there. If there's a value, we skip it since
// our migrated value is stale.
nonPresentValues := make([][]byte, NodeWidth)
for i := range ln.values {
if node.values[i] == nil {
nonPresentValues[i] = ln.values[i]
}
}
if err := node.updateMultipleLeaves(nonPresentValues); err != nil {
return fmt.Errorf("updating leaves: %s", err)
}
continue
}
// Otherwise, we need to create the missing internal nodes depending in the fork point in their stems.
idx := firstDiffByteIdx(node.stem, ln.stem)
// We do a sanity check to make sure that the fork point is not before the current depth.
if byte(idx) <= parent.depth {
return fmt.Errorf("unexpected fork point %d for nodes %x and %x", idx, node.stem, ln.stem)
}
// Create the missing internal nodes.
for i := parent.depth + 1; i <= byte(idx); i++ {
nextParent := newInternalNode(parent.depth + 1).(*InternalNode)
parent.cowChild(ln.stem[parent.depth])
parent.children[ln.stem[parent.depth]] = nextParent
parent = nextParent
}
// Add old and new leaf node to the latest created parent.
parent.cowChild(node.stem[parent.depth])
parent.children[node.stem[parent.depth]] = node
node.setDepth(parent.depth + 1)
parent.cowChild(ln.stem[parent.depth])
parent.children[ln.stem[parent.depth]] = &ln
ln.setDepth(parent.depth + 1)
default:
return fmt.Errorf("unexpected node type %T", node)
}
}
return nil
}