diff --git a/api.go b/api.go index 289baa6..dd18df9 100644 --- a/api.go +++ b/api.go @@ -20,16 +20,25 @@ const ( msgTypeStore string = "store" ) -var ( - socketSubscriptions = map[string]func(msg interface{}) error{} - socketSubscriptionsLock = new(sync.RWMutex) -) +var subscriptions = newSubscriptionStore() -func sendAllSockets(msgType string, msg interface{}) error { - socketSubscriptionsLock.RLock() - defer socketSubscriptionsLock.RUnlock() +type subcriptionStore struct { + socketSubscriptions map[string]func(interface{}) error + socketSubscriptionsLock *sync.RWMutex +} - for _, hdl := range socketSubscriptions { +func newSubscriptionStore() *subcriptionStore { + return &subcriptionStore{ + socketSubscriptions: map[string]func(msg interface{}) error{}, + socketSubscriptionsLock: new(sync.RWMutex), + } +} + +func (s subcriptionStore) SendAllSockets(msgType string, msg interface{}) error { + s.socketSubscriptionsLock.RLock() + defer s.socketSubscriptionsLock.RUnlock() + + for _, hdl := range s.socketSubscriptions { if err := hdl(compileSocketMessage(msgType, msg)); err != nil { return errors.Wrap(err, "submit message") } @@ -38,18 +47,18 @@ func sendAllSockets(msgType string, msg interface{}) error { return nil } -func subscribeSocket(id string, hdl func(interface{}) error) { - socketSubscriptionsLock.Lock() - defer socketSubscriptionsLock.Unlock() +func (s *subcriptionStore) SubscribeSocket(id string, hdl func(interface{}) error) { + s.socketSubscriptionsLock.Lock() + defer s.socketSubscriptionsLock.Unlock() - socketSubscriptions[id] = hdl + s.socketSubscriptions[id] = hdl } -func unsubscribeSocket(id string) { - socketSubscriptionsLock.Lock() - defer socketSubscriptionsLock.Unlock() +func (s *subcriptionStore) UnsubscribeSocket(id string) { + s.socketSubscriptionsLock.Lock() + defer s.socketSubscriptionsLock.Unlock() - delete(socketSubscriptions, id) + delete(s.socketSubscriptions, id) } func compileSocketMessage(msgType string, msg interface{}) map[string]interface{} { @@ -97,13 +106,13 @@ func handleUpdateSocket(w http.ResponseWriter, r *http.Request) { connLock = new(sync.Mutex) id = uuid.Must(uuid.NewV4()).String() ) - subscribeSocket(id, func(msg interface{}) error { + subscriptions.SubscribeSocket(id, func(msg interface{}) error { connLock.Lock() defer connLock.Unlock() return conn.WriteJSON(msg) }) - defer unsubscribeSocket(id) + defer subscriptions.UnsubscribeSocket(id) keepAlive := time.NewTicker(5 * time.Second) defer keepAlive.Stop() diff --git a/main.go b/main.go index daa8be0..885b9e7 100644 --- a/main.go +++ b/main.go @@ -110,7 +110,7 @@ func main() { } case <-timerForceSync.C: - if err := sendAllSockets(msgTypeStore, store); err != nil { + if err := subscriptions.SendAllSockets(msgTypeStore, store); err != nil { log.WithError(err).Error("Unable to send store to all sockets") } diff --git a/stats.go b/stats.go index 8c50d96..a38ffab 100644 --- a/stats.go +++ b/stats.go @@ -77,7 +77,7 @@ func updateFollowers() error { } return errors.Wrap( - sendAllSockets(msgTypeStore, store), + subscriptions.SendAllSockets(msgTypeStore, store), "update all sockets", ) } diff --git a/webhook.go b/webhook.go index 4539676..7dbfe48 100644 --- a/webhook.go +++ b/webhook.go @@ -98,7 +98,7 @@ func handleWebHookPush(w http.ResponseWriter, r *http.Request) { logger.WithError(err).Error("Unable to update persistent store") } - if err := sendAllSockets(msgTypeStore, store); err != nil { + if err := subscriptions.SendAllSockets(msgTypeStore, store); err != nil { logger.WithError(err).Error("Unable to send update to all sockets") } }