mockbroker.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. package sarama
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "io"
  7. "net"
  8. "reflect"
  9. "strconv"
  10. "sync"
  11. "time"
  12. "github.com/davecgh/go-spew/spew"
  13. )
  14. const (
  15. expectationTimeout = 500 * time.Millisecond
  16. )
  17. type GSSApiHandlerFunc func([]byte) []byte
  18. type requestHandlerFunc func(req *request) (res encoderWithHeader)
  19. // RequestNotifierFunc is invoked when a mock broker processes a request successfully
  20. // and will provides the number of bytes read and written.
  21. type RequestNotifierFunc func(bytesRead, bytesWritten int)
  22. // MockBroker is a mock Kafka broker that is used in unit tests. It is exposed
  23. // to facilitate testing of higher level or specialized consumers and producers
  24. // built on top of Sarama. Note that it does not 'mimic' the Kafka API protocol,
  25. // but rather provides a facility to do that. It takes care of the TCP
  26. // transport, request unmarshalling, response marshaling, and makes it the test
  27. // writer responsibility to program correct according to the Kafka API protocol
  28. // MockBroker behavior.
  29. //
  30. // MockBroker is implemented as a TCP server listening on a kernel-selected
  31. // localhost port that can accept many connections. It reads Kafka requests
  32. // from that connection and returns responses programmed by the SetHandlerByMap
  33. // function. If a MockBroker receives a request that it has no programmed
  34. // response for, then it returns nothing and the request times out.
  35. //
  36. // A set of MockRequest builders to define mappings used by MockBroker is
  37. // provided by Sarama. But users can develop MockRequests of their own and use
  38. // them along with or instead of the standard ones.
  39. //
  40. // When running tests with MockBroker it is strongly recommended to specify
  41. // a timeout to `go test` so that if the broker hangs waiting for a response,
  42. // the test panics.
  43. //
  44. // It is not necessary to prefix message length or correlation ID to your
  45. // response bytes, the server does that automatically as a convenience.
  46. type MockBroker struct {
  47. brokerID int32
  48. port int32
  49. closing chan none
  50. stopper chan none
  51. expectations chan encoderWithHeader
  52. listener net.Listener
  53. t TestReporter
  54. latency time.Duration
  55. handler requestHandlerFunc
  56. notifier RequestNotifierFunc
  57. history []RequestResponse
  58. lock sync.Mutex
  59. gssApiHandler GSSApiHandlerFunc
  60. }
  61. // RequestResponse represents a Request/Response pair processed by MockBroker.
  62. type RequestResponse struct {
  63. Request protocolBody
  64. Response encoder
  65. }
  66. // SetLatency makes broker pause for the specified period every time before
  67. // replying.
  68. func (b *MockBroker) SetLatency(latency time.Duration) {
  69. b.latency = latency
  70. }
  71. // SetHandlerByMap defines mapping of Request types to MockResponses. When a
  72. // request is received by the broker, it looks up the request type in the map
  73. // and uses the found MockResponse instance to generate an appropriate reply.
  74. // If the request type is not found in the map then nothing is sent.
  75. func (b *MockBroker) SetHandlerByMap(handlerMap map[string]MockResponse) {
  76. fnMap := make(map[string]MockResponse)
  77. for k, v := range handlerMap {
  78. fnMap[k] = v
  79. }
  80. b.setHandler(func(req *request) (res encoderWithHeader) {
  81. reqTypeName := reflect.TypeOf(req.body).Elem().Name()
  82. mockResponse := fnMap[reqTypeName]
  83. if mockResponse == nil {
  84. return nil
  85. }
  86. return mockResponse.For(req.body)
  87. })
  88. }
  89. // SetNotifier set a function that will get invoked whenever a request has been
  90. // processed successfully and will provide the number of bytes read and written
  91. func (b *MockBroker) SetNotifier(notifier RequestNotifierFunc) {
  92. b.lock.Lock()
  93. b.notifier = notifier
  94. b.lock.Unlock()
  95. }
  96. // BrokerID returns broker ID assigned to the broker.
  97. func (b *MockBroker) BrokerID() int32 {
  98. return b.brokerID
  99. }
  100. // History returns a slice of RequestResponse pairs in the order they were
  101. // processed by the broker. Note that in case of multiple connections to the
  102. // broker the order expected by a test can be different from the order recorded
  103. // in the history, unless some synchronization is implemented in the test.
  104. func (b *MockBroker) History() []RequestResponse {
  105. b.lock.Lock()
  106. history := make([]RequestResponse, len(b.history))
  107. copy(history, b.history)
  108. b.lock.Unlock()
  109. return history
  110. }
  111. // Port returns the TCP port number the broker is listening for requests on.
  112. func (b *MockBroker) Port() int32 {
  113. return b.port
  114. }
  115. // Addr returns the broker connection string in the form "<address>:<port>".
  116. func (b *MockBroker) Addr() string {
  117. return b.listener.Addr().String()
  118. }
  119. // Close terminates the broker blocking until it stops internal goroutines and
  120. // releases all resources.
  121. func (b *MockBroker) Close() {
  122. close(b.expectations)
  123. if len(b.expectations) > 0 {
  124. buf := bytes.NewBufferString(fmt.Sprintf("mockbroker/%d: not all expectations were satisfied! Still waiting on:\n", b.BrokerID()))
  125. for e := range b.expectations {
  126. _, _ = buf.WriteString(spew.Sdump(e))
  127. }
  128. b.t.Error(buf.String())
  129. }
  130. close(b.closing)
  131. <-b.stopper
  132. }
  133. // setHandler sets the specified function as the request handler. Whenever
  134. // a mock broker reads a request from the wire it passes the request to the
  135. // function and sends back whatever the handler function returns.
  136. func (b *MockBroker) setHandler(handler requestHandlerFunc) {
  137. b.lock.Lock()
  138. b.handler = handler
  139. b.lock.Unlock()
  140. }
  141. func (b *MockBroker) serverLoop() {
  142. defer close(b.stopper)
  143. var err error
  144. var conn net.Conn
  145. go func() {
  146. <-b.closing
  147. err := b.listener.Close()
  148. if err != nil {
  149. b.t.Error(err)
  150. }
  151. }()
  152. wg := &sync.WaitGroup{}
  153. i := 0
  154. for conn, err = b.listener.Accept(); err == nil; conn, err = b.listener.Accept() {
  155. wg.Add(1)
  156. go b.handleRequests(conn, i, wg)
  157. i++
  158. }
  159. wg.Wait()
  160. Logger.Printf("*** mockbroker/%d: listener closed, err=%v", b.BrokerID(), err)
  161. }
  162. func (b *MockBroker) SetGSSAPIHandler(handler GSSApiHandlerFunc) {
  163. b.gssApiHandler = handler
  164. }
  165. func (b *MockBroker) readToBytes(r io.Reader) ([]byte, error) {
  166. var (
  167. bytesRead int
  168. lengthBytes = make([]byte, 4)
  169. )
  170. if _, err := io.ReadFull(r, lengthBytes); err != nil {
  171. return nil, err
  172. }
  173. bytesRead += len(lengthBytes)
  174. length := int32(binary.BigEndian.Uint32(lengthBytes))
  175. if length <= 4 || length > MaxRequestSize {
  176. return nil, PacketDecodingError{fmt.Sprintf("message of length %d too large or too small", length)}
  177. }
  178. encodedReq := make([]byte, length)
  179. if _, err := io.ReadFull(r, encodedReq); err != nil {
  180. return nil, err
  181. }
  182. bytesRead += len(encodedReq)
  183. fullBytes := append(lengthBytes, encodedReq...)
  184. return fullBytes, nil
  185. }
  186. func (b *MockBroker) isGSSAPI(buffer []byte) bool {
  187. return buffer[4] == 0x60 || bytes.Equal(buffer[4:6], []byte{0x05, 0x04})
  188. }
  189. func (b *MockBroker) handleRequests(conn io.ReadWriteCloser, idx int, wg *sync.WaitGroup) {
  190. defer wg.Done()
  191. defer func() {
  192. _ = conn.Close()
  193. }()
  194. s := spew.NewDefaultConfig()
  195. s.MaxDepth = 1
  196. Logger.Printf("*** mockbroker/%d/%d: connection opened", b.BrokerID(), idx)
  197. var err error
  198. abort := make(chan none)
  199. defer close(abort)
  200. go func() {
  201. select {
  202. case <-b.closing:
  203. _ = conn.Close()
  204. case <-abort:
  205. }
  206. }()
  207. var bytesWritten int
  208. var bytesRead int
  209. for {
  210. buffer, err := b.readToBytes(conn)
  211. if err != nil {
  212. Logger.Printf("*** mockbroker/%d/%d: invalid request: err=%+v, %+v", b.brokerID, idx, err, spew.Sdump(buffer))
  213. b.serverError(err)
  214. break
  215. }
  216. bytesWritten = 0
  217. if !b.isGSSAPI(buffer) {
  218. req, br, err := decodeRequest(bytes.NewReader(buffer))
  219. bytesRead = br
  220. if err != nil {
  221. Logger.Printf("*** mockbroker/%d/%d: invalid request: err=%+v, %+v", b.brokerID, idx, err, spew.Sdump(req))
  222. b.serverError(err)
  223. break
  224. }
  225. if b.latency > 0 {
  226. time.Sleep(b.latency)
  227. }
  228. b.lock.Lock()
  229. res := b.handler(req)
  230. b.history = append(b.history, RequestResponse{req.body, res})
  231. b.lock.Unlock()
  232. if res == nil {
  233. Logger.Printf("*** mockbroker/%d/%d: ignored %v", b.brokerID, idx, spew.Sdump(req))
  234. continue
  235. }
  236. Logger.Printf(
  237. "*** mockbroker/%d/%d: replied to %T with %T\n-> %s\n-> %s",
  238. b.brokerID, idx, req.body, res,
  239. s.Sprintf("%#v", req.body),
  240. s.Sprintf("%#v", res),
  241. )
  242. encodedRes, err := encode(res, nil)
  243. if err != nil {
  244. b.serverError(err)
  245. break
  246. }
  247. if len(encodedRes) == 0 {
  248. b.lock.Lock()
  249. if b.notifier != nil {
  250. b.notifier(bytesRead, 0)
  251. }
  252. b.lock.Unlock()
  253. continue
  254. }
  255. resHeader := b.encodeHeader(res.headerVersion(), req.correlationID, uint32(len(encodedRes)))
  256. if _, err = conn.Write(resHeader); err != nil {
  257. b.serverError(err)
  258. break
  259. }
  260. if _, err = conn.Write(encodedRes); err != nil {
  261. b.serverError(err)
  262. break
  263. }
  264. bytesWritten = len(resHeader) + len(encodedRes)
  265. } else {
  266. // GSSAPI is not part of kafka protocol, but is supported for authentication proposes.
  267. // Don't support history for this kind of request as is only used for test GSSAPI authentication mechanism
  268. b.lock.Lock()
  269. res := b.gssApiHandler(buffer)
  270. b.lock.Unlock()
  271. if res == nil {
  272. Logger.Printf("*** mockbroker/%d/%d: ignored %v", b.brokerID, idx, spew.Sdump(buffer))
  273. continue
  274. }
  275. if _, err = conn.Write(res); err != nil {
  276. b.serverError(err)
  277. break
  278. }
  279. bytesWritten = len(res)
  280. }
  281. b.lock.Lock()
  282. if b.notifier != nil {
  283. b.notifier(bytesRead, bytesWritten)
  284. }
  285. b.lock.Unlock()
  286. }
  287. Logger.Printf("*** mockbroker/%d/%d: connection closed, err=%v", b.BrokerID(), idx, err)
  288. }
  289. func (b *MockBroker) encodeHeader(headerVersion int16, correlationId int32, payloadLength uint32) []byte {
  290. headerLength := uint32(8)
  291. if headerVersion >= 1 {
  292. headerLength = 9
  293. }
  294. resHeader := make([]byte, headerLength)
  295. binary.BigEndian.PutUint32(resHeader, payloadLength+headerLength-4)
  296. binary.BigEndian.PutUint32(resHeader[4:], uint32(correlationId))
  297. if headerVersion >= 1 {
  298. binary.PutUvarint(resHeader[8:], 0)
  299. }
  300. return resHeader
  301. }
  302. func (b *MockBroker) defaultRequestHandler(req *request) (res encoderWithHeader) {
  303. select {
  304. case res, ok := <-b.expectations:
  305. if !ok {
  306. return nil
  307. }
  308. return res
  309. case <-time.After(expectationTimeout):
  310. return nil
  311. }
  312. }
  313. func (b *MockBroker) serverError(err error) {
  314. isConnectionClosedError := false
  315. if _, ok := err.(*net.OpError); ok {
  316. isConnectionClosedError = true
  317. } else if err == io.EOF {
  318. isConnectionClosedError = true
  319. } else if err.Error() == "use of closed network connection" {
  320. isConnectionClosedError = true
  321. }
  322. if isConnectionClosedError {
  323. return
  324. }
  325. b.t.Errorf(err.Error())
  326. }
  327. // NewMockBroker launches a fake Kafka broker. It takes a TestReporter as provided by the
  328. // test framework and a channel of responses to use. If an error occurs it is
  329. // simply logged to the TestReporter and the broker exits.
  330. func NewMockBroker(t TestReporter, brokerID int32) *MockBroker {
  331. return NewMockBrokerAddr(t, brokerID, "localhost:0")
  332. }
  333. // NewMockBrokerAddr behaves like newMockBroker but listens on the address you give
  334. // it rather than just some ephemeral port.
  335. func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker {
  336. listener, err := net.Listen("tcp", addr)
  337. if err != nil {
  338. t.Fatal(err)
  339. }
  340. return NewMockBrokerListener(t, brokerID, listener)
  341. }
  342. // NewMockBrokerListener behaves like newMockBrokerAddr but accepts connections on the listener specified.
  343. func NewMockBrokerListener(t TestReporter, brokerID int32, listener net.Listener) *MockBroker {
  344. var err error
  345. broker := &MockBroker{
  346. closing: make(chan none),
  347. stopper: make(chan none),
  348. t: t,
  349. brokerID: brokerID,
  350. expectations: make(chan encoderWithHeader, 512),
  351. listener: listener,
  352. }
  353. broker.handler = broker.defaultRequestHandler
  354. Logger.Printf("*** mockbroker/%d listening on %s\n", brokerID, broker.listener.Addr().String())
  355. _, portStr, err := net.SplitHostPort(broker.listener.Addr().String())
  356. if err != nil {
  357. t.Fatal(err)
  358. }
  359. tmp, err := strconv.ParseInt(portStr, 10, 32)
  360. if err != nil {
  361. t.Fatal(err)
  362. }
  363. broker.port = int32(tmp)
  364. go broker.serverLoop()
  365. return broker
  366. }
  367. func (b *MockBroker) Returns(e encoderWithHeader) {
  368. b.expectations <- e
  369. }