input: use lnutils.SyncMap to store musig2 sessions

This commit is contained in:
yyforyongyu 2023-11-25 04:46:55 +08:00 committed by Olaoluwa Osuntokun
parent 6744c64e62
commit 23e177514f

View File

@ -8,6 +8,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/multimutex" "github.com/lightningnetwork/lnd/multimutex"
) )
@ -38,7 +39,7 @@ type MusigSessionManager struct {
sessionMtx *multimutex.Mutex[MuSig2SessionID] sessionMtx *multimutex.Mutex[MuSig2SessionID]
musig2Sessions map[MuSig2SessionID]*MuSig2State musig2Sessions *lnutils.SyncMap[MuSig2SessionID, *MuSig2State]
} }
// NewMusigSessionManager creates a new musig manager given an abstract key // NewMusigSessionManager creates a new musig manager given an abstract key
@ -46,7 +47,9 @@ type MusigSessionManager struct {
func NewMusigSessionManager(keyFetcher PrivKeyFetcher) *MusigSessionManager { func NewMusigSessionManager(keyFetcher PrivKeyFetcher) *MusigSessionManager {
return &MusigSessionManager{ return &MusigSessionManager{
keyFetcher: keyFetcher, keyFetcher: keyFetcher,
musig2Sessions: make(map[MuSig2SessionID]*MuSig2State), musig2Sessions: &lnutils.SyncMap[
MuSig2SessionID, *MuSig2State,
]{},
sessionMtx: multimutex.NewMutex[MuSig2SessionID](), sessionMtx: multimutex.NewMutex[MuSig2SessionID](),
} }
} }
@ -134,9 +137,7 @@ func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version,
// //
// We'll use just all zeroes as the session ID for the mutex, as this // We'll use just all zeroes as the session ID for the mutex, as this
// is a "global" action. // is a "global" action.
m.sessionMtx.Lock(MuSig2SessionID{}) m.musig2Sessions.Store(session.SessionID, session)
m.musig2Sessions[session.SessionID] = session
m.sessionMtx.Unlock(MuSig2SessionID{})
return &session.MuSig2SessionInfo, nil return &session.MuSig2SessionInfo, nil
} }
@ -157,7 +158,7 @@ func (m *MusigSessionManager) MuSig2Sign(sessionID MuSig2SessionID,
m.sessionMtx.Lock(sessionID) m.sessionMtx.Lock(sessionID)
defer m.sessionMtx.Unlock(sessionID) defer m.sessionMtx.Unlock(sessionID)
session, ok := m.musig2Sessions[sessionID] session, ok := m.musig2Sessions.Load(sessionID)
if !ok { if !ok {
return nil, fmt.Errorf("session with ID %x not found", return nil, fmt.Errorf("session with ID %x not found",
sessionID[:]) sessionID[:])
@ -178,7 +179,7 @@ func (m *MusigSessionManager) MuSig2Sign(sessionID MuSig2SessionID,
// Clean up our local state if requested. // Clean up our local state if requested.
if cleanUp { if cleanUp {
delete(m.musig2Sessions, sessionID) m.musig2Sessions.Delete(sessionID)
} }
return partialSig, nil return partialSig, nil
@ -198,7 +199,7 @@ func (m *MusigSessionManager) MuSig2CombineSig(sessionID MuSig2SessionID,
m.sessionMtx.Lock(sessionID) m.sessionMtx.Lock(sessionID)
defer m.sessionMtx.Unlock(sessionID) defer m.sessionMtx.Unlock(sessionID)
session, ok := m.musig2Sessions[sessionID] session, ok := m.musig2Sessions.Load(sessionID)
if !ok { if !ok {
return nil, false, fmt.Errorf("session with ID %x not found", return nil, false, fmt.Errorf("session with ID %x not found",
sessionID[:]) sessionID[:])
@ -231,7 +232,7 @@ func (m *MusigSessionManager) MuSig2CombineSig(sessionID MuSig2SessionID,
// there is nothing more left to do. // there is nothing more left to do.
if session.HaveAllSigs { if session.HaveAllSigs {
finalSig = session.session.FinalSig() finalSig = session.session.FinalSig()
delete(m.musig2Sessions, sessionID) m.musig2Sessions.Delete(sessionID)
} }
return finalSig, session.HaveAllSigs, nil return finalSig, session.HaveAllSigs, nil
@ -245,12 +246,12 @@ func (m *MusigSessionManager) MuSig2Cleanup(sessionID MuSig2SessionID) error {
m.sessionMtx.Lock(sessionID) m.sessionMtx.Lock(sessionID)
defer m.sessionMtx.Unlock(sessionID) defer m.sessionMtx.Unlock(sessionID)
_, ok := m.musig2Sessions[sessionID] _, ok := m.musig2Sessions.Load(sessionID)
if !ok { if !ok {
return fmt.Errorf("session with ID %x not found", sessionID[:]) return fmt.Errorf("session with ID %x not found", sessionID[:])
} }
delete(m.musig2Sessions, sessionID) m.musig2Sessions.Delete(sessionID)
return nil return nil
} }
@ -267,7 +268,7 @@ func (m *MusigSessionManager) MuSig2RegisterNonces(sessionID MuSig2SessionID,
m.sessionMtx.Lock(sessionID) m.sessionMtx.Lock(sessionID)
defer m.sessionMtx.Unlock(sessionID) defer m.sessionMtx.Unlock(sessionID)
session, ok := m.musig2Sessions[sessionID] session, ok := m.musig2Sessions.Load(sessionID)
if !ok { if !ok {
return false, fmt.Errorf("session with ID %x not found", return false, fmt.Errorf("session with ID %x not found",
sessionID[:]) sessionID[:])