package hash import ( "crypto/sha256" "errors" ) // ExpandMsgXmd expands msg to a slice of lenInBytes bytes. // https://datatracker.ietf.org/doc/html/rfc9380#name-expand_message_xmd // https://datatracker.ietf.org/doc/html/rfc9380#name-utility-functions (I2OSP/O2ISP) func ExpandMsgXmd(msg, dst []byte, lenInBytes int) ([]byte, error) { h := sha256.New() ell := (lenInBytes + h.Size() - 1) / h.Size() // ceil(len_in_bytes / b_in_bytes) if ell > 255 { return nil, errors.New("invalid lenInBytes") } if len(dst) > 255 { return nil, errors.New("invalid domain size (>255 bytes)") } sizeDomain := uint8(len(dst)) // Z_pad = I2OSP(0, r_in_bytes) // l_i_b_str = I2OSP(len_in_bytes, 2) // DST_prime = DST ∥ I2OSP(len(DST), 1) // b₀ = H(Z_pad ∥ msg ∥ l_i_b_str ∥ I2OSP(0, 1) ∥ DST_prime) h.Reset() if _, err := h.Write(make([]byte, h.BlockSize())); err != nil { return nil, err } if _, err := h.Write(msg); err != nil { return nil, err } if _, err := h.Write([]byte{uint8(lenInBytes >> 8), uint8(lenInBytes), uint8(0)}); err != nil { return nil, err } if _, err := h.Write(dst); err != nil { return nil, err } if _, err := h.Write([]byte{sizeDomain}); err != nil { return nil, err } b0 := h.Sum(nil) // b₁ = H(b₀ ∥ I2OSP(1, 1) ∥ DST_prime) h.Reset() if _, err := h.Write(b0); err != nil { return nil, err } if _, err := h.Write([]byte{uint8(1)}); err != nil { return nil, err } if _, err := h.Write(dst); err != nil { return nil, err } if _, err := h.Write([]byte{sizeDomain}); err != nil { return nil, err } b1 := h.Sum(nil) res := make([]byte, lenInBytes) copy(res[:h.Size()], b1) for i := 2; i <= ell; i++ { // b_i = H(strxor(b₀, b_(i - 1)) ∥ I2OSP(i, 1) ∥ DST_prime) h.Reset() strxor := make([]byte, h.Size()) for j := 0; j < h.Size(); j++ { strxor[j] = b0[j] ^ b1[j] } if _, err := h.Write(strxor); err != nil { return nil, err } if _, err := h.Write([]byte{uint8(i)}); err != nil { return nil, err } if _, err := h.Write(dst); err != nil { return nil, err } if _, err := h.Write([]byte{sizeDomain}); err != nil { return nil, err } b1 = h.Sum(nil) copy(res[h.Size()*(i-1):min(h.Size()*i, len(res))], b1) } return res, nil } func min(a, b int) int { if a < b { return a } return b }