Files
mev-beta/vendor/github.com/crate-crypto/go-ipa/banderwagon/precomp.go

244 lines
7.2 KiB
Go

package banderwagon
import (
"context"
"fmt"
"runtime"
"github.com/crate-crypto/go-ipa/bandersnatch"
"github.com/crate-crypto/go-ipa/bandersnatch/fp"
"github.com/crate-crypto/go-ipa/bandersnatch/fr"
"github.com/crate-crypto/go-ipa/common/parallel"
"golang.org/x/sync/errgroup"
)
const (
// supportedMSMLength is the number of points supported by the precomputed tables.
supportedMSMLength = 256
// window16vs8IndexLimit is the index of the first point that will use a 8-bit window instead of a 16-bit window.
window16vs8IndexLimit = 5
)
// MSMPrecomp is an engine to calculate 256-MSM on a fixed basis using precomputed tables.
// This precomputed tables design are biased to support an efficient MSM for Verkle Trees.
//
// Their design involves 16-bit windows for the first window16vs8IndexLimit points, and 8-bit
// windows for the rest. The motivation for this is that the first points are used to calculate
// tree keys, which clients heavily rely on compared to "longer" MSMs. This provides a significant
// boost to tree-key generation without exploding table sizes.
type MSMPrecomp struct {
precompPoints [supportedMSMLength]PrecompPoint
}
// NewPrecompMSM creates a new MSMPrecomp.
func NewPrecompMSM(points []Element) (MSMPrecomp, error) {
if len(points) != supportedMSMLength {
return MSMPrecomp{}, fmt.Errorf("the number of points must be %d", supportedMSMLength)
}
var err error
var precompPoints [supportedMSMLength]PrecompPoint
// We apply the current strategy of:
// - Use a 16-bit window for the first window16vs8IndexLimit points.
// - Use an 8-bit window for the rest.
for i := 0; i < supportedMSMLength; i++ {
windowSize := 8
if i < window16vs8IndexLimit {
windowSize = 16
}
precompPoints[i], err = NewPrecompPoint(points[i], windowSize)
if err != nil {
return MSMPrecomp{}, fmt.Errorf("creating precomputed table for point: %s", err)
}
}
return MSMPrecomp{
precompPoints: precompPoints,
}, nil
}
// MSM calculates the 256-MSM of the given scalars on the fixed basis.
// It automatically detects how many non-zero scalars there are and parallelizes the computation.
func (msm *MSMPrecomp) MSM(scalars []fr.Element) Element {
result := bandersnatch.IdentityExt
for i := range scalars {
if !scalars[i].IsZero() {
msm.precompPoints[i].ScalarMul(scalars[i], &result)
}
}
return Element{inner: bandersnatch.PointProj{
X: result.X,
Y: result.Y,
Z: result.Z,
}}
}
// PrecompPoint is a precomputed table for a single point.
type PrecompPoint struct {
windowSize int
windows [][]bandersnatch.PointExtendedNormalized
}
// NewPrecompPoint creates a new PrecompPoint for the given point and window size.
func NewPrecompPoint(point Element, windowSize int) (PrecompPoint, error) {
if windowSize&(windowSize-1) != 0 {
return PrecompPoint{}, fmt.Errorf("window size must be a power of 2")
}
var specialWindow fr.Element
specialWindow.SetUint64(1 << windowSize)
res := PrecompPoint{
windowSize: windowSize,
windows: make([][]bandersnatch.PointExtendedNormalized, 256/windowSize),
}
windows := make([][]bandersnatch.PointExtended, 256/windowSize)
group, _ := errgroup.WithContext(context.Background())
group.SetLimit(runtime.NumCPU())
for i := 0; i < len(res.windows); i++ {
i := i
base := bandersnatch.PointExtendedFromProj(&point.inner)
group.Go(func() error {
windows[i] = make([]bandersnatch.PointExtended, 1<<(windowSize-1))
curr := base
for j := 0; j < len(windows[i]); j++ {
windows[i][j] = curr
curr.Add(&curr, &base)
}
res.windows[i] = batchToExtendedPointNormalized(windows[i])
return nil
})
point.ScalarMul(&point, &specialWindow)
}
_ = group.Wait()
return res, nil
}
// ScalarMul multiplies the point by the given scalar using the precomputed points.
// It applies a trick to push a carry between windows since our precomputed tables
// avoid storing point inverses.
func (pp *PrecompPoint) ScalarMul(scalar fr.Element, res *bandersnatch.PointExtended) {
numWindowsInLimb := 64 / pp.windowSize
scalar.FromMont()
var carry uint64
var pNeg bandersnatch.PointExtendedNormalized
for l := 0; l < fr.Limbs; l++ {
for w := 0; w < numWindowsInLimb; w++ {
windowValue := (scalar[l]>>(pp.windowSize*w))&((1<<pp.windowSize)-1) + carry
if windowValue == 0 {
continue
}
carry = 0
if windowValue > 1<<(pp.windowSize-1) {
windowValue = (1 << pp.windowSize) - windowValue
if windowValue != 0 {
pNeg.Neg(&pp.windows[l*numWindowsInLimb+w][windowValue-1])
bandersnatch.ExtendedAddNormalized(res, res, &pNeg)
}
carry = 1
} else {
bandersnatch.ExtendedAddNormalized(res, res, &pp.windows[l*numWindowsInLimb+w][windowValue-1])
}
}
}
}
// batchProjToAffine converts a slice of points in projective coordinates to affine coordinates.
// This code was pulled from gnark-crypto which unfortunately doesn't have a variant for bandersnatch
// since it's a secondary curve in the generated code.
func batchProjToAffine(points []bandersnatch.PointProj) []bandersnatch.PointAffine {
result := make([]bandersnatch.PointAffine, len(points))
zeroes := make([]bool, len(points))
accumulator := fp.One()
// batch invert all points[].Z coordinates with Montgomery batch inversion trick
// (stores points[].Z^-1 in result[i].X to avoid allocating a slice of fr.Elements)
for i := 0; i < len(points); i++ {
if points[i].Z.IsZero() {
zeroes[i] = true
continue
}
result[i].X = accumulator
accumulator.Mul(&accumulator, &points[i].Z)
}
var accInverse fp.Element
accInverse.Inverse(&accumulator)
for i := len(points) - 1; i >= 0; i-- {
if zeroes[i] {
// do nothing, (X=0, Y=0) is infinity point in affine
continue
}
result[i].X.Mul(&result[i].X, &accInverse)
accInverse.Mul(&accInverse, &points[i].Z)
}
// batch convert to affine.
parallel.Execute(len(points), func(start, end int) {
for i := start; i < end; i++ {
if zeroes[i] {
// do nothing, (X=0, Y=0) is infinity point in affine
continue
}
a := result[i].X
result[i].X.Mul(&points[i].X, &a)
result[i].Y.Mul(&points[i].Y, &a)
}
})
return result
}
func batchToExtendedPointNormalized(points []bandersnatch.PointExtended) []bandersnatch.PointExtendedNormalized {
result := make([]bandersnatch.PointExtendedNormalized, len(points))
zeroes := make([]bool, len(points))
accumulator := fp.One()
// batch invert all points[].Z coordinates with Montgomery batch inversion trick
// (stores points[].Z^-1 in result[i].X to avoid allocating a slice of fr.Elements)
for i := 0; i < len(points); i++ {
if points[i].Z.IsZero() {
zeroes[i] = true
continue
}
result[i].X = accumulator
accumulator.Mul(&accumulator, &points[i].Z)
}
var accInverse fp.Element
accInverse.Inverse(&accumulator)
for i := len(points) - 1; i >= 0; i-- {
if zeroes[i] {
// do nothing, (X=0, Y=0) is infinity point in affine
continue
}
result[i].X.Mul(&result[i].X, &accInverse)
accInverse.Mul(&accInverse, &points[i].Z)
}
// batch convert to affine.
parallel.Execute(len(points), func(start, end int) {
for i := start; i < end; i++ {
if zeroes[i] {
// do nothing, (X=0, Y=0) is infinity point in affine
continue
}
a := result[i].X
result[i].X.Mul(&points[i].X, &a)
result[i].Y.Mul(&points[i].Y, &a)
result[i].T.Mul(&result[i].X, &result[i].Y)
}
})
return result
}