contractcourt: update encode/decode for taproot aux data

When we read/write the aux data, we need to make sure we always set the
new fields for aux HTLCs.
This commit is contained in:
Olaoluwa Osuntokun
2024-10-15 19:19:22 -07:00
parent 9b8adf5f5c
commit 7ef2683586

View File

@@ -1571,6 +1571,7 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
})
}
htlcBlobs := newAuxHtlcBlobs()
for _, htlc := range c.HtlcResolutions.IncomingHTLCs {
htlc := htlc
@@ -1581,8 +1582,9 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
continue
}
var resID resolverID
if htlc.SignedSuccessTx != nil {
resID := newResolverID(
resID = newResolverID(
htlc.SignedSuccessTx.TxIn[0].PreviousOutPoint,
)
//nolint:lll
@@ -1598,10 +1600,14 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID] = bridgeCtrlBlock
}
} else {
resID := newResolverID(htlc.ClaimOutpoint)
resID = newResolverID(htlc.ClaimOutpoint)
//nolint:lll
tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID] = ctrlBlock
}
htlc.ResolutionBlob.WhenSome(func(b []byte) {
htlcBlobs[resID] = b
})
}
for _, htlc := range c.HtlcResolutions.OutgoingHTLCs {
htlc := htlc
@@ -1613,8 +1619,9 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
continue
}
var resID resolverID
if htlc.SignedTimeoutTx != nil {
resID := newResolverID(
resID = newResolverID(
htlc.SignedTimeoutTx.TxIn[0].PreviousOutPoint,
)
//nolint:lll
@@ -1632,10 +1639,14 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID] = bridgeCtrlBlock
}
} else {
resID := newResolverID(htlc.ClaimOutpoint)
resID = newResolverID(htlc.ClaimOutpoint)
//nolint:lll
tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID] = ctrlBlock
}
htlc.ResolutionBlob.WhenSome(func(b []byte) {
htlcBlobs[resID] = b
})
}
if c.AnchorResolution != nil {
@@ -1643,6 +1654,12 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
tapCase.TapTweaks.Val.AnchorTweak = anchorSignDesc.TapTweak
}
if len(htlcBlobs) != 0 {
tapCase.HtlcBlobs = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType4](htlcBlobs),
)
}
return tapCase.Encode(w)
}
@@ -1661,6 +1678,8 @@ func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error {
})
}
htlcBlobs := tapCase.HtlcBlobs.ValOpt().UnwrapOr(newAuxHtlcBlobs())
for i := range c.HtlcResolutions.IncomingHTLCs {
htlc := c.HtlcResolutions.IncomingHTLCs[i]
@@ -1687,7 +1706,12 @@ func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error {
htlc.SweepSignDesc.ControlBlock = ctrlBlock
}
if htlcBlob, ok := htlcBlobs[resID]; ok {
htlc.ResolutionBlob = fn.Some(htlcBlob)
}
c.HtlcResolutions.IncomingHTLCs[i] = htlc
}
for i := range c.HtlcResolutions.OutgoingHTLCs {
htlc := c.HtlcResolutions.OutgoingHTLCs[i]
@@ -1715,6 +1739,10 @@ func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error {
htlc.SweepSignDesc.ControlBlock = ctrlBlock
}
if htlcBlob, ok := htlcBlobs[resID]; ok {
htlc.ResolutionBlob = fn.Some(htlcBlob)
}
c.HtlcResolutions.OutgoingHTLCs[i] = htlc
}