diff --git a/stream/messaging.go b/stream/messaging.go index ea1501c..4876d63 100644 --- a/stream/messaging.go +++ b/stream/messaging.go @@ -62,18 +62,15 @@ func (s *Stream) Close() { // Register a new output on a stream. // If hidden in true, then do not count this client. -func (s *Stream) Register(output chan []byte, hidden bool) { +func (s *Stream) Register(output chan []byte) { s.lock.Lock() defer s.lock.Unlock() s.outputs[output] = struct{}{} - if !hidden { - s.nbClients++ - } } // Unregister removes an output. // If hidden in true, then do not count this client. -func (s *Stream) Unregister(output chan []byte, hidden bool) { +func (s *Stream) Unregister(output chan []byte) { s.lock.Lock() defer s.lock.Unlock() @@ -82,13 +79,20 @@ func (s *Stream) Unregister(output chan []byte, hidden bool) { if ok { delete(s.outputs, output) close(output) - if !hidden { - s.nbClients-- - } } } -// Count number of clients -func (s *Stream) Count() int { +// ClientCount returns the number of clients +func (s *Stream) ClientCount() int { return s.nbClients } + +// IncrementClientCount increments the number of clients +func (s *Stream) IncrementClientCount() { + s.nbClients++ +} + +// DecrementClientCount decrements the number of clients +func (s *Stream) DecrementClientCount() { + s.nbClients-- +} diff --git a/stream/messaging_test.go b/stream/messaging_test.go index f7141b2..49e17a0 100644 --- a/stream/messaging_test.go +++ b/stream/messaging_test.go @@ -16,7 +16,8 @@ func TestWithOneOutput(t *testing.T) { // Register one output output := make(chan []byte, 64) - stream.Register(output, false) + stream.Register(output) + stream.IncrementClientCount() // Try to pass one message stream.Broadcast <- []byte("hello world") @@ -26,15 +27,16 @@ func TestWithOneOutput(t *testing.T) { } // Check client count - if count := stream.Count(); count != 1 { + if count := stream.ClientCount(); count != 1 { t.Errorf("Client counter returned %d, expected 1", count) } // Unregister - stream.Unregister(output, false) + stream.Unregister(output) + stream.DecrementClientCount() // Check client count - if count := stream.Count(); count != 0 { + if count := stream.ClientCount(); count != 0 { t.Errorf("Client counter returned %d, expected 0", count) } } diff --git a/stream/srt/handler.go b/stream/srt/handler.go index 8ea3d98..9d521dc 100644 --- a/stream/srt/handler.go +++ b/stream/srt/handler.go @@ -66,7 +66,8 @@ func handleViewer(s *srtgo.SrtSocket, streams map[string]*stream.Stream, name st // Register new output c := make(chan []byte, 128) - st.Register(c, false) + st.Register(c) + st.IncrementClientCount() // Receive data and send them for data := range c { @@ -84,6 +85,7 @@ func handleViewer(s *srtgo.SrtSocket, streams map[string]*stream.Stream, name st } // Close output - st.Unregister(c, false) + st.Unregister(c) + st.DecrementClientCount() s.Close() }