Skip to content

Commit

Permalink
WIP wrapped hash
Browse files Browse the repository at this point in the history
  • Loading branch information
ivokub committed Oct 25, 2023
1 parent d1ab301 commit f58117b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 27 deletions.
32 changes: 14 additions & 18 deletions std/recursion/wrapped_hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,10 @@ func (h *shortNativeHash) Sum(b []byte) []byte {
h.ringBuf.Read(h.buf[1:])
h.wrapped.Write(h.buf)

// TODO: I'm cutting the hash on bit short to avoid edge cases:
// cut the hash a byte short to definitely fit
res := h.wrapped.Sum(nil)
fmt.Printf("sum %x\n", res)
nbBytes := (h.outSize + 7) / 8
res = res[len(res)-nbBytes+1:]
// mask := (1 << ((h.outSize - 1) % 8)) - 1
// res[0] &= byte(mask)
nbBytes := (h.outSize+7)/8 - 1
res = res[len(res)-nbBytes:]
return append(b, res...)
}

Expand All @@ -136,23 +133,23 @@ func (h *shortNativeHash) BlockSize() int {
}

type shortCircuitHash struct {
api frontend.API
bitLength int
wrapped stdhash.FieldHasher
buf []frontend.Variable
tmp []frontend.Variable
api frontend.API
outSize int
wrapped stdhash.FieldHasher
buf []frontend.Variable
tmp []frontend.Variable
}

func newHashFromParameter(api frontend.API, hf stdhash.FieldHasher, bitLength int) stdhash.FieldHasher {
tmp := make([]frontend.Variable, ((bitLength+7)/8)*8-8)
tmp := make([]frontend.Variable, ((api.Compiler().FieldBitLen()+7)/8)*8-8)
for i := range tmp {
tmp[i] = 0
}
return &shortCircuitHash{
api: api,
bitLength: bitLength,
wrapped: hf,
tmp: tmp,
api: api,
outSize: bitLength,
wrapped: hf,
tmp: tmp,
}
}

Expand Down Expand Up @@ -190,9 +187,8 @@ func (h *shortCircuitHash) Sum() frontend.Variable {
v := bits.FromBinary(h.api, h.tmp)
h.wrapped.Write(v)
res := h.wrapped.Sum()
h.api.Println(res)
resBts := bits.ToBinary(h.api, res)
res = bits.FromBinary(h.api, resBts[:len(h.tmp)])
res = bits.FromBinary(h.api, resBts[:((h.outSize+7)/8-1)*8])
return res
}

Expand Down
12 changes: 3 additions & 9 deletions std/recursion/wrapped_hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ func (c *shortHashCircuit) Define(api frontend.API) error {

func TestShortHash(t *testing.T) {
outerCurves := []ecc.ID{
// tinyfield.Modulus(),
ecc.BN254,
ecc.BLS12_381,
ecc.BLS12_377,
Expand All @@ -51,7 +50,7 @@ func TestShortHash(t *testing.T) {
}

assert := test.NewAssert(t)
nbInputs := 100
nbInputs := 19
for _, outer := range outerCurves {
outer := outer
for _, inner := range innerCurves {
Expand All @@ -63,21 +62,16 @@ func TestShortHash(t *testing.T) {
witness := &shortHashCircuit{Input: make([]frontend.Variable, nbInputs), inner: inner}
buf := make([]byte, (outer.ScalarField().BitLen()+7)/8)
for i := range witness.Input {
// el, _ := new(big.Int).SetString("1231231230981238971241240982112382934728934798234981324798123981724198712467928497124987124", 10)
// el.Mod(el, outer)
el, err := rand.Int(rand.Reader, outer.ScalarField())
assert.NoError(err)
el.FillBytes(buf)
// fmt.Printf("input: %x\n", buf)
h.Write(buf)
witness.Input[i] = el
}
res := h.Sum(nil)
witness.Output = res
err = test.IsSolved(circuit, witness, outer.ScalarField())
assert.NoError(err)
// assert.CheckCircuit(circuit, test.WithCurves(outer), test.WithValidAssignment(witness), test.NoFuzzing(), test.NoSerializationChecks(), test.NoSolidityChecks())
}, inner.String())
assert.CheckCircuit(circuit, test.WithCurves(outer), test.WithValidAssignment(witness), test.NoFuzzing(), test.NoSerializationChecks(), test.NoSolidityChecks())
}, outer.String(), inner.String())
}
}
}

0 comments on commit f58117b

Please sign in to comment.