// Copyright 2014 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package http2 import ( "bytes" "context" "crypto/tls" "errors" "flag" "fmt" "io" "io/ioutil" "log" "net" "net/http" "net/http/httptest" "os" "os/exec" "reflect" "runtime" "strconv" "strings" "sync" "sync/atomic" "testing" "time" "golang.org/x/net/http2/hpack" ) var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered") func stderrv() io.Writer { if *stderrVerbose { return os.Stderr } return ioutil.Discard } type serverTester struct { cc net.Conn // client conn t testing.TB ts *httptest.Server fr *Framer serverLogBuf bytes.Buffer // logger for httptest.Server logFilter []string // substrings to filter out scMu sync.Mutex // guards sc sc *serverConn hpackDec *hpack.Decoder decodedHeaders [][2]string // If http2debug!=2, then we capture Frame debug logs that will be written // to t.Log after a test fails. The read and write logs use separate locks // and buffers so we don't accidentally introduce synchronization between // the read and write goroutines, which may hide data races. frameReadLogMu sync.Mutex frameReadLogBuf bytes.Buffer frameWriteLogMu sync.Mutex frameWriteLogBuf bytes.Buffer // writing headers: headerBuf bytes.Buffer hpackEnc *hpack.Encoder } func init() { testHookOnPanicMu = new(sync.Mutex) goAwayTimeout = 25 * time.Millisecond } func resetHooks() { testHookOnPanicMu.Lock() testHookOnPanic = nil testHookOnPanicMu.Unlock() } type serverTesterOpt string var optOnlyServer = serverTesterOpt("only_server") var optQuiet = serverTesterOpt("quiet_logging") var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames") func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester { resetHooks() ts := httptest.NewUnstartedServer(handler) tlsConfig := &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{NextProtoTLS}, } var onlyServer, quiet, framerReuseFrames bool h2server := new(Server) for _, opt := range opts { switch v := opt.(type) { case func(*tls.Config): v(tlsConfig) case func(*httptest.Server): v(ts) case func(*Server): v(h2server) case serverTesterOpt: switch v { case optOnlyServer: onlyServer = true case optQuiet: quiet = true case optFramerReuseFrames: framerReuseFrames = true } case func(net.Conn, http.ConnState): ts.Config.ConnState = v default: t.Fatalf("unknown newServerTester option type %T", v) } } ConfigureServer(ts.Config, h2server) st := &serverTester{ t: t, ts: ts, } st.hpackEnc = hpack.NewEncoder(&st.headerBuf) st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField) ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config if quiet { ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) } else { ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags) } ts.StartTLS() if VerboseLogs { t.Logf("Running test server at: %s", ts.URL) } testHookGetServerConn = func(v *serverConn) { st.scMu.Lock() defer st.scMu.Unlock() st.sc = v } log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st})) if !onlyServer { cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) if err != nil { t.Fatal(err) } st.cc = cc st.fr = NewFramer(cc, cc) if framerReuseFrames { st.fr.SetReuseFrames() } if !logFrameReads && !logFrameWrites { st.fr.debugReadLoggerf = func(m string, v ...interface{}) { m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n" st.frameReadLogMu.Lock() fmt.Fprintf(&st.frameReadLogBuf, m, v...) st.frameReadLogMu.Unlock() } st.fr.debugWriteLoggerf = func(m string, v ...interface{}) { m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n" st.frameWriteLogMu.Lock() fmt.Fprintf(&st.frameWriteLogBuf, m, v...) st.frameWriteLogMu.Unlock() } st.fr.logReads = true st.fr.logWrites = true } } return st } func (st *serverTester) closeConn() { st.scMu.Lock() defer st.scMu.Unlock() st.sc.conn.Close() } func (st *serverTester) addLogFilter(phrase string) { st.logFilter = append(st.logFilter, phrase) } func (st *serverTester) stream(id uint32) *stream { ch := make(chan *stream, 1) st.sc.serveMsgCh <- func(int) { ch <- st.sc.streams[id] } return <-ch } func (st *serverTester) streamState(id uint32) streamState { ch := make(chan streamState, 1) st.sc.serveMsgCh <- func(int) { state, _ := st.sc.state(id) ch <- state } return <-ch } // loopNum reports how many times this conn's select loop has gone around. func (st *serverTester) loopNum() int { lastc := make(chan int, 1) st.sc.serveMsgCh <- func(loopNum int) { lastc <- loopNum } return <-lastc } // awaitIdle heuristically awaits for the server conn's select loop to be idle. // The heuristic is that the server connection's serve loop must schedule // 50 times in a row without any channel sends or receives occurring. func (st *serverTester) awaitIdle() { remain := 50 last := st.loopNum() for remain > 0 { n := st.loopNum() if n == last+1 { remain-- } else { remain = 50 } last = n } } func (st *serverTester) Close() { if st.t.Failed() { st.frameReadLogMu.Lock() if st.frameReadLogBuf.Len() > 0 { st.t.Logf("Framer read log:\n%s", st.frameReadLogBuf.String()) } st.frameReadLogMu.Unlock() st.frameWriteLogMu.Lock() if st.frameWriteLogBuf.Len() > 0 { st.t.Logf("Framer write log:\n%s", st.frameWriteLogBuf.String()) } st.frameWriteLogMu.Unlock() // If we failed already (and are likely in a Fatal, // unwindowing), force close the connection, so the // httptest.Server doesn't wait forever for the conn // to close. if st.cc != nil { st.cc.Close() } } st.ts.Close() if st.cc != nil { st.cc.Close() } log.SetOutput(os.Stderr) } // greet initiates the client's HTTP/2 connection into a state where // frames may be sent. func (st *serverTester) greet() { st.greetAndCheckSettings(func(Setting) error { return nil }) } func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error) { st.writePreface() st.writeInitialSettings() st.wantSettings().ForeachSetting(checkSetting) st.writeSettingsAck() // The initial WINDOW_UPDATE and SETTINGS ACK can come in any order. var gotSettingsAck bool var gotWindowUpdate bool for i := 0; i < 2; i++ { f, err := st.readFrame() if err != nil { st.t.Fatal(err) } switch f := f.(type) { case *SettingsFrame: if !f.Header().Flags.Has(FlagSettingsAck) { st.t.Fatal("Settings Frame didn't have ACK set") } gotSettingsAck = true case *WindowUpdateFrame: if f.FrameHeader.StreamID != 0 { st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID) } incr := uint32((&Server{}).initialConnRecvWindowSize() - initialWindowSize) if f.Increment != incr { st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr) } gotWindowUpdate = true default: st.t.Fatalf("Wanting a settings ACK or window update, received a %T", f) } } if !gotSettingsAck { st.t.Fatalf("Didn't get a settings ACK") } if !gotWindowUpdate { st.t.Fatalf("Didn't get a window update") } } func (st *serverTester) writePreface() { n, err := st.cc.Write(clientPreface) if err != nil { st.t.Fatalf("Error writing client preface: %v", err) } if n != len(clientPreface) { st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(clientPreface)) } } func (st *serverTester) writeInitialSettings() { if err := st.fr.WriteSettings(); err != nil { st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err) } } func (st *serverTester) writeSettingsAck() { if err := st.fr.WriteSettingsAck(); err != nil { st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err) } } func (st *serverTester) writeHeaders(p HeadersFrameParam) { if err := st.fr.WriteHeaders(p); err != nil { st.t.Fatalf("Error writing HEADERS: %v", err) } } func (st *serverTester) writePriority(id uint32, p PriorityParam) { if err := st.fr.WritePriority(id, p); err != nil { st.t.Fatalf("Error writing PRIORITY: %v", err) } } func (st *serverTester) encodeHeaderField(k, v string) { err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) if err != nil { st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) } } // encodeHeaderRaw is the magic-free version of encodeHeader. // It takes 0 or more (k, v) pairs and encodes them. func (st *serverTester) encodeHeaderRaw(headers ...string) []byte { if len(headers)%2 == 1 { panic("odd number of kv args") } st.headerBuf.Reset() for len(headers) > 0 { k, v := headers[0], headers[1] st.encodeHeaderField(k, v) headers = headers[2:] } return st.headerBuf.Bytes() } // encodeHeader encodes headers and returns their HPACK bytes. headers // must contain an even number of key/value pairs. There may be // multiple pairs for keys (e.g. "cookie"). The :method, :path, and // :scheme headers default to GET, / and https. The :authority header // defaults to st.ts.Listener.Addr(). func (st *serverTester) encodeHeader(headers ...string) []byte { if len(headers)%2 == 1 { panic("odd number of kv args") } st.headerBuf.Reset() defaultAuthority := st.ts.Listener.Addr().String() if len(headers) == 0 { // Fast path, mostly for benchmarks, so test code doesn't pollute // profiles when we're looking to improve server allocations. st.encodeHeaderField(":method", "GET") st.encodeHeaderField(":scheme", "https") st.encodeHeaderField(":authority", defaultAuthority) st.encodeHeaderField(":path", "/") return st.headerBuf.Bytes() } if len(headers) == 2 && headers[0] == ":method" { // Another fast path for benchmarks. st.encodeHeaderField(":method", headers[1]) st.encodeHeaderField(":scheme", "https") st.encodeHeaderField(":authority", defaultAuthority) st.encodeHeaderField(":path", "/") return st.headerBuf.Bytes() } pseudoCount := map[string]int{} keys := []string{":method", ":scheme", ":authority", ":path"} vals := map[string][]string{ ":method": {"GET"}, ":scheme": {"https"}, ":authority": {defaultAuthority}, ":path": {"/"}, } for len(headers) > 0 { k, v := headers[0], headers[1] headers = headers[2:] if _, ok := vals[k]; !ok { keys = append(keys, k) } if strings.HasPrefix(k, ":") { pseudoCount[k]++ if pseudoCount[k] == 1 { vals[k] = []string{v} } else { // Allows testing of invalid headers w/ dup pseudo fields. vals[k] = append(vals[k], v) } } else { vals[k] = append(vals[k], v) } } for _, k := range keys { for _, v := range vals[k] { st.encodeHeaderField(k, v) } } return st.headerBuf.Bytes() } // bodylessReq1 writes a HEADERS frames with StreamID 1 and EndStream and EndHeaders set. func (st *serverTester) bodylessReq1(headers ...string) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(headers...), EndStream: true, EndHeaders: true, }) } func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) { if err := st.fr.WriteData(streamID, endStream, data); err != nil { st.t.Fatalf("Error writing DATA: %v", err) } } func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) { if err := st.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil { st.t.Fatalf("Error writing DATA: %v", err) } } func readFrameTimeout(fr *Framer, wait time.Duration) (Frame, error) { ch := make(chan interface{}, 1) go func() { fr, err := fr.ReadFrame() if err != nil { ch <- err } else { ch <- fr } }() t := time.NewTimer(wait) select { case v := <-ch: t.Stop() if fr, ok := v.(Frame); ok { return fr, nil } return nil, v.(error) case <-t.C: return nil, errors.New("timeout waiting for frame") } } func (st *serverTester) readFrame() (Frame, error) { return readFrameTimeout(st.fr, 2*time.Second) } func (st *serverTester) wantHeaders() *HeadersFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a HEADERS frame: %v", err) } hf, ok := f.(*HeadersFrame) if !ok { st.t.Fatalf("got a %T; want *HeadersFrame", f) } return hf } func (st *serverTester) wantContinuation() *ContinuationFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a CONTINUATION frame: %v", err) } cf, ok := f.(*ContinuationFrame) if !ok { st.t.Fatalf("got a %T; want *ContinuationFrame", f) } return cf } func (st *serverTester) wantData() *DataFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a DATA frame: %v", err) } df, ok := f.(*DataFrame) if !ok { st.t.Fatalf("got a %T; want *DataFrame", f) } return df } func (st *serverTester) wantSettings() *SettingsFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err) } sf, ok := f.(*SettingsFrame) if !ok { st.t.Fatalf("got a %T; want *SettingsFrame", f) } return sf } func (st *serverTester) wantPing() *PingFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a PING frame: %v", err) } pf, ok := f.(*PingFrame) if !ok { st.t.Fatalf("got a %T; want *PingFrame", f) } return pf } func (st *serverTester) wantGoAway() *GoAwayFrame { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a GOAWAY frame: %v", err) } gf, ok := f.(*GoAwayFrame) if !ok { st.t.Fatalf("got a %T; want *GoAwayFrame", f) } return gf } func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting an RSTStream frame: %v", err) } rs, ok := f.(*RSTStreamFrame) if !ok { st.t.Fatalf("got a %T; want *RSTStreamFrame", f) } if rs.FrameHeader.StreamID != streamID { st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID) } if rs.ErrCode != errCode { st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode) } } func (st *serverTester) wantWindowUpdate(streamID, incr uint32) { f, err := st.readFrame() if err != nil { st.t.Fatalf("Error while expecting a WINDOW_UPDATE frame: %v", err) } wu, ok := f.(*WindowUpdateFrame) if !ok { st.t.Fatalf("got a %T; want *WindowUpdateFrame", f) } if wu.FrameHeader.StreamID != streamID { st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID) } if wu.Increment != incr { st.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr) } } func (st *serverTester) wantSettingsAck() { f, err := st.readFrame() if err != nil { st.t.Fatal(err) } sf, ok := f.(*SettingsFrame) if !ok { st.t.Fatalf("Wanting a settings ACK, received a %T", f) } if !sf.Header().Flags.Has(FlagSettingsAck) { st.t.Fatal("Settings Frame didn't have ACK set") } } func (st *serverTester) wantPushPromise() *PushPromiseFrame { f, err := st.readFrame() if err != nil { st.t.Fatal(err) } ppf, ok := f.(*PushPromiseFrame) if !ok { st.t.Fatalf("Wanted PushPromise, received %T", ppf) } return ppf } func TestServer(t *testing.T) { gotReq := make(chan bool, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Foo", "Bar") gotReq <- true }) defer st.Close() covers("3.5", ` The server connection preface consists of a potentially empty SETTINGS frame ([SETTINGS]) that MUST be the first frame the server sends in the HTTP/2 connection. `) st.greet() st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(), EndStream: true, // no DATA frames EndHeaders: true, }) select { case <-gotReq: case <-time.After(2 * time.Second): t.Error("timeout waiting for request") } } func TestServer_Request_Get(t *testing.T) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader("foo-bar", "some-value"), EndStream: true, // no DATA frames EndHeaders: true, }) }, func(r *http.Request) { if r.Method != "GET" { t.Errorf("Method = %q; want GET", r.Method) } if r.URL.Path != "/" { t.Errorf("URL.Path = %q; want /", r.URL.Path) } if r.ContentLength != 0 { t.Errorf("ContentLength = %v; want 0", r.ContentLength) } if r.Close { t.Error("Close = true; want false") } if !strings.Contains(r.RemoteAddr, ":") { t.Errorf("RemoteAddr = %q; want something with a colon", r.RemoteAddr) } if r.Proto != "HTTP/2.0" || r.ProtoMajor != 2 || r.ProtoMinor != 0 { t.Errorf("Proto = %q Major=%v,Minor=%v; want HTTP/2.0", r.Proto, r.ProtoMajor, r.ProtoMinor) } wantHeader := http.Header{ "Foo-Bar": []string{"some-value"}, } if !reflect.DeepEqual(r.Header, wantHeader) { t.Errorf("Header = %#v; want %#v", r.Header, wantHeader) } if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 { t.Errorf("Read = %d, %v; want 0, EOF", n, err) } }) } func TestServer_Request_Get_PathSlashes(t *testing.T) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(":path", "/%2f/"), EndStream: true, // no DATA frames EndHeaders: true, }) }, func(r *http.Request) { if r.RequestURI != "/%2f/" { t.Errorf("RequestURI = %q; want /%%2f/", r.RequestURI) } if r.URL.Path != "///" { t.Errorf("URL.Path = %q; want ///", r.URL.Path) } }) } // TODO: add a test with EndStream=true on the HEADERS but setting a // Content-Length anyway. Should we just omit it and force it to // zero? func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(":method", "POST"), EndStream: true, EndHeaders: true, }) }, func(r *http.Request) { if r.Method != "POST" { t.Errorf("Method = %q; want POST", r.Method) } if r.ContentLength != 0 { t.Errorf("ContentLength = %v; want 0", r.ContentLength) } if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 { t.Errorf("Read = %d, %v; want 0, EOF", n, err) } }) } func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) { testBodyContents(t, -1, "", func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(":method", "POST"), EndStream: false, // to say DATA frames are coming EndHeaders: true, }) st.writeData(1, true, nil) // just kidding. empty body. }) } func TestServer_Request_Post_Body_OneData(t *testing.T) { const content = "Some content" testBodyContents(t, -1, content, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(":method", "POST"), EndStream: false, // to say DATA frames are coming EndHeaders: true, }) st.writeData(1, true, []byte(content)) }) } func TestServer_Request_Post_Body_TwoData(t *testing.T) { const content = "Some content" testBodyContents(t, -1, content, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(":method", "POST"), EndStream: false, // to say DATA frames are coming EndHeaders: true, }) st.writeData(1, false, []byte(content[:5])) st.writeData(1, true, []byte(content[5:])) }) } func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) { const content = "Some content" testBodyContents(t, int64(len(content)), content, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader( ":method", "POST", "content-length", strconv.Itoa(len(content)), ), EndStream: false, // to say DATA frames are coming EndHeaders: true, }) st.writeData(1, true, []byte(content)) }) } func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) { testBodyContentsFail(t, 3, "request declared a Content-Length of 3 but only wrote 2 bytes", func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader( ":method", "POST", "content-length", "3", ), EndStream: false, // to say DATA frames are coming EndHeaders: true, }) st.writeData(1, true, []byte("12")) }) } func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) { testBodyContentsFail(t, 4, "sender tried to send more than declared Content-Length of 4 bytes", func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader( ":method", "POST", "content-length", "4", ), EndStream: false, // to say DATA frames are coming EndHeaders: true, }) st.writeData(1, true, []byte("12345")) }) } func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, write func(st *serverTester)) { testServerRequest(t, write, func(r *http.Request) { if r.Method != "POST" { t.Errorf("Method = %q; want POST", r.Method) } if r.ContentLength != wantContentLength { t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength) } all, err := ioutil.ReadAll(r.Body) if err != nil { t.Fatal(err) } if string(all) != wantBody { t.Errorf("Read = %q; want %q", all, wantBody) } if err := r.Body.Close(); err != nil { t.Fatalf("Close: %v", err) } }) } func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError string, write func(st *serverTester)) { testServerRequest(t, write, func(r *http.Request) { if r.Method != "POST" { t.Errorf("Method = %q; want POST", r.Method) } if r.ContentLength != wantContentLength { t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength) } all, err := ioutil.ReadAll(r.Body) if err == nil { t.Fatalf("expected an error (%q) reading from the body. Successfully read %q instead.", wantReadError, all) } if !strings.Contains(err.Error(), wantReadError) { t.Fatalf("Body.Read = %v; want substring %q", err, wantReadError) } if err := r.Body.Close(); err != nil { t.Fatalf("Close: %v", err) } }) } // Using a Host header, instead of :authority func TestServer_Request_Get_Host(t *testing.T) { const host = "example.com" testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(":authority", "", "host", host), EndStream: true, EndHeaders: true, }) }, func(r *http.Request) { if r.Host != host { t.Errorf("Host = %q; want %q", r.Host, host) } }) } // Using an :authority pseudo-header, instead of Host func TestServer_Request_Get_Authority(t *testing.T) { const host = "example.com" testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(":authority", host), EndStream: true, EndHeaders: true, }) }, func(r *http.Request) { if r.Host != host { t.Errorf("Host = %q; want %q", r.Host, host) } }) } func TestServer_Request_WithContinuation(t *testing.T) { wantHeader := http.Header{ "Foo-One": []string{"value-one"}, "Foo-Two": []string{"value-two"}, "Foo-Three": []string{"value-three"}, } testServerRequest(t, func(st *serverTester) { fullHeaders := st.encodeHeader( "foo-one", "value-one", "foo-two", "value-two", "foo-three", "value-three", ) remain := fullHeaders chunks := 0 for len(remain) > 0 { const maxChunkSize = 5 chunk := remain if len(chunk) > maxChunkSize { chunk = chunk[:maxChunkSize] } remain = remain[len(chunk):] if chunks == 0 { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: chunk, EndStream: true, // no DATA frames EndHeaders: false, // we'll have continuation frames }) } else { err := st.fr.WriteContinuation(1, len(remain) == 0, chunk) if err != nil { t.Fatal(err) } } chunks++ } if chunks < 2 { t.Fatal("too few chunks") } }, func(r *http.Request) { if !reflect.DeepEqual(r.Header, wantHeader) { t.Errorf("Header = %#v; want %#v", r.Header, wantHeader) } }) } // Concatenated cookie headers. ("8.1.2.5 Compressing the Cookie Header Field") func TestServer_Request_CookieConcat(t *testing.T) { const host = "example.com" testServerRequest(t, func(st *serverTester) { st.bodylessReq1( ":authority", host, "cookie", "a=b", "cookie", "c=d", "cookie", "e=f", ) }, func(r *http.Request) { const want = "a=b; c=d; e=f" if got := r.Header.Get("Cookie"); got != want { t.Errorf("Cookie = %q; want %q", got, want) } }) } func TestServer_Request_Reject_CapitalHeader(t *testing.T) { testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("UPPER", "v") }) } func TestServer_Request_Reject_HeaderFieldNameColon(t *testing.T) { testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has:colon", "v") }) } func TestServer_Request_Reject_HeaderFieldNameNULL(t *testing.T) { testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has\x00null", "v") }) } func TestServer_Request_Reject_HeaderFieldNameEmpty(t *testing.T) { testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("", "v") }) } func TestServer_Request_Reject_HeaderFieldValueNewline(t *testing.T) { testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\nnewline") }) } func TestServer_Request_Reject_HeaderFieldValueCR(t *testing.T) { testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\rcarriage") }) } func TestServer_Request_Reject_HeaderFieldValueDEL(t *testing.T) { testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\x7fdel") }) } func TestServer_Request_Reject_Pseudo_Missing_method(t *testing.T) { testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":method", "") }) } func TestServer_Request_Reject_Pseudo_ExactlyOne(t *testing.T) { // 8.1.2.3 Request Pseudo-Header Fields // "All HTTP/2 requests MUST include exactly one valid value" ... testRejectRequest(t, func(st *serverTester) { st.addLogFilter("duplicate pseudo-header") st.bodylessReq1(":method", "GET", ":method", "POST") }) } func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) { // 8.1.2.3 Request Pseudo-Header Fields // "All pseudo-header fields MUST appear in the header block // before regular header fields. Any request or response that // contains a pseudo-header field that appears in a header // block after a regular header field MUST be treated as // malformed (Section 8.1.2.6)." testRejectRequest(t, func(st *serverTester) { st.addLogFilter("pseudo-header after regular header") var buf bytes.Buffer enc := hpack.NewEncoder(&buf) enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"}) enc.WriteField(hpack.HeaderField{Name: "regular", Value: "foobar"}) enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"}) enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"}) st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: buf.Bytes(), EndStream: true, EndHeaders: true, }) }) } func TestServer_Request_Reject_Pseudo_Missing_path(t *testing.T) { testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":path", "") }) } func TestServer_Request_Reject_Pseudo_Missing_scheme(t *testing.T) { testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "") }) } func TestServer_Request_Reject_Pseudo_scheme_invalid(t *testing.T) { testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "bogus") }) } func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) { testRejectRequest(t, func(st *serverTester) { st.addLogFilter(`invalid pseudo-header ":unknown_thing"`) st.bodylessReq1(":unknown_thing", "") }) } func testRejectRequest(t *testing.T, send func(*serverTester)) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { t.Error("server request made it to handler; should've been rejected") }) defer st.Close() st.greet() send(st) st.wantRSTStream(1, ErrCodeProtocol) } func testRejectRequestWithProtocolError(t *testing.T, send func(*serverTester)) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { t.Error("server request made it to handler; should've been rejected") }, optQuiet) defer st.Close() st.greet() send(st) gf := st.wantGoAway() if gf.ErrCode != ErrCodeProtocol { t.Errorf("err code = %v; want %v", gf.ErrCode, ErrCodeProtocol) } } // Section 5.1, on idle connections: "Receiving any frame other than // HEADERS or PRIORITY on a stream in this state MUST be treated as a // connection error (Section 5.4.1) of type PROTOCOL_ERROR." func TestRejectFrameOnIdle_WindowUpdate(t *testing.T) { testRejectRequestWithProtocolError(t, func(st *serverTester) { st.fr.WriteWindowUpdate(123, 456) }) } func TestRejectFrameOnIdle_Data(t *testing.T) { testRejectRequestWithProtocolError(t, func(st *serverTester) { st.fr.WriteData(123, true, nil) }) } func TestRejectFrameOnIdle_RSTStream(t *testing.T) { testRejectRequestWithProtocolError(t, func(st *serverTester) { st.fr.WriteRSTStream(123, ErrCodeCancel) }) } func TestServer_Request_Connect(t *testing.T) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeaderRaw( ":method", "CONNECT", ":authority", "example.com:123", ), EndStream: true, EndHeaders: true, }) }, func(r *http.Request) { if g, w := r.Method, "CONNECT"; g != w { t.Errorf("Method = %q; want %q", g, w) } if g, w := r.RequestURI, "example.com:123"; g != w { t.Errorf("RequestURI = %q; want %q", g, w) } if g, w := r.URL.Host, "example.com:123"; g != w { t.Errorf("URL.Host = %q; want %q", g, w) } }) } func TestServer_Request_Connect_InvalidPath(t *testing.T) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeaderRaw( ":method", "CONNECT", ":authority", "example.com:123", ":path", "/bogus", ), EndStream: true, EndHeaders: true, }) }) } func TestServer_Request_Connect_InvalidScheme(t *testing.T) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeaderRaw( ":method", "CONNECT", ":authority", "example.com:123", ":scheme", "https", ), EndStream: true, EndHeaders: true, }) }) } func TestServer_Ping(t *testing.T) { st := newServerTester(t, nil) defer st.Close() st.greet() // Server should ignore this one, since it has ACK set. ackPingData := [8]byte{1, 2, 4, 8, 16, 32, 64, 128} if err := st.fr.WritePing(true, ackPingData); err != nil { t.Fatal(err) } // But the server should reply to this one, since ACK is false. pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} if err := st.fr.WritePing(false, pingData); err != nil { t.Fatal(err) } pf := st.wantPing() if !pf.Flags.Has(FlagPingAck) { t.Error("response ping doesn't have ACK set") } if pf.Data != pingData { t.Errorf("response ping has data %q; want %q", pf.Data, pingData) } } func TestServer_RejectsLargeFrames(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("see golang.org/issue/13434") } st := newServerTester(t, nil) defer st.Close() st.greet() // Write too large of a frame (too large by one byte) // We ignore the return value because it's expected that the server // will only read the first 9 bytes (the headre) and then disconnect. st.fr.WriteRawFrame(0xff, 0, 0, make([]byte, defaultMaxReadFrameSize+1)) gf := st.wantGoAway() if gf.ErrCode != ErrCodeFrameSize { t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFrameSize) } if st.serverLogBuf.Len() != 0 { // Previously we spun here for a bit until the GOAWAY disconnect // timer fired, logging while we fired. t.Errorf("unexpected server output: %.500s\n", st.serverLogBuf.Bytes()) } } func TestServer_Handler_Sends_WindowUpdate(t *testing.T) { puppet := newHandlerPuppet() st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { puppet.act(w, r) }) defer st.Close() defer puppet.done() st.greet() st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(":method", "POST"), EndStream: false, // data coming EndHeaders: true, }) st.writeData(1, false, []byte("abcdef")) puppet.do(readBodyHandler(t, "abc")) st.wantWindowUpdate(0, 3) st.wantWindowUpdate(1, 3) puppet.do(readBodyHandler(t, "def")) st.wantWindowUpdate(0, 3) st.wantWindowUpdate(1, 3) st.writeData(1, true, []byte("ghijkl")) // END_STREAM here puppet.do(readBodyHandler(t, "ghi")) puppet.do(readBodyHandler(t, "jkl")) st.wantWindowUpdate(0, 3) st.wantWindowUpdate(0, 3) // no more stream-level, since END_STREAM } // the version of the TestServer_Handler_Sends_WindowUpdate with padding. // See golang.org/issue/16556 func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) { puppet := newHandlerPuppet() st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { puppet.act(w, r) }) defer st.Close() defer puppet.done() st.greet() st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(":method", "POST"), EndStream: false, EndHeaders: true, }) st.writeDataPadded(1, false, []byte("abcdef"), []byte{0, 0, 0, 0}) // Expect to immediately get our 5 bytes of padding back for // both the connection and stream (4 bytes of padding + 1 byte of length) st.wantWindowUpdate(0, 5) st.wantWindowUpdate(1, 5) puppet.do(readBodyHandler(t, "abc")) st.wantWindowUpdate(0, 3) st.wantWindowUpdate(1, 3) puppet.do(readBodyHandler(t, "def")) st.wantWindowUpdate(0, 3) st.wantWindowUpdate(1, 3) } func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) { st := newServerTester(t, nil) defer st.Close() st.greet() if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil { t.Fatal(err) } gf := st.wantGoAway() if gf.ErrCode != ErrCodeFlowControl { t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFlowControl) } if gf.LastStreamID != 0 { t.Errorf("GOAWAY last stream ID = %v; want %v", gf.LastStreamID, 0) } } func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) { inHandler := make(chan bool) blockHandler := make(chan bool) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { inHandler <- true <-blockHandler }) defer st.Close() defer close(blockHandler) st.greet() st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(":method", "POST"), EndStream: false, // keep it open EndHeaders: true, }) <-inHandler // Send a bogus window update: if err := st.fr.WriteWindowUpdate(1, 1<<31-1); err != nil { t.Fatal(err) } st.wantRSTStream(1, ErrCodeFlowControl) } // testServerPostUnblock sends a hanging POST with unsent data to handler, // then runs fn once in the handler, and verifies that the error returned from // handler is acceptable. It fails if takes over 5 seconds for handler to exit. func testServerPostUnblock(t *testing.T, handler func(http.ResponseWriter, *http.Request) error, fn func(*serverTester), checkErr func(error), otherHeaders ...string) { inHandler := make(chan bool) errc := make(chan error, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { inHandler <- true errc <- handler(w, r) }) defer st.Close() st.greet() st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(append([]string{":method", "POST"}, otherHeaders...)...), EndStream: false, // keep it open EndHeaders: true, }) <-inHandler fn(st) select { case err := <-errc: if checkErr != nil { checkErr(err) } case <-time.After(5 * time.Second): t.Fatal("timeout waiting for Handler to return") } } func TestServer_RSTStream_Unblocks_Read(t *testing.T) { testServerPostUnblock(t, func(w http.ResponseWriter, r *http.Request) (err error) { _, err = r.Body.Read(make([]byte, 1)) return }, func(st *serverTester) { if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil { t.Fatal(err) } }, func(err error) { want := StreamError{StreamID: 0x1, Code: 0x8} if !reflect.DeepEqual(err, want) { t.Errorf("Read error = %v; want %v", err, want) } }, ) } func TestServer_RSTStream_Unblocks_Header_Write(t *testing.T) { // Run this test a bunch, because it doesn't always // deadlock. But with a bunch, it did. n := 50 if testing.Short() { n = 5 } for i := 0; i < n; i++ { testServer_RSTStream_Unblocks_Header_Write(t) } } func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) { inHandler := make(chan bool, 1) unblockHandler := make(chan bool, 1) headerWritten := make(chan bool, 1) wroteRST := make(chan bool, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { inHandler <- true <-wroteRST w.Header().Set("foo", "bar") w.WriteHeader(200) w.(http.Flusher).Flush() headerWritten <- true <-unblockHandler }) defer st.Close() st.greet() st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(":method", "POST"), EndStream: false, // keep it open EndHeaders: true, }) <-inHandler if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil { t.Fatal(err) } wroteRST <- true st.awaitIdle() select { case <-headerWritten: case <-time.After(2 * time.Second): t.Error("timeout waiting for header write") } unblockHandler <- true } func TestServer_DeadConn_Unblocks_Read(t *testing.T) { testServerPostUnblock(t, func(w http.ResponseWriter, r *http.Request) (err error) { _, err = r.Body.Read(make([]byte, 1)) return }, func(st *serverTester) { st.cc.Close() }, func(err error) { if err == nil { t.Error("unexpected nil error from Request.Body.Read") } }, ) } var blockUntilClosed = func(w http.ResponseWriter, r *http.Request) error { <-w.(http.CloseNotifier).CloseNotify() return nil } func TestServer_CloseNotify_After_RSTStream(t *testing.T) { testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil { t.Fatal(err) } }, nil) } func TestServer_CloseNotify_After_ConnClose(t *testing.T) { testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { st.cc.Close() }, nil) } // that CloseNotify unblocks after a stream error due to the client's // problem that's unrelated to them explicitly canceling it (which is // TestServer_CloseNotify_After_RSTStream above) func TestServer_CloseNotify_After_StreamError(t *testing.T) { testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { // data longer than declared Content-Length => stream error st.writeData(1, true, []byte("1234")) }, nil, "content-length", "3") } func TestServer_StateTransitions(t *testing.T) { var st *serverTester inHandler := make(chan bool) writeData := make(chan bool) leaveHandler := make(chan bool) st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { inHandler <- true if st.stream(1) == nil { t.Errorf("nil stream 1 in handler") } if got, want := st.streamState(1), stateOpen; got != want { t.Errorf("in handler, state is %v; want %v", got, want) } writeData <- true if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF { t.Errorf("body read = %d, %v; want 0, EOF", n, err) } if got, want := st.streamState(1), stateHalfClosedRemote; got != want { t.Errorf("in handler, state is %v; want %v", got, want) } <-leaveHandler }) st.greet() if st.stream(1) != nil { t.Fatal("stream 1 should be empty") } if got := st.streamState(1); got != stateIdle { t.Fatalf("stream 1 should be idle; got %v", got) } st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(":method", "POST"), EndStream: false, // keep it open EndHeaders: true, }) <-inHandler <-writeData st.writeData(1, true, nil) leaveHandler <- true hf := st.wantHeaders() if !hf.StreamEnded() { t.Fatal("expected END_STREAM flag") } if got, want := st.streamState(1), stateClosed; got != want { t.Errorf("at end, state is %v; want %v", got, want) } if st.stream(1) != nil { t.Fatal("at end, stream 1 should be gone") } } // test HEADERS w/o EndHeaders + another HEADERS (should get rejected) func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) { testServerRejectsConn(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(), EndStream: true, EndHeaders: false, }) st.writeHeaders(HeadersFrameParam{ // Not a continuation. StreamID: 3, // different stream. BlockFragment: st.encodeHeader(), EndStream: true, EndHeaders: true, }) }) } // test HEADERS w/o EndHeaders + PING (should get rejected) func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) { testServerRejectsConn(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(), EndStream: true, EndHeaders: false, }) if err := st.fr.WritePing(false, [8]byte{}); err != nil { t.Fatal(err) } }) } // test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected) func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) { testServerRejectsConn(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(), EndStream: true, EndHeaders: true, }) st.wantHeaders() if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil { t.Fatal(err) } }) } // test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) { testServerRejectsConn(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(), EndStream: true, EndHeaders: false, }) if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil { t.Fatal(err) } }) } // No HEADERS on stream 0. func TestServer_Rejects_Headers0(t *testing.T) { testServerRejectsConn(t, func(st *serverTester) { st.fr.AllowIllegalWrites = true st.writeHeaders(HeadersFrameParam{ StreamID: 0, BlockFragment: st.encodeHeader(), EndStream: true, EndHeaders: true, }) }) } // No CONTINUATION on stream 0. func TestServer_Rejects_Continuation0(t *testing.T) { testServerRejectsConn(t, func(st *serverTester) { st.fr.AllowIllegalWrites = true if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil { t.Fatal(err) } }) } // No PRIORITY on stream 0. func TestServer_Rejects_Priority0(t *testing.T) { testServerRejectsConn(t, func(st *serverTester) { st.fr.AllowIllegalWrites = true st.writePriority(0, PriorityParam{StreamDep: 1}) }) } // No HEADERS frame with a self-dependence. func TestServer_Rejects_HeadersSelfDependence(t *testing.T) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.fr.AllowIllegalWrites = true st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(), EndStream: true, EndHeaders: true, Priority: PriorityParam{StreamDep: 1}, }) }) } // No PRIORTY frame with a self-dependence. func TestServer_Rejects_PrioritySelfDependence(t *testing.T) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.fr.AllowIllegalWrites = true st.writePriority(1, PriorityParam{StreamDep: 1}) }) } func TestServer_Rejects_PushPromise(t *testing.T) { testServerRejectsConn(t, func(st *serverTester) { pp := PushPromiseParam{ StreamID: 1, PromiseID: 3, } if err := st.fr.WritePushPromise(pp); err != nil { t.Fatal(err) } }) } // testServerRejectsConn tests that the server hangs up with a GOAWAY // frame and a server close after the client does something // deserving a CONNECTION_ERROR. func testServerRejectsConn(t *testing.T, writeReq func(*serverTester)) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}) st.addLogFilter("connection error: PROTOCOL_ERROR") defer st.Close() st.greet() writeReq(st) st.wantGoAway() errc := make(chan error, 1) go func() { fr, err := st.fr.ReadFrame() if err == nil { err = fmt.Errorf("got frame of type %T", fr) } errc <- err }() select { case err := <-errc: if err != io.EOF { t.Errorf("ReadFrame = %v; want io.EOF", err) } case <-time.After(2 * time.Second): t.Error("timeout waiting for disconnect") } } // testServerRejectsStream tests that the server sends a RST_STREAM with the provided // error code after a client sends a bogus request. func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTester)) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}) defer st.Close() st.greet() writeReq(st) st.wantRSTStream(1, code) } // testServerRequest sets up an idle HTTP/2 connection and lets you // write a single request with writeReq, and then verify that the // *http.Request is built correctly in checkReq. func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) { gotReq := make(chan bool, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { if r.Body == nil { t.Fatal("nil Body") } checkReq(r) gotReq <- true }) defer st.Close() st.greet() writeReq(st) select { case <-gotReq: case <-time.After(2 * time.Second): t.Error("timeout waiting for request") } } func getSlash(st *serverTester) { st.bodylessReq1() } func TestServer_Response_NoData(t *testing.T) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { // Nothing. return nil }, func(st *serverTester) { getSlash(st) hf := st.wantHeaders() if !hf.StreamEnded() { t.Fatal("want END_STREAM flag") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } }) } func TestServer_Response_NoData_Header_FooBar(t *testing.T) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Foo-Bar", "some-value") return nil }, func(st *serverTester) { getSlash(st) hf := st.wantHeaders() if !hf.StreamEnded() { t.Fatal("want END_STREAM flag") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "200"}, {"foo-bar", "some-value"}, {"content-length", "0"}, } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Got headers %v; want %v", goth, wanth) } }) } func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) { const msg = "this is HTML." testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Content-Type", "foo/bar") io.WriteString(w, msg) return nil }, func(st *serverTester) { getSlash(st) hf := st.wantHeaders() if hf.StreamEnded() { t.Fatal("don't want END_STREAM, expecting data") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "200"}, {"content-type", "foo/bar"}, {"content-length", strconv.Itoa(len(msg))}, } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Got headers %v; want %v", goth, wanth) } df := st.wantData() if !df.StreamEnded() { t.Error("expected DATA to have END_STREAM flag") } if got := string(df.Data()); got != msg { t.Errorf("got DATA %q; want %q", got, msg) } }) } func TestServer_Response_TransferEncoding_chunked(t *testing.T) { const msg = "hi" testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Transfer-Encoding", "chunked") // should be stripped io.WriteString(w, msg) return nil }, func(st *serverTester) { getSlash(st) hf := st.wantHeaders() goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "200"}, {"content-type", "text/plain; charset=utf-8"}, {"content-length", strconv.Itoa(len(msg))}, } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Got headers %v; want %v", goth, wanth) } }) } // Header accessed only after the initial write. func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) { const msg = "this is HTML." testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { io.WriteString(w, msg) w.Header().Set("foo", "should be ignored") return nil }, func(st *serverTester) { getSlash(st) hf := st.wantHeaders() if hf.StreamEnded() { t.Fatal("unexpected END_STREAM") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "200"}, {"content-type", "text/html; charset=utf-8"}, {"content-length", strconv.Itoa(len(msg))}, } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Got headers %v; want %v", goth, wanth) } }) } // Header accessed before the initial write and later mutated. func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) { const msg = "this is HTML." testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("foo", "proper value") io.WriteString(w, msg) w.Header().Set("foo", "should be ignored") return nil }, func(st *serverTester) { getSlash(st) hf := st.wantHeaders() if hf.StreamEnded() { t.Fatal("unexpected END_STREAM") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "200"}, {"foo", "proper value"}, {"content-type", "text/html; charset=utf-8"}, {"content-length", strconv.Itoa(len(msg))}, } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Got headers %v; want %v", goth, wanth) } }) } func TestServer_Response_Data_SniffLenType(t *testing.T) { const msg = "this is HTML." testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { io.WriteString(w, msg) return nil }, func(st *serverTester) { getSlash(st) hf := st.wantHeaders() if hf.StreamEnded() { t.Fatal("don't want END_STREAM, expecting data") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "200"}, {"content-type", "text/html; charset=utf-8"}, {"content-length", strconv.Itoa(len(msg))}, } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Got headers %v; want %v", goth, wanth) } df := st.wantData() if !df.StreamEnded() { t.Error("expected DATA to have END_STREAM flag") } if got := string(df.Data()); got != msg { t.Errorf("got DATA %q; want %q", got, msg) } }) } func TestServer_Response_Header_Flush_MidWrite(t *testing.T) { const msg = "this is HTML" const msg2 = ", and this is the next chunk" testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { io.WriteString(w, msg) w.(http.Flusher).Flush() io.WriteString(w, msg2) return nil }, func(st *serverTester) { getSlash(st) hf := st.wantHeaders() if hf.StreamEnded() { t.Fatal("unexpected END_STREAM flag") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "200"}, {"content-type", "text/html; charset=utf-8"}, // sniffed // and no content-length } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Got headers %v; want %v", goth, wanth) } { df := st.wantData() if df.StreamEnded() { t.Error("unexpected END_STREAM flag") } if got := string(df.Data()); got != msg { t.Errorf("got DATA %q; want %q", got, msg) } } { df := st.wantData() if !df.StreamEnded() { t.Error("wanted END_STREAM flag on last data chunk") } if got := string(df.Data()); got != msg2 { t.Errorf("got DATA %q; want %q", got, msg2) } } }) } func TestServer_Response_LargeWrite(t *testing.T) { const size = 1 << 20 const maxFrameSize = 16 << 10 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { n, err := w.Write(bytes.Repeat([]byte("a"), size)) if err != nil { return fmt.Errorf("Write error: %v", err) } if n != size { return fmt.Errorf("wrong size %d from Write", n) } return nil }, func(st *serverTester) { if err := st.fr.WriteSettings( Setting{SettingInitialWindowSize, 0}, Setting{SettingMaxFrameSize, maxFrameSize}, ); err != nil { t.Fatal(err) } st.wantSettingsAck() getSlash(st) // make the single request // Give the handler quota to write: if err := st.fr.WriteWindowUpdate(1, size); err != nil { t.Fatal(err) } // Give the handler quota to write to connection-level // window as well if err := st.fr.WriteWindowUpdate(0, size); err != nil { t.Fatal(err) } hf := st.wantHeaders() if hf.StreamEnded() { t.Fatal("unexpected END_STREAM flag") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "200"}, {"content-type", "text/plain; charset=utf-8"}, // sniffed // and no content-length } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Got headers %v; want %v", goth, wanth) } var bytes, frames int for { df := st.wantData() bytes += len(df.Data()) frames++ for _, b := range df.Data() { if b != 'a' { t.Fatal("non-'a' byte seen in DATA") } } if df.StreamEnded() { break } } if bytes != size { t.Errorf("Got %d bytes; want %d", bytes, size) } if want := int(size / maxFrameSize); frames < want || frames > want*2 { t.Errorf("Got %d frames; want %d", frames, size) } }) } // Test that the handler can't write more than the client allows func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) { // Make these reads. Before each read, the client adds exactly enough // flow-control to satisfy the read. Numbers chosen arbitrarily. reads := []int{123, 1, 13, 127} size := 0 for _, n := range reads { size += n } testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.(http.Flusher).Flush() n, err := w.Write(bytes.Repeat([]byte("a"), size)) if err != nil { return fmt.Errorf("Write error: %v", err) } if n != size { return fmt.Errorf("wrong size %d from Write", n) } return nil }, func(st *serverTester) { // Set the window size to something explicit for this test. // It's also how much initial data we expect. if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, uint32(reads[0])}); err != nil { t.Fatal(err) } st.wantSettingsAck() getSlash(st) // make the single request hf := st.wantHeaders() if hf.StreamEnded() { t.Fatal("unexpected END_STREAM flag") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } df := st.wantData() if got := len(df.Data()); got != reads[0] { t.Fatalf("Initial window size = %d but got DATA with %d bytes", reads[0], got) } for _, quota := range reads[1:] { if err := st.fr.WriteWindowUpdate(1, uint32(quota)); err != nil { t.Fatal(err) } df := st.wantData() if int(quota) != len(df.Data()) { t.Fatalf("read %d bytes after giving %d quota", len(df.Data()), quota) } } }) } // Test that the handler blocked in a Write is unblocked if the server sends a RST_STREAM. func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) { const size = 1 << 20 const maxFrameSize = 16 << 10 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.(http.Flusher).Flush() errc := make(chan error, 1) go func() { _, err := w.Write(bytes.Repeat([]byte("a"), size)) errc <- err }() select { case err := <-errc: if err == nil { return errors.New("unexpected nil error from Write in handler") } return nil case <-time.After(2 * time.Second): return errors.New("timeout waiting for Write in handler") } }, func(st *serverTester) { if err := st.fr.WriteSettings( Setting{SettingInitialWindowSize, 0}, Setting{SettingMaxFrameSize, maxFrameSize}, ); err != nil { t.Fatal(err) } st.wantSettingsAck() getSlash(st) // make the single request hf := st.wantHeaders() if hf.StreamEnded() { t.Fatal("unexpected END_STREAM flag") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil { t.Fatal(err) } }) } func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.(http.Flusher).Flush() // Nothing; send empty DATA return nil }, func(st *serverTester) { // Handler gets no data quota: if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, 0}); err != nil { t.Fatal(err) } st.wantSettingsAck() getSlash(st) // make the single request hf := st.wantHeaders() if hf.StreamEnded() { t.Fatal("unexpected END_STREAM flag") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } df := st.wantData() if got := len(df.Data()); got != 0 { t.Fatalf("unexpected %d DATA bytes; want 0", got) } if !df.StreamEnded() { t.Fatal("DATA didn't have END_STREAM") } }) } func TestServer_Response_Automatic100Continue(t *testing.T) { const msg = "foo" const reply = "bar" testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { if v := r.Header.Get("Expect"); v != "" { t.Errorf("Expect header = %q; want empty", v) } buf := make([]byte, len(msg)) // This read should trigger the 100-continue being sent. if n, err := io.ReadFull(r.Body, buf); err != nil || n != len(msg) || string(buf) != msg { return fmt.Errorf("ReadFull = %q, %v; want %q, nil", buf[:n], err, msg) } _, err := io.WriteString(w, reply) return err }, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(":method", "POST", "expect", "100-continue"), EndStream: false, EndHeaders: true, }) hf := st.wantHeaders() if hf.StreamEnded() { t.Fatal("unexpected END_STREAM flag") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "100"}, } if !reflect.DeepEqual(goth, wanth) { t.Fatalf("Got headers %v; want %v", goth, wanth) } // Okay, they sent status 100, so we can send our // gigantic and/or sensitive "foo" payload now. st.writeData(1, true, []byte(msg)) st.wantWindowUpdate(0, uint32(len(msg))) hf = st.wantHeaders() if hf.StreamEnded() { t.Fatal("expected data to follow") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } goth = st.decodeHeader(hf.HeaderBlockFragment()) wanth = [][2]string{ {":status", "200"}, {"content-type", "text/plain; charset=utf-8"}, {"content-length", strconv.Itoa(len(reply))}, } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Got headers %v; want %v", goth, wanth) } df := st.wantData() if string(df.Data()) != reply { t.Errorf("Client read %q; want %q", df.Data(), reply) } if !df.StreamEnded() { t.Errorf("expect data stream end") } }) } func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) { errc := make(chan error, 1) testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { p := []byte("some data.\n") for { _, err := w.Write(p) if err != nil { errc <- err return nil } } }, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(), EndStream: false, EndHeaders: true, }) hf := st.wantHeaders() if hf.StreamEnded() { t.Fatal("unexpected END_STREAM flag") } if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } // Close the connection and wait for the handler to (hopefully) notice. st.cc.Close() select { case <-errc: case <-time.After(5 * time.Second): t.Error("timeout") } }) } func TestServer_Rejects_Too_Many_Streams(t *testing.T) { const testPath = "/some/path" inHandler := make(chan uint32) leaveHandler := make(chan bool) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { id := w.(*responseWriter).rws.stream.id inHandler <- id if id == 1+(defaultMaxStreams+1)*2 && r.URL.Path != testPath { t.Errorf("decoded final path as %q; want %q", r.URL.Path, testPath) } <-leaveHandler }) defer st.Close() st.greet() nextStreamID := uint32(1) streamID := func() uint32 { defer func() { nextStreamID += 2 }() return nextStreamID } sendReq := func(id uint32, headers ...string) { st.writeHeaders(HeadersFrameParam{ StreamID: id, BlockFragment: st.encodeHeader(headers...), EndStream: true, EndHeaders: true, }) } for i := 0; i < defaultMaxStreams; i++ { sendReq(streamID()) <-inHandler } defer func() { for i := 0; i < defaultMaxStreams; i++ { leaveHandler <- true } }() // And this one should cross the limit: // (It's also sent as a CONTINUATION, to verify we still track the decoder context, // even if we're rejecting it) rejectID := streamID() headerBlock := st.encodeHeader(":path", testPath) frag1, frag2 := headerBlock[:3], headerBlock[3:] st.writeHeaders(HeadersFrameParam{ StreamID: rejectID, BlockFragment: frag1, EndStream: true, EndHeaders: false, // CONTINUATION coming }) if err := st.fr.WriteContinuation(rejectID, true, frag2); err != nil { t.Fatal(err) } st.wantRSTStream(rejectID, ErrCodeProtocol) // But let a handler finish: leaveHandler <- true st.wantHeaders() // And now another stream should be able to start: goodID := streamID() sendReq(goodID, ":path", testPath) select { case got := <-inHandler: if got != goodID { t.Errorf("Got stream %d; want %d", got, goodID) } case <-time.After(3 * time.Second): t.Error("timeout waiting for handler") } } // So many response headers that the server needs to use CONTINUATION frames: func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { h := w.Header() for i := 0; i < 5000; i++ { h.Set(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("x-value-%d", i)) } return nil }, func(st *serverTester) { getSlash(st) hf := st.wantHeaders() if hf.HeadersEnded() { t.Fatal("got unwanted END_HEADERS flag") } n := 0 for { n++ cf := st.wantContinuation() if cf.HeadersEnded() { break } } if n < 5 { t.Errorf("Only got %d CONTINUATION frames; expected 5+ (currently 6)", n) } }) } // This previously crashed (reported by Mathieu Lonjaret as observed // while using Camlistore) because we got a DATA frame from the client // after the handler exited and our logic at the time was wrong, // keeping a stream in the map in stateClosed, which tickled an // invariant check later when we tried to remove that stream (via // defer sc.closeAllStreamsOnConnClose) when the serverConn serve loop // ended. func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { // nothing return nil }, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(), EndStream: false, // DATA is coming EndHeaders: true, }) hf := st.wantHeaders() if !hf.HeadersEnded() || !hf.StreamEnded() { t.Fatalf("want END_HEADERS+END_STREAM, got %v", hf) } // Sent when the a Handler closes while a client has // indicated it's still sending DATA: st.wantRSTStream(1, ErrCodeNo) // Now the handler has ended, so it's ended its // stream, but the client hasn't closed its side // (stateClosedLocal). So send more data and verify // it doesn't crash with an internal invariant panic, like // it did before. st.writeData(1, true, []byte("foo")) // Get our flow control bytes back, since the handler didn't get them. st.wantWindowUpdate(0, uint32(len("foo"))) // Sent after a peer sends data anyway (admittedly the // previous RST_STREAM might've still been in-flight), // but they'll get the more friendly 'cancel' code // first. st.wantRSTStream(1, ErrCodeStreamClosed) // Set up a bunch of machinery to record the panic we saw // previously. var ( panMu sync.Mutex panicVal interface{} ) testHookOnPanicMu.Lock() testHookOnPanic = func(sc *serverConn, pv interface{}) bool { panMu.Lock() panicVal = pv panMu.Unlock() return true } testHookOnPanicMu.Unlock() // Now force the serve loop to end, via closing the connection. st.cc.Close() select { case <-st.sc.doneServing: // Loop has exited. panMu.Lock() got := panicVal panMu.Unlock() if got != nil { t.Errorf("Got panic: %v", got) } case <-time.After(5 * time.Second): t.Error("timeout") } }) } func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) } func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) } func testRejectTLS(t *testing.T, max uint16) { st := newServerTester(t, nil, func(c *tls.Config) { c.MaxVersion = max }) defer st.Close() gf := st.wantGoAway() if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want { t.Errorf("Got error code %v; want %v", got, want) } } func TestServer_Rejects_TLSBadCipher(t *testing.T) { st := newServerTester(t, nil, func(c *tls.Config) { // All TLS 1.3 ciphers are good. Test with TLS 1.2. c.MaxVersion = tls.VersionTLS12 // Only list bad ones: c.CipherSuites = []uint16{ tls.TLS_RSA_WITH_RC4_128_SHA, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, tls.TLS_RSA_WITH_AES_128_CBC_SHA, tls.TLS_RSA_WITH_AES_256_CBC_SHA, tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, cipher_TLS_RSA_WITH_AES_128_CBC_SHA256, } }) defer st.Close() gf := st.wantGoAway() if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want { t.Errorf("Got error code %v; want %v", got, want) } } func TestServer_Advertises_Common_Cipher(t *testing.T) { const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 st := newServerTester(t, nil, func(c *tls.Config) { // Have the client only support the one required by the spec. c.CipherSuites = []uint16{requiredSuite} }, func(ts *httptest.Server) { var srv *http.Server = ts.Config // Have the server configured with no specific cipher suites. // This tests that Go's defaults include the required one. srv.TLSConfig = nil }) defer st.Close() st.greet() } func (st *serverTester) onHeaderField(f hpack.HeaderField) { if f.Name == "date" { return } st.decodedHeaders = append(st.decodedHeaders, [2]string{f.Name, f.Value}) } func (st *serverTester) decodeHeader(headerBlock []byte) (pairs [][2]string) { st.decodedHeaders = nil if _, err := st.hpackDec.Write(headerBlock); err != nil { st.t.Fatalf("hpack decoding error: %v", err) } if err := st.hpackDec.Close(); err != nil { st.t.Fatalf("hpack decoding error: %v", err) } return st.decodedHeaders } // testServerResponse sets up an idle HTTP/2 connection. The client function should // write a single request that must be handled by the handler. This waits up to 5s // for client to return, then up to an additional 2s for the handler to return. func testServerResponse(t testing.TB, handler func(http.ResponseWriter, *http.Request) error, client func(*serverTester), ) { errc := make(chan error, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { if r.Body == nil { t.Fatal("nil Body") } errc <- handler(w, r) }) defer st.Close() donec := make(chan bool) go func() { defer close(donec) st.greet() client(st) }() select { case <-donec: case <-time.After(5 * time.Second): t.Fatal("timeout in client") } select { case err := <-errc: if err != nil { t.Fatalf("Error in handler: %v", err) } case <-time.After(2 * time.Second): t.Fatal("timeout in handler") } } // readBodyHandler returns an http Handler func that reads len(want) // bytes from r.Body and fails t if the contents read were not // the value of want. func readBodyHandler(t *testing.T, want string) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { buf := make([]byte, len(want)) _, err := io.ReadFull(r.Body, buf) if err != nil { t.Error(err) return } if string(buf) != want { t.Errorf("read %q; want %q", buf, want) } } } // TestServerWithCurl currently fails, hence the LenientCipherSuites test. See: // https://github.com/tatsuhiro-t/nghttp2/issues/140 & // http://sourceforge.net/p/curl/bugs/1472/ func TestServerWithCurl(t *testing.T) { testServerWithCurl(t, false) } func TestServerWithCurl_LenientCipherSuites(t *testing.T) { testServerWithCurl(t, true) } func testServerWithCurl(t *testing.T, permitProhibitedCipherSuites bool) { if runtime.GOOS != "linux" { t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway") } if testing.Short() { t.Skip("skipping curl test in short mode") } requireCurl(t) var gotConn int32 testHookOnConn = func() { atomic.StoreInt32(&gotConn, 1) } const msg = "Hello from curl!\n" ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Foo", "Bar") w.Header().Set("Client-Proto", r.Proto) io.WriteString(w, msg) })) ConfigureServer(ts.Config, &Server{ PermitProhibitedCipherSuites: permitProhibitedCipherSuites, }) ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config ts.StartTLS() defer ts.Close() t.Logf("Running test server for curl to hit at: %s", ts.URL) container := curl(t, "--silent", "--http2", "--insecure", "-v", ts.URL) defer kill(container) resc := make(chan interface{}, 1) go func() { res, err := dockerLogs(container) if err != nil { resc <- err } else { resc <- res } }() select { case res := <-resc: if err, ok := res.(error); ok { t.Fatal(err) } body := string(res.([]byte)) // Search for both "key: value" and "key:value", since curl changed their format // Our Dockerfile contains the latest version (no space), but just in case people // didn't rebuild, check both. if !strings.Contains(body, "foo: Bar") && !strings.Contains(body, "foo:Bar") { t.Errorf("didn't see foo: Bar header") t.Logf("Got: %s", body) } if !strings.Contains(body, "client-proto: HTTP/2") && !strings.Contains(body, "client-proto:HTTP/2") { t.Errorf("didn't see client-proto: HTTP/2 header") t.Logf("Got: %s", res) } if !strings.Contains(string(res.([]byte)), msg) { t.Errorf("didn't see %q content", msg) t.Logf("Got: %s", res) } case <-time.After(3 * time.Second): t.Errorf("timeout waiting for curl") } if atomic.LoadInt32(&gotConn) == 0 { t.Error("never saw an http2 connection") } } var doh2load = flag.Bool("h2load", false, "Run h2load test") func TestServerWithH2Load(t *testing.T) { if !*doh2load { t.Skip("Skipping without --h2load flag.") } if runtime.GOOS != "linux" { t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway") } requireH2load(t) msg := strings.Repeat("Hello, h2load!\n", 5000) ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, msg) w.(http.Flusher).Flush() io.WriteString(w, msg) })) ts.StartTLS() defer ts.Close() cmd := exec.Command("docker", "run", "--net=host", "--entrypoint=/usr/local/bin/h2load", "gohttp2/curl", "-n100000", "-c100", "-m100", ts.URL) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { t.Fatal(err) } } // Issue 12843 func TestServerDoS_MaxHeaderListSize(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}) defer st.Close() // shake hands frameSize := defaultMaxReadFrameSize var advHeaderListSize *uint32 st.greetAndCheckSettings(func(s Setting) error { switch s.ID { case SettingMaxFrameSize: if s.Val < minMaxFrameSize { frameSize = minMaxFrameSize } else if s.Val > maxFrameSize { frameSize = maxFrameSize } else { frameSize = int(s.Val) } case SettingMaxHeaderListSize: advHeaderListSize = &s.Val } return nil }) if advHeaderListSize == nil { t.Errorf("server didn't advertise a max header list size") } else if *advHeaderListSize == 0 { t.Errorf("server advertised a max header list size of 0") } st.encodeHeaderField(":method", "GET") st.encodeHeaderField(":path", "/") st.encodeHeaderField(":scheme", "https") cookie := strings.Repeat("*", 4058) st.encodeHeaderField("cookie", cookie) st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.headerBuf.Bytes(), EndStream: true, EndHeaders: false, }) // Capture the short encoding of a duplicate ~4K cookie, now // that we've already sent it once. st.headerBuf.Reset() st.encodeHeaderField("cookie", cookie) // Now send 1MB of it. const size = 1 << 20 b := bytes.Repeat(st.headerBuf.Bytes(), size/st.headerBuf.Len()) for len(b) > 0 { chunk := b if len(chunk) > frameSize { chunk = chunk[:frameSize] } b = b[len(chunk):] st.fr.WriteContinuation(1, len(b) == 0, chunk) } h := st.wantHeaders() if !h.HeadersEnded() { t.Fatalf("Got HEADERS without END_HEADERS set: %v", h) } headers := st.decodeHeader(h.HeaderBlockFragment()) want := [][2]string{ {":status", "431"}, {"content-type", "text/html; charset=utf-8"}, {"content-length", "63"}, } if !reflect.DeepEqual(headers, want) { t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want) } } func TestServer_Response_Stream_With_Missing_Trailer(t *testing.T) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Trailer", "test-trailer") return nil }, func(st *serverTester) { getSlash(st) hf := st.wantHeaders() if !hf.HeadersEnded() { t.Fatal("want END_HEADERS flag") } df := st.wantData() if len(df.data) != 0 { t.Fatal("did not want data") } if !df.StreamEnded() { t.Fatal("want END_STREAM flag") } }) } func TestCompressionErrorOnWrite(t *testing.T) { const maxStrLen = 8 << 10 var serverConfig *http.Server st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { // No response body. }, func(ts *httptest.Server) { serverConfig = ts.Config serverConfig.MaxHeaderBytes = maxStrLen }) st.addLogFilter("connection error: COMPRESSION_ERROR") defer st.Close() st.greet() maxAllowed := st.sc.framer.maxHeaderStringLen() // Crank this up, now that we have a conn connected with the // hpack.Decoder's max string length set has been initialized // from the earlier low ~8K value. We want this higher so don't // hit the max header list size. We only want to test hitting // the max string size. serverConfig.MaxHeaderBytes = 1 << 20 // First a request with a header that's exactly the max allowed size // for the hpack compression. It's still too long for the header list // size, so we'll get the 431 error, but that keeps the compression // context still valid. hbf := st.encodeHeader("foo", strings.Repeat("a", maxAllowed)) st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: hbf, EndStream: true, EndHeaders: true, }) h := st.wantHeaders() if !h.HeadersEnded() { t.Fatalf("Got HEADERS without END_HEADERS set: %v", h) } headers := st.decodeHeader(h.HeaderBlockFragment()) want := [][2]string{ {":status", "431"}, {"content-type", "text/html; charset=utf-8"}, {"content-length", "63"}, } if !reflect.DeepEqual(headers, want) { t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want) } df := st.wantData() if !strings.Contains(string(df.Data()), "HTTP Error 431") { t.Errorf("Unexpected data body: %q", df.Data()) } if !df.StreamEnded() { t.Fatalf("expect data stream end") } // And now send one that's just one byte too big. hbf = st.encodeHeader("bar", strings.Repeat("b", maxAllowed+1)) st.writeHeaders(HeadersFrameParam{ StreamID: 3, BlockFragment: hbf, EndStream: true, EndHeaders: true, }) ga := st.wantGoAway() if ga.ErrCode != ErrCodeCompression { t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode) } } func TestCompressionErrorOnClose(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { // No response body. }) st.addLogFilter("connection error: COMPRESSION_ERROR") defer st.Close() st.greet() hbf := st.encodeHeader("foo", "bar") hbf = hbf[:len(hbf)-1] // truncate one byte from the end, so hpack.Decoder.Close fails. st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: hbf, EndStream: true, EndHeaders: true, }) ga := st.wantGoAway() if ga.ErrCode != ErrCodeCompression { t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode) } } // test that a server handler can read trailers from a client func TestServerReadsTrailers(t *testing.T) { const testBody = "some test body" writeReq := func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader("trailer", "Foo, Bar", "trailer", "Baz"), EndStream: false, EndHeaders: true, }) st.writeData(1, false, []byte(testBody)) st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeaderRaw( "foo", "foov", "bar", "barv", "baz", "bazv", "surprise", "wasn't declared; shouldn't show up", ), EndStream: true, EndHeaders: true, }) } checkReq := func(r *http.Request) { wantTrailer := http.Header{ "Foo": nil, "Bar": nil, "Baz": nil, } if !reflect.DeepEqual(r.Trailer, wantTrailer) { t.Errorf("initial Trailer = %v; want %v", r.Trailer, wantTrailer) } slurp, err := ioutil.ReadAll(r.Body) if string(slurp) != testBody { t.Errorf("read body %q; want %q", slurp, testBody) } if err != nil { t.Fatalf("Body slurp: %v", err) } wantTrailerAfter := http.Header{ "Foo": {"foov"}, "Bar": {"barv"}, "Baz": {"bazv"}, } if !reflect.DeepEqual(r.Trailer, wantTrailerAfter) { t.Errorf("final Trailer = %v; want %v", r.Trailer, wantTrailerAfter) } } testServerRequest(t, writeReq, checkReq) } // test that a server handler can send trailers func TestServerWritesTrailers_WithFlush(t *testing.T) { testServerWritesTrailers(t, true) } func TestServerWritesTrailers_WithoutFlush(t *testing.T) { testServerWritesTrailers(t, false) } func testServerWritesTrailers(t *testing.T, withFlush bool) { // See https://httpwg.github.io/specs/rfc7540.html#rfc.section.8.1.3 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") w.Header().Add("Trailer", "Server-Trailer-C") w.Header().Add("Trailer", "Transfer-Encoding, Content-Length, Trailer") // filtered // Regular headers: w.Header().Set("Foo", "Bar") w.Header().Set("Content-Length", "5") // len("Hello") io.WriteString(w, "Hello") if withFlush { w.(http.Flusher).Flush() } w.Header().Set("Server-Trailer-A", "valuea") w.Header().Set("Server-Trailer-C", "valuec") // skipping B // After a flush, random keys like Server-Surprise shouldn't show up: w.Header().Set("Server-Surpise", "surprise! this isn't predeclared!") // But we do permit promoting keys to trailers after a // flush if they start with the magic // otherwise-invalid "Trailer:" prefix: w.Header().Set("Trailer:Post-Header-Trailer", "hi1") w.Header().Set("Trailer:post-header-trailer2", "hi2") w.Header().Set("Trailer:Range", "invalid") w.Header().Set("Trailer:Foo\x01Bogus", "invalid") w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 7230 4.1.2") w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 7230 4.1.2") w.Header().Set("Trailer", "should not be included; Forbidden by RFC 7230 4.1.2") return nil }, func(st *serverTester) { getSlash(st) hf := st.wantHeaders() if hf.StreamEnded() { t.Fatal("response HEADERS had END_STREAM") } if !hf.HeadersEnded() { t.Fatal("response HEADERS didn't have END_HEADERS") } goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "200"}, {"foo", "Bar"}, {"trailer", "Server-Trailer-A, Server-Trailer-B"}, {"trailer", "Server-Trailer-C"}, {"trailer", "Transfer-Encoding, Content-Length, Trailer"}, {"content-type", "text/plain; charset=utf-8"}, {"content-length", "5"}, } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth) } df := st.wantData() if string(df.Data()) != "Hello" { t.Fatalf("Client read %q; want Hello", df.Data()) } if df.StreamEnded() { t.Fatalf("data frame had STREAM_ENDED") } tf := st.wantHeaders() // for the trailers if !tf.StreamEnded() { t.Fatalf("trailers HEADERS lacked END_STREAM") } if !tf.HeadersEnded() { t.Fatalf("trailers HEADERS lacked END_HEADERS") } wanth = [][2]string{ {"post-header-trailer", "hi1"}, {"post-header-trailer2", "hi2"}, {"server-trailer-a", "valuea"}, {"server-trailer-c", "valuec"}, } goth = st.decodeHeader(tf.HeaderBlockFragment()) if !reflect.DeepEqual(goth, wanth) { t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth) } }) } // validate transmitted header field names & values // golang.org/issue/14048 func TestServerDoesntWriteInvalidHeaders(t *testing.T) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Add("OK1", "x") w.Header().Add("Bad:Colon", "x") // colon (non-token byte) in key w.Header().Add("Bad1\x00", "x") // null in key w.Header().Add("Bad2", "x\x00y") // null in value return nil }, func(st *serverTester) { getSlash(st) hf := st.wantHeaders() if !hf.StreamEnded() { t.Error("response HEADERS lacked END_STREAM") } if !hf.HeadersEnded() { t.Fatal("response HEADERS didn't have END_HEADERS") } goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "200"}, {"ok1", "x"}, {"content-length", "0"}, } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth) } }) } func BenchmarkServerGets(b *testing.B) { defer disableGoroutineTracking()() b.ReportAllocs() const msg = "Hello, world" st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, msg) }) defer st.Close() st.greet() // Give the server quota to reply. (plus it has the 64KB) if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil { b.Fatal(err) } for i := 0; i < b.N; i++ { id := 1 + uint32(i)*2 st.writeHeaders(HeadersFrameParam{ StreamID: id, BlockFragment: st.encodeHeader(), EndStream: true, EndHeaders: true, }) st.wantHeaders() df := st.wantData() if !df.StreamEnded() { b.Fatalf("DATA didn't have END_STREAM; got %v", df) } } } func BenchmarkServerPosts(b *testing.B) { defer disableGoroutineTracking()() b.ReportAllocs() const msg = "Hello, world" st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { // Consume the (empty) body from th peer before replying, otherwise // the server will sometimes (depending on scheduling) send the peer a // a RST_STREAM with the CANCEL error code. if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil { b.Errorf("Copy error; got %v, %v; want 0, nil", n, err) } io.WriteString(w, msg) }) defer st.Close() st.greet() // Give the server quota to reply. (plus it has the 64KB) if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil { b.Fatal(err) } for i := 0; i < b.N; i++ { id := 1 + uint32(i)*2 st.writeHeaders(HeadersFrameParam{ StreamID: id, BlockFragment: st.encodeHeader(":method", "POST"), EndStream: false, EndHeaders: true, }) st.writeData(id, true, nil) st.wantHeaders() df := st.wantData() if !df.StreamEnded() { b.Fatalf("DATA didn't have END_STREAM; got %v", df) } } } // Send a stream of messages from server to client in separate data frames. // Brings up performance issues seen in long streams. // Created to show problem in go issue #18502 func BenchmarkServerToClientStreamDefaultOptions(b *testing.B) { benchmarkServerToClientStream(b) } // Justification for Change-Id: Iad93420ef6c3918f54249d867098f1dadfa324d8 // Expect to see memory/alloc reduction by opting in to Frame reuse with the Framer. func BenchmarkServerToClientStreamReuseFrames(b *testing.B) { benchmarkServerToClientStream(b, optFramerReuseFrames) } func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) { defer disableGoroutineTracking()() b.ReportAllocs() const msgLen = 1 // default window size const windowSize = 1<<16 - 1 // next message to send from the server and for the client to expect nextMsg := func(i int) []byte { msg := make([]byte, msgLen) msg[0] = byte(i) if len(msg) != msgLen { panic("invalid test setup msg length") } return msg } st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { // Consume the (empty) body from th peer before replying, otherwise // the server will sometimes (depending on scheduling) send the peer a // a RST_STREAM with the CANCEL error code. if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil { b.Errorf("Copy error; got %v, %v; want 0, nil", n, err) } for i := 0; i < b.N; i += 1 { w.Write(nextMsg(i)) w.(http.Flusher).Flush() } }, newServerOpts...) defer st.Close() st.greet() const id = uint32(1) st.writeHeaders(HeadersFrameParam{ StreamID: id, BlockFragment: st.encodeHeader(":method", "POST"), EndStream: false, EndHeaders: true, }) st.writeData(id, true, nil) st.wantHeaders() var pendingWindowUpdate = uint32(0) for i := 0; i < b.N; i += 1 { expected := nextMsg(i) df := st.wantData() if bytes.Compare(expected, df.data) != 0 { b.Fatalf("Bad message received; want %v; got %v", expected, df.data) } // try to send infrequent but large window updates so they don't overwhelm the test pendingWindowUpdate += uint32(len(df.data)) if pendingWindowUpdate >= windowSize/2 { if err := st.fr.WriteWindowUpdate(0, pendingWindowUpdate); err != nil { b.Fatal(err) } if err := st.fr.WriteWindowUpdate(id, pendingWindowUpdate); err != nil { b.Fatal(err) } pendingWindowUpdate = 0 } } df := st.wantData() if !df.StreamEnded() { b.Fatalf("DATA didn't have END_STREAM; got %v", df) } } // go-fuzz bug, originally reported at https://github.com/bradfitz/http2/issues/53 // Verify we don't hang. func TestIssue53(t *testing.T) { const data = "PRI * HTTP/2.0\r\n\r\nSM" + "\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad" s := &http.Server{ ErrorLog: log.New(io.MultiWriter(stderrv(), twriter{t: t}), "", log.LstdFlags), Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }), } s2 := &Server{ MaxReadFrameSize: 1 << 16, PermitProhibitedCipherSuites: true, } c := &issue53Conn{[]byte(data), false, false} s2.ServeConn(c, &ServeConnOpts{BaseConfig: s}) if !c.closed { t.Fatal("connection is not closed") } } type issue53Conn struct { data []byte closed bool written bool } func (c *issue53Conn) Read(b []byte) (n int, err error) { if len(c.data) == 0 { return 0, io.EOF } n = copy(b, c.data) c.data = c.data[n:] return } func (c *issue53Conn) Write(b []byte) (n int, err error) { c.written = true return len(b), nil } func (c *issue53Conn) Close() error { c.closed = true return nil } func (c *issue53Conn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706} } func (c *issue53Conn) RemoteAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706} } func (c *issue53Conn) SetDeadline(t time.Time) error { return nil } func (c *issue53Conn) SetReadDeadline(t time.Time) error { return nil } func (c *issue53Conn) SetWriteDeadline(t time.Time) error { return nil } // golang.org/issue/12895 func TestConfigureServer(t *testing.T) { tests := []struct { name string tlsConfig *tls.Config wantErr string }{ { name: "empty server", }, { name: "just the required cipher suite", tlsConfig: &tls.Config{ CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, }, }, { name: "just the alternative required cipher suite", tlsConfig: &tls.Config{ CipherSuites: []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, }, }, { name: "missing required cipher suite", tlsConfig: &tls.Config{ CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384}, }, wantErr: "is missing an HTTP/2-required AES_128_GCM_SHA256 cipher.", }, { name: "required after bad", tlsConfig: &tls.Config{ CipherSuites: []uint16{tls.TLS_RSA_WITH_RC4_128_SHA, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, }, wantErr: "contains an HTTP/2-approved cipher suite (0xc02f), but it comes after", }, { name: "bad after required", tlsConfig: &tls.Config{ CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_RC4_128_SHA}, }, }, } for _, tt := range tests { srv := &http.Server{TLSConfig: tt.tlsConfig} err := ConfigureServer(srv, nil) if (err != nil) != (tt.wantErr != "") { if tt.wantErr != "" { t.Errorf("%s: success, but want error", tt.name) } else { t.Errorf("%s: unexpected error: %v", tt.name, err) } } if err != nil && tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr) { t.Errorf("%s: err = %v; want substring %q", tt.name, err, tt.wantErr) } if err == nil && !srv.TLSConfig.PreferServerCipherSuites { t.Errorf("%s: PreferServerCipherSuite is false; want true", tt.name) } } } func TestServerRejectHeadWithBody(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { // No response body. }) defer st.Close() st.greet() st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(":method", "HEAD"), EndStream: false, // what we're testing, a bogus HEAD request with body EndHeaders: true, }) st.wantRSTStream(1, ErrCodeProtocol) } func TestServerNoAutoContentLengthOnHead(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { // No response body. (or smaller than one frame) }) defer st.Close() st.greet() st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader(":method", "HEAD"), EndStream: true, EndHeaders: true, }) h := st.wantHeaders() headers := st.decodeHeader(h.HeaderBlockFragment()) want := [][2]string{ {":status", "200"}, } if !reflect.DeepEqual(headers, want) { t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want) } } // golang.org/issue/13495 func TestServerNoDuplicateContentType(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Header()["Content-Type"] = []string{""} fmt.Fprintf(w, "hi") }) defer st.Close() st.greet() st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(), EndStream: true, EndHeaders: true, }) h := st.wantHeaders() headers := st.decodeHeader(h.HeaderBlockFragment()) want := [][2]string{ {":status", "200"}, {"content-type", ""}, {"content-length", "41"}, } if !reflect.DeepEqual(headers, want) { t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want) } } func disableGoroutineTracking() (restore func()) { old := DebugGoroutines DebugGoroutines = false return func() { DebugGoroutines = old } } func BenchmarkServer_GetRequest(b *testing.B) { defer disableGoroutineTracking()() b.ReportAllocs() const msg = "Hello, world." st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { n, err := io.Copy(ioutil.Discard, r.Body) if err != nil || n > 0 { b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err) } io.WriteString(w, msg) }) defer st.Close() st.greet() // Give the server quota to reply. (plus it has the 64KB) if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil { b.Fatal(err) } hbf := st.encodeHeader(":method", "GET") for i := 0; i < b.N; i++ { streamID := uint32(1 + 2*i) st.writeHeaders(HeadersFrameParam{ StreamID: streamID, BlockFragment: hbf, EndStream: true, EndHeaders: true, }) st.wantHeaders() st.wantData() } } func BenchmarkServer_PostRequest(b *testing.B) { defer disableGoroutineTracking()() b.ReportAllocs() const msg = "Hello, world." st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { n, err := io.Copy(ioutil.Discard, r.Body) if err != nil || n > 0 { b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err) } io.WriteString(w, msg) }) defer st.Close() st.greet() // Give the server quota to reply. (plus it has the 64KB) if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil { b.Fatal(err) } hbf := st.encodeHeader(":method", "POST") for i := 0; i < b.N; i++ { streamID := uint32(1 + 2*i) st.writeHeaders(HeadersFrameParam{ StreamID: streamID, BlockFragment: hbf, EndStream: false, EndHeaders: true, }) st.writeData(streamID, true, nil) st.wantHeaders() st.wantData() } } type connStateConn struct { net.Conn cs tls.ConnectionState } func (c connStateConn) ConnectionState() tls.ConnectionState { return c.cs } // golang.org/issue/12737 -- handle any net.Conn, not just // *tls.Conn. func TestServerHandleCustomConn(t *testing.T) { var s Server c1, c2 := net.Pipe() clientDone := make(chan struct{}) handlerDone := make(chan struct{}) var req *http.Request go func() { defer close(clientDone) defer c2.Close() fr := NewFramer(c2, c2) io.WriteString(c2, ClientPreface) fr.WriteSettings() fr.WriteSettingsAck() f, err := fr.ReadFrame() if err != nil { t.Error(err) return } if sf, ok := f.(*SettingsFrame); !ok || sf.IsAck() { t.Errorf("Got %v; want non-ACK SettingsFrame", summarizeFrame(f)) return } f, err = fr.ReadFrame() if err != nil { t.Error(err) return } if sf, ok := f.(*SettingsFrame); !ok || !sf.IsAck() { t.Errorf("Got %v; want ACK SettingsFrame", summarizeFrame(f)) return } var henc hpackEncoder fr.WriteHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: henc.encodeHeaderRaw(t, ":method", "GET", ":path", "/", ":scheme", "https", ":authority", "foo.com"), EndStream: true, EndHeaders: true, }) go io.Copy(ioutil.Discard, c2) <-handlerDone }() const testString = "my custom ConnectionState" fakeConnState := tls.ConnectionState{ ServerName: testString, Version: tls.VersionTLS12, CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, } go s.ServeConn(connStateConn{c1, fakeConnState}, &ServeConnOpts{ BaseConfig: &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer close(handlerDone) req = r }), }}) select { case <-clientDone: case <-time.After(5 * time.Second): t.Fatal("timeout waiting for handler") } if req.TLS == nil { t.Fatalf("Request.TLS is nil. Got: %#v", req) } if req.TLS.ServerName != testString { t.Fatalf("Request.TLS = %+v; want ServerName of %q", req.TLS, testString) } } // golang.org/issue/14214 func TestServer_Rejects_ConnHeaders(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { t.Error("should not get to Handler") }) defer st.Close() st.greet() st.bodylessReq1("connection", "foo") hf := st.wantHeaders() goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "400"}, {"content-type", "text/plain; charset=utf-8"}, {"x-content-type-options", "nosniff"}, {"content-length", "51"}, } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Got headers %v; want %v", goth, wanth) } } type hpackEncoder struct { enc *hpack.Encoder buf bytes.Buffer } func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte { if len(headers)%2 == 1 { panic("odd number of kv args") } he.buf.Reset() if he.enc == nil { he.enc = hpack.NewEncoder(&he.buf) } for len(headers) > 0 { k, v := headers[0], headers[1] err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v}) if err != nil { t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) } headers = headers[2:] } return he.buf.Bytes() } func TestCheckValidHTTP2Request(t *testing.T) { tests := []struct { h http.Header want error }{ { h: http.Header{"Te": {"trailers"}}, want: nil, }, { h: http.Header{"Te": {"trailers", "bogus"}}, want: errors.New(`request header "TE" may only be "trailers" in HTTP/2`), }, { h: http.Header{"Foo": {""}}, want: nil, }, { h: http.Header{"Connection": {""}}, want: errors.New(`request header "Connection" is not valid in HTTP/2`), }, { h: http.Header{"Proxy-Connection": {""}}, want: errors.New(`request header "Proxy-Connection" is not valid in HTTP/2`), }, { h: http.Header{"Keep-Alive": {""}}, want: errors.New(`request header "Keep-Alive" is not valid in HTTP/2`), }, { h: http.Header{"Upgrade": {""}}, want: errors.New(`request header "Upgrade" is not valid in HTTP/2`), }, } for i, tt := range tests { got := checkValidHTTP2RequestHeaders(tt.h) if !equalError(got, tt.want) { t.Errorf("%d. checkValidHTTP2Request = %v; want %v", i, got, tt.want) } } } // golang.org/issue/14030 func TestExpect100ContinueAfterHandlerWrites(t *testing.T) { const msg = "Hello" const msg2 = "World" doRead := make(chan bool, 1) defer close(doRead) // fallback cleanup st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, msg) w.(http.Flusher).Flush() // Do a read, which might force a 100-continue status to be sent. <-doRead r.Body.Read(make([]byte, 10)) io.WriteString(w, msg2) }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() req, _ := http.NewRequest("POST", st.ts.URL, io.LimitReader(neverEnding('A'), 2<<20)) req.Header.Set("Expect", "100-continue") res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } defer res.Body.Close() buf := make([]byte, len(msg)) if _, err := io.ReadFull(res.Body, buf); err != nil { t.Fatal(err) } if string(buf) != msg { t.Fatalf("msg = %q; want %q", buf, msg) } doRead <- true if _, err := io.ReadFull(res.Body, buf); err != nil { t.Fatal(err) } if string(buf) != msg2 { t.Fatalf("second msg = %q; want %q", buf, msg2) } } type funcReader func([]byte) (n int, err error) func (f funcReader) Read(p []byte) (n int, err error) { return f(p) } // golang.org/issue/16481 -- return flow control when streams close with unread data. // (The Server version of the bug. See also TestUnreadFlowControlReturned_Transport) func TestUnreadFlowControlReturned_Server(t *testing.T) { unblock := make(chan bool, 1) defer close(unblock) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { // Don't read the 16KB request body. Wait until the client's // done sending it and then return. This should cause the Server // to then return those 16KB of flow control to the client. <-unblock }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() // This previously hung on the 4th iteration. for i := 0; i < 6; i++ { body := io.MultiReader( io.LimitReader(neverEnding('A'), 16<<10), funcReader(func([]byte) (n int, err error) { unblock <- true return 0, io.EOF }), ) req, _ := http.NewRequest("POST", st.ts.URL, body) res, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } res.Body.Close() } } func TestServerIdleTimeout(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, func(h2s *Server) { h2s.IdleTimeout = 500 * time.Millisecond }) defer st.Close() st.greet() ga := st.wantGoAway() if ga.ErrCode != ErrCodeNo { t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) } } func TestServerIdleTimeout_AfterRequest(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } const timeout = 250 * time.Millisecond st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { time.Sleep(timeout * 2) }, func(h2s *Server) { h2s.IdleTimeout = timeout }) defer st.Close() st.greet() // Send a request which takes twice the timeout. Verifies the // idle timeout doesn't fire while we're in a request: st.bodylessReq1() st.wantHeaders() // But the idle timeout should be rearmed after the request // is done: ga := st.wantGoAway() if ga.ErrCode != ErrCodeNo { t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) } } // grpc-go closes the Request.Body currently with a Read. // Verify that it doesn't race. // See https://github.com/grpc/grpc-go/pull/938 func TestRequestBodyReadCloseRace(t *testing.T) { for i := 0; i < 100; i++ { body := &requestBody{ pipe: &pipe{ b: new(bytes.Buffer), }, } body.pipe.CloseWithError(io.EOF) done := make(chan bool, 1) buf := make([]byte, 10) go func() { time.Sleep(1 * time.Millisecond) body.Close() done <- true }() body.Read(buf) <-done } } func TestIssue20704Race(t *testing.T) { if testing.Short() && os.Getenv("GO_BUILDER_NAME") == "" { t.Skip("skipping in short mode") } const ( itemSize = 1 << 10 itemCount = 100 ) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { for i := 0; i < itemCount; i++ { _, err := w.Write(make([]byte, itemSize)) if err != nil { return } } }, optOnlyServer) defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() cl := &http.Client{Transport: tr} for i := 0; i < 1000; i++ { resp, err := cl.Get(st.ts.URL) if err != nil { t.Fatal(err) } // Force a RST stream to the server by closing without // reading the body: resp.Body.Close() } } func TestServer_Rejects_TooSmall(t *testing.T) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { ioutil.ReadAll(r.Body) return nil }, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers BlockFragment: st.encodeHeader( ":method", "POST", "content-length", "4", ), EndStream: false, // to say DATA frames are coming EndHeaders: true, }) st.writeData(1, true, []byte("12345")) st.wantRSTStream(1, ErrCodeProtocol) }) } // Tests that a handler setting "Connection: close" results in a GOAWAY being sent, // and the connection still completing. func TestServerHandlerConnectionClose(t *testing.T) { unblockHandler := make(chan bool, 1) defer close(unblockHandler) // backup; in case of errors testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Connection", "close") w.Header().Set("Foo", "bar") w.(http.Flusher).Flush() <-unblockHandler return nil }, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(), EndStream: true, EndHeaders: true, }) var sawGoAway bool var sawRes bool for { f, err := st.readFrame() if err == io.EOF { break } if err != nil { t.Fatal(err) } switch f := f.(type) { case *GoAwayFrame: sawGoAway = true unblockHandler <- true if f.LastStreamID != 1 || f.ErrCode != ErrCodeNo { t.Errorf("unexpected GOAWAY frame: %v", summarizeFrame(f)) } case *HeadersFrame: goth := st.decodeHeader(f.HeaderBlockFragment()) wanth := [][2]string{ {":status", "200"}, {"foo", "bar"}, } if !reflect.DeepEqual(goth, wanth) { t.Errorf("got headers %v; want %v", goth, wanth) } sawRes = true case *DataFrame: if f.StreamID != 1 || !f.StreamEnded() || len(f.Data()) != 0 { t.Errorf("unexpected DATA frame: %v", summarizeFrame(f)) } default: t.Logf("unexpected frame: %v", summarizeFrame(f)) } } if !sawGoAway { t.Errorf("didn't see GOAWAY") } if !sawRes { t.Errorf("didn't see response") } }) } func TestServer_Headers_HalfCloseRemote(t *testing.T) { var st *serverTester writeData := make(chan bool) writeHeaders := make(chan bool) leaveHandler := make(chan bool) st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { if st.stream(1) == nil { t.Errorf("nil stream 1 in handler") } if got, want := st.streamState(1), stateOpen; got != want { t.Errorf("in handler, state is %v; want %v", got, want) } writeData <- true if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF { t.Errorf("body read = %d, %v; want 0, EOF", n, err) } if got, want := st.streamState(1), stateHalfClosedRemote; got != want { t.Errorf("in handler, state is %v; want %v", got, want) } writeHeaders <- true <-leaveHandler }) st.greet() st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(), EndStream: false, // keep it open EndHeaders: true, }) <-writeData st.writeData(1, true, nil) <-writeHeaders st.writeHeaders(HeadersFrameParam{ StreamID: 1, BlockFragment: st.encodeHeader(), EndStream: false, // keep it open EndHeaders: true, }) defer close(leaveHandler) st.wantRSTStream(1, ErrCodeStreamClosed) } func TestServerGracefulShutdown(t *testing.T) { var st *serverTester handlerDone := make(chan struct{}) st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { defer close(handlerDone) go st.ts.Config.Shutdown(context.Background()) ga := st.wantGoAway() if ga.ErrCode != ErrCodeNo { t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) } if ga.LastStreamID != 1 { t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID) } w.Header().Set("x-foo", "bar") }) defer st.Close() st.greet() st.bodylessReq1() select { case <-handlerDone: case <-time.After(5 * time.Second): t.Fatalf("server did not shutdown?") } hf := st.wantHeaders() goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ {":status", "200"}, {"x-foo", "bar"}, {"content-length", "0"}, } if !reflect.DeepEqual(goth, wanth) { t.Errorf("Got headers %v; want %v", goth, wanth) } n, err := st.cc.Read([]byte{0}) if n != 0 || err == nil { t.Errorf("Read = %v, %v; want 0, non-nil", n, err) } }