Go: Broadcast channels?

In Go channels are a very key and convenient mechanism for communication. However, if we want to pass a message to multiple listeners then we need multiple channels. That is, writes to channels aren't broadcast. As soon as one listener reads a value from the channel that value is removed from the channel and other listeners listening on the same channel will not see that value. If we want to send a message (or value) to multiple listeners we need to write that message to \(N\) different channels.

package main

import (
	"fmt"
	"sync"
)

func main() {
	ch := make(chan int)

	var wg sync.WaitGroup

	wg.Add(2)
	go func() {
		for v := range ch {
			fmt.Println("A", v)
		}
		wg.Done()
	}()

	go func() {
		for v := range ch {
			fmt.Println("B", v)
		}
		wg.Done()
	}()

	ch <- 1
	ch <- 2
	close(ch)
	wg.Wait()
}

The above example will print two lines with the second field being 1 and 2. Depending on the exact scheduling it might print A 1 \ B 2, or A 1 \ A 2 or B 1 \ B 2 or any other such variant but each go routine will print one line as each go routine sees one value. If we want each value to be seen by all listening go routines we need more channels.

package main

import (
	"fmt"
	"sync"
)

func main() {
	chA := make(chan int)
	chB := make(chan int)

	var wg sync.WaitGroup

	wg.Add(2)
	go func() {
		for v := range chA {
			fmt.Println("A", v)
		}
		wg.Done()
	}()

	go func() {
		for v := range chB {
			fmt.Println("B", v)
		}
		wg.Done()
	}()

	for i := 0; i < 10; i++ {
		chA <- i
		chB <- i
	}
	close(chA)
	close(chB)
	wg.Wait()
}

Now each listener sees every value. We're not quite happy with this though. Now we have to keep track of many channels and what if we want to dynamically add or remove listeners? Let's look at one first intermediate improvement.

package main

import (
	"fmt"
	"sync"
)

func main() {
	chA := make(chan int)
	chB := make(chan int)
	
	chBroadcast := make(chan int)

	var wg sync.WaitGroup

	wg.Add(3)
	
	go func() {
		for v := range chBroadcast {
			chA <- v
			chB <- v
		}
		close(chA)
		close(chB)
		wg.Done()
	}()
	
	go func() {
		for v := range chA {
			fmt.Println("A", v)
		}
		wg.Done()
	}()

	go func() {
		for v := range chB {
			fmt.Println("B", v)
		}
		wg.Done()
	}()

	for i := 0; i < 10; i++ {
		chBroadcast <- i
	}
	close(chBroadcast)
	wg.Wait()
}

Notice the difference? Now we only have to write our values to one single channel chBroadcast. This is now our broadcast channel. But now let's get more complicated with dynamically adding and removing listeners!

package main

import (
	"fmt"
	"sync"
)

type BroadcastService struct {
	// This is the channel the service will listen on...
	chBroadcast chan int
	// and forward it to these.
	chListeners []chan int
	// Requests for new listeners to be added...
	chNewRequests chan (chan int)
	// Requests for listeners to be removed...
	chRemoveRequests chan (chan int)
}

// Create a new BroadcastService.
func NewBroadcastService() *BroadcastService {
	return &BroadcastService{
		chBroadcast:      make(chan int),
		chListeners:      make([]chan int, 3),
		chNewRequests:    make(chan (chan int)),
		chRemoveRequests: make(chan (chan int)),
	}
}

// This creates a new listener and returns the channel a goroutine
// should listen on.
func (bs *BroadcastService) Listener() chan int {
	ch := make(chan int)
	bs.chNewRequests <- ch
	return ch
}

// This removes a listener.
func (bs *BroadcastService) RemoveListener(ch chan int) {
	bs.chRemoveRequests <- ch
}

func (bs *BroadcastService) addListener(ch chan int) {
	for i, v := range bs.chListeners {
		if v == nil {
			bs.chListeners[i] = ch
			return
		}
	}

	bs.chListeners = append(bs.chListeners, ch)
}

func (bs *BroadcastService) removeListener(ch chan int) {
	for i, v := range bs.chListeners {
		if v == ch {
			bs.chListeners[i] = nil
			// important to close! otherwise the goroutine listening on it
			// might block forever!
			close(ch)
			return
		}
	}
}

func (bs *BroadcastService) Run() chan int {
	go func() {
		for {
			// process requests for new listeners or removal of listeners
			select {
			case newCh := <-bs.chNewRequests:
				bs.addListener(newCh)
			case removeCh := <-bs.chRemoveRequests:
				bs.removeListener(removeCh)
			case v, ok := <-bs.chBroadcast:
				// terminate everything if the input channel is closed
				if !ok {
					goto terminate
				}

				// forward the value to all channels
				for _, dstCh := range bs.chListeners {
					if dstCh == nil {
						continue
					}

					dstCh <- v
				}
			}
		}

	terminate:

		// close all listeners
		for _, dstCh := range bs.chListeners {
			if dstCh == nil {
				continue
			}

			close(dstCh)
		}
	}()

	return bs.chBroadcast
}

func main() {
	bs := NewBroadcastService()
	chBroadcast := bs.Run()
	chA := bs.Listener()
	chB := bs.Listener()

	var wg sync.WaitGroup

	wg.Add(2)

	go func() {
		for v := range chA {
			fmt.Println("A", v)
		}
		wg.Done()
	}()

	go func() {
		for v := range chB {
			fmt.Println("B", v)
		}
		wg.Done()
	}()

	for i := 0; i < 3; i++ {
		chBroadcast <- i
	}

	bs.RemoveListener(chA)

	for i := 3; i < 6; i++ {
		chBroadcast <- i
	}

	close(chBroadcast)
	wg.Wait()
}