Add unit tests

This commit is contained in:
Mayuresh Gaitonde
2020-12-17 17:25:53 -08:00
parent 3702339f44
commit 6be4d69d02
705 changed files with 120529 additions and 150051 deletions

View File

@@ -24,9 +24,10 @@ import (
)
const (
// bdpLimit is the maximum value the flow control windows
// will be increased to.
bdpLimit = (1 << 20) * 4
// bdpLimit is the maximum value the flow control windows will be increased
// to. TCP typically limits this to 4MB, but some systems go up to 16MB.
// Since this is only a limit, it is safe to make it optimistic.
bdpLimit = (1 << 20) * 16
// alpha is a constant factor used to keep a moving average
// of RTTs.
alpha = 0.9

View File

@@ -23,6 +23,7 @@ import (
"fmt"
"runtime"
"sync"
"sync/atomic"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
@@ -84,24 +85,40 @@ func (il *itemList) isEmpty() bool {
// the control buffer of transport. They represent different aspects of
// control tasks, e.g., flow control, settings, streaming resetting, etc.
// maxQueuedTransportResponseFrames is the most queued "transport response"
// frames we will buffer before preventing new reads from occurring on the
// transport. These are control frames sent in response to client requests,
// such as RST_STREAM due to bad headers or settings acks.
const maxQueuedTransportResponseFrames = 50
type cbItem interface {
isTransportResponseFrame() bool
}
// registerStream is used to register an incoming stream with loopy writer.
type registerStream struct {
streamID uint32
wq *writeQuota
}
func (*registerStream) isTransportResponseFrame() bool { return false }
// headerFrame is also used to register stream on the client-side.
type headerFrame struct {
streamID uint32
hf []hpack.HeaderField
endStream bool // Valid on server side.
initStream func(uint32) (bool, error) // Used only on the client side.
endStream bool // Valid on server side.
initStream func(uint32) error // Used only on the client side.
onWrite func()
wq *writeQuota // write quota for the stream created.
cleanup *cleanupStream // Valid on the server side.
onOrphaned func(error) // Valid on client-side
}
func (h *headerFrame) isTransportResponseFrame() bool {
return h.cleanup != nil && h.cleanup.rst // Results in a RST_STREAM
}
type cleanupStream struct {
streamID uint32
rst bool
@@ -109,6 +126,8 @@ type cleanupStream struct {
onWrite func()
}
func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM
type dataFrame struct {
streamID uint32
endStream bool
@@ -119,27 +138,41 @@ type dataFrame struct {
onEachWrite func()
}
func (*dataFrame) isTransportResponseFrame() bool { return false }
type incomingWindowUpdate struct {
streamID uint32
increment uint32
}
func (*incomingWindowUpdate) isTransportResponseFrame() bool { return false }
type outgoingWindowUpdate struct {
streamID uint32
increment uint32
}
func (*outgoingWindowUpdate) isTransportResponseFrame() bool {
return false // window updates are throttled by thresholds
}
type incomingSettings struct {
ss []http2.Setting
}
func (*incomingSettings) isTransportResponseFrame() bool { return true } // Results in a settings ACK
type outgoingSettings struct {
ss []http2.Setting
}
func (*outgoingSettings) isTransportResponseFrame() bool { return false }
type incomingGoAway struct {
}
func (*incomingGoAway) isTransportResponseFrame() bool { return false }
type goAway struct {
code http2.ErrCode
debugData []byte
@@ -147,15 +180,21 @@ type goAway struct {
closeConn bool
}
func (*goAway) isTransportResponseFrame() bool { return false }
type ping struct {
ack bool
data [8]byte
}
func (*ping) isTransportResponseFrame() bool { return true }
type outFlowControlSizeRequest struct {
resp chan uint32
}
func (*outFlowControlSizeRequest) isTransportResponseFrame() bool { return false }
type outStreamState int
const (
@@ -238,6 +277,14 @@ type controlBuffer struct {
consumerWaiting bool
list *itemList
err error
// transportResponseFrames counts the number of queued items that represent
// the response of an action initiated by the peer. trfChan is created
// when transportResponseFrames >= maxQueuedTransportResponseFrames and is
// closed and nilled when transportResponseFrames drops below the
// threshold. Both fields are protected by mu.
transportResponseFrames int
trfChan atomic.Value // *chan struct{}
}
func newControlBuffer(done <-chan struct{}) *controlBuffer {
@@ -248,12 +295,24 @@ func newControlBuffer(done <-chan struct{}) *controlBuffer {
}
}
func (c *controlBuffer) put(it interface{}) error {
// throttle blocks if there are too many incomingSettings/cleanupStreams in the
// controlbuf.
func (c *controlBuffer) throttle() {
ch, _ := c.trfChan.Load().(*chan struct{})
if ch != nil {
select {
case <-*ch:
case <-c.done:
}
}
}
func (c *controlBuffer) put(it cbItem) error {
_, err := c.executeAndPut(nil, it)
return err
}
func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it interface{}) (bool, error) {
func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it cbItem) (bool, error) {
var wakeUp bool
c.mu.Lock()
if c.err != nil {
@@ -271,6 +330,15 @@ func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it interface{
c.consumerWaiting = false
}
c.list.enqueue(it)
if it.isTransportResponseFrame() {
c.transportResponseFrames++
if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are adding the frame that puts us over the threshold; create
// a throttling channel.
ch := make(chan struct{})
c.trfChan.Store(&ch)
}
}
c.mu.Unlock()
if wakeUp {
select {
@@ -304,7 +372,17 @@ func (c *controlBuffer) get(block bool) (interface{}, error) {
return nil, c.err
}
if !c.list.isEmpty() {
h := c.list.dequeue()
h := c.list.dequeue().(cbItem)
if h.isTransportResponseFrame() {
if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are removing the frame that put us over the
// threshold; close and clear the throttling channel.
ch := c.trfChan.Load().(*chan struct{})
close(*ch)
c.trfChan.Store((*chan struct{})(nil))
}
c.transportResponseFrames--
}
c.mu.Unlock()
return h, nil
}
@@ -559,21 +637,17 @@ func (l *loopyWriter) headerHandler(h *headerFrame) error {
func (l *loopyWriter) originateStream(str *outStream) error {
hdr := str.itl.dequeue().(*headerFrame)
sendPing, err := hdr.initStream(str.id)
if err != nil {
if err := hdr.initStream(str.id); err != nil {
if err == ErrConnClosing {
return err
}
// Other errors(errStreamDrain) need not close transport.
return nil
}
if err = l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil {
if err := l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil {
return err
}
l.estdStreams[str.id] = str
if sendPing {
return l.pingHandler(&ping{data: [8]byte{}})
}
return nil
}

View File

@@ -149,6 +149,7 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 {
n = uint32(math.MaxInt32)
}
f.mu.Lock()
defer f.mu.Unlock()
// estSenderQuota is the receiver's view of the maximum number of bytes the sender
// can send without a window update.
estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate))
@@ -169,10 +170,8 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 {
// is padded; We will fallback on the current available window(at least a 1/4th of the limit).
f.delta = n
}
f.mu.Unlock()
return f.delta
}
f.mu.Unlock()
return 0
}

View File

@@ -1,52 +0,0 @@
// +build go1.6,!go1.7
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"net"
"net/http"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
}
// ContextErr converts the error from context package into a status error.
func ContextErr(err error) error {
switch err {
case context.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled:
return status.Error(codes.Canceled, err.Error())
}
return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err)
}
// contextFromRequest returns a background context.
func contextFromRequest(r *http.Request) context.Context {
return context.Background()
}

View File

@@ -1,53 +0,0 @@
// +build go1.7
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"context"
"net"
"net/http"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
netctx "golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, network, address)
}
// ContextErr converts the error from context package into a status error.
func ContextErr(err error) error {
switch err {
case context.DeadlineExceeded, netctx.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled, netctx.Canceled:
return status.Error(codes.Canceled, err.Error())
}
return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err)
}
// contextFromRequest returns a context from the HTTP Request.
func contextFromRequest(r *http.Request) context.Context {
return r.Context()
}

View File

@@ -24,6 +24,8 @@
package transport
import (
"bytes"
"context"
"errors"
"fmt"
"io"
@@ -34,7 +36,6 @@ import (
"time"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"golang.org/x/net/http2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
@@ -63,9 +64,6 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats sta
if _, ok := w.(http.Flusher); !ok {
return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher")
}
if _, ok := w.(http.CloseNotifier); !ok {
return nil, errors.New("gRPC requires a ResponseWriter supporting http.CloseNotifier")
}
st := &serverHandlerTransport{
rw: w,
@@ -176,17 +174,11 @@ func (a strAddr) String() string { return string(a) }
// do runs fn in the ServeHTTP goroutine.
func (ht *serverHandlerTransport) do(fn func()) error {
// Avoid a panic writing to closed channel. Imperfect but maybe good enough.
select {
case <-ht.closedCh:
return ErrConnClosing
default:
select {
case ht.writes <- fn:
return nil
case <-ht.closedCh:
return ErrConnClosing
}
case ht.writes <- fn:
return nil
}
}
@@ -235,9 +227,10 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
if err == nil { // transport has not been closed
if ht.stats != nil {
ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{
Trailer: s.trailer.Copy(),
})
}
close(ht.writes)
}
ht.Close()
return err
@@ -298,7 +291,9 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
if err == nil {
if ht.stats != nil {
ht.stats.HandleRPC(s.Context(), &stats.OutHeader{})
ht.stats.HandleRPC(s.Context(), &stats.OutHeader{
Header: md.Copy(),
})
}
}
return err
@@ -307,7 +302,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
// With this transport type there will be exactly 1 stream: this HTTP request.
ctx := contextFromRequest(ht.req)
ctx := ht.req.Context()
var cancel context.CancelFunc
if ht.timeoutSet {
ctx, cancel = context.WithTimeout(ctx, ht.timeout)
@@ -315,19 +310,13 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
ctx, cancel = context.WithCancel(ctx)
}
// requestOver is closed when either the request's context is done
// or the status has been written via WriteStatus.
// requestOver is closed when the status has been written via WriteStatus.
requestOver := make(chan struct{})
// clientGone receives a single value if peer is gone, either
// because the underlying connection is dead or because the
// peer sends an http2 RST_STREAM.
clientGone := ht.rw.(http.CloseNotifier).CloseNotify()
go func() {
select {
case <-requestOver:
case <-ht.closedCh:
case <-clientGone:
case <-ht.req.Context().Done():
}
cancel()
ht.Close()
@@ -349,7 +338,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
Addr: ht.RemoteAddr(),
}
if req.TLS != nil {
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{credentials.PrivacyAndIntegrity}}
}
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
s.ctx = peer.NewContext(ctx, pr)
@@ -363,7 +352,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
ht.stats.HandleRPC(s.ctx, inHeader)
}
s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
windowHandler: func(int) {},
}
@@ -377,7 +366,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
for buf := make([]byte, readSize); ; {
n, err := req.Body.Read(buf)
if n > 0 {
s.buf.put(recvMsg{data: buf[:n:n]})
s.buf.put(recvMsg{buffer: bytes.NewBuffer(buf[:n:n])})
buf = buf[n:]
}
if err != nil {
@@ -407,10 +396,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
func (ht *serverHandlerTransport) runStream() {
for {
select {
case fn, ok := <-ht.writes:
if !ok {
return
}
case fn := <-ht.writes:
fn()
case <-ht.closedCh:
return

View File

@@ -19,6 +19,8 @@
package transport
import (
"context"
"fmt"
"io"
"math"
"net"
@@ -28,13 +30,14 @@ import (
"sync/atomic"
"time"
"golang.org/x/net/context"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
@@ -42,8 +45,14 @@ import (
"google.golang.org/grpc/status"
)
// clientConnectionCounter counts the number of connections a client has
// initiated (equal to the number of http2Clients created). Must be accessed
// atomically.
var clientConnectionCounter uint64
// http2Client implements the ClientTransport interface with HTTP2.
type http2Client struct {
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
ctx context.Context
cancel context.CancelFunc
ctxDone <-chan struct{} // Cache the ctx.Done() chan.
@@ -60,8 +69,6 @@ type http2Client struct {
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
// that the server sent GoAway on this transport.
goAway chan struct{}
// awakenKeepalive is used to wake up keepalive when after it has gone dormant.
awakenKeepalive chan struct{}
framer *framer
// controlBuf delivers all the control related tasks (e.g., window
@@ -73,11 +80,8 @@ type http2Client struct {
isSecure bool
creds []credentials.PerRPCCredentials
perRPCCreds []credentials.PerRPCCredentials
// Boolean to keep track of reading activity on transport.
// 1 is true and 0 is false.
activity uint32 // Accessed atomically.
kp keepalive.ClientParameters
keepaliveEnabled bool
@@ -89,10 +93,10 @@ type http2Client struct {
maxSendHeaderListSize *uint32
bdpEst *bdpEstimator
// onSuccess is a callback that client transport calls upon
// onPrefaceReceipt is a callback that client transport calls upon
// receiving server preface to signal that a succefull HTTP2
// connection was established.
onSuccess func()
onPrefaceReceipt func()
maxConcurrentStreams uint32
streamQuota int64
@@ -108,17 +112,34 @@ type http2Client struct {
// goAwayReason records the http2.ErrCode and debug data received with the
// GoAway frame.
goAwayReason GoAwayReason
// A condition variable used to signal when the keepalive goroutine should
// go dormant. The condition for dormancy is based on the number of active
// streams and the `PermitWithoutStream` keepalive client parameter. And
// since the number of active streams is guarded by the above mutex, we use
// the same for this condition variable as well.
kpDormancyCond *sync.Cond
// A boolean to track whether the keepalive goroutine is dormant or not.
// This is checked before attempting to signal the above condition
// variable.
kpDormant bool
// Fields below are for channelz metric collection.
channelzID int64 // channelz unique identification number
czData *channelzData
onGoAway func(GoAwayReason)
onClose func()
bufferPool *bufferPool
connectionID uint64
}
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) {
if fn != nil {
return fn(ctx, addr)
}
return dialContext(ctx, "tcp", addr)
return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
}
func isTemporary(err error) bool {
@@ -140,7 +161,7 @@ func isTemporary(err error) bool {
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction
// fails.
func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts ConnectOptions, onSuccess func()) (_ *http2Client, err error) {
func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts ConnectOptions, onPrefaceReceipt func(), onGoAway func(GoAwayReason), onClose func()) (_ *http2Client, err error) {
scheme := "http"
ctx, cancel := context.WithCancel(ctx)
defer func() {
@@ -162,18 +183,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
conn.Close()
}
}(conn)
var (
isSecure bool
authInfo credentials.AuthInfo
)
if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Authority, conn)
if err != nil {
return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err)
}
isSecure = true
}
kp := opts.KeepaliveParams
// Validate keepalive parameters.
if kp.Time == 0 {
@@ -182,6 +191,36 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
if kp.Timeout == 0 {
kp.Timeout = defaultClientKeepaliveTimeout
}
keepaliveEnabled := false
if kp.Time != infinity {
if err = syscall.SetTCPUserTimeout(conn, kp.Timeout); err != nil {
return nil, connectionErrorf(false, err, "transport: failed to set TCP_USER_TIMEOUT: %v", err)
}
keepaliveEnabled = true
}
var (
isSecure bool
authInfo credentials.AuthInfo
)
transportCreds := opts.TransportCredentials
perRPCCreds := opts.PerRPCCredentials
if b := opts.CredsBundle; b != nil {
if t := b.TransportCredentials(); t != nil {
transportCreds = t
}
if t := b.PerRPCCredentials(); t != nil {
perRPCCreds = append(perRPCCreds, t)
}
}
if transportCreds != nil {
scheme = "https"
conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.Authority, conn)
if err != nil {
return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err)
}
isSecure = true
}
dynamicWindow := true
icwz := int32(initialWindowSize)
if opts.InitialConnWindowSize >= defaultWindowSize {
@@ -207,22 +246,25 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
goAway: make(chan struct{}),
awakenKeepalive: make(chan struct{}, 1),
framer: newFramer(conn, writeBufSize, readBufSize, maxHeaderListSize),
fc: &trInFlow{limit: uint32(icwz)},
scheme: scheme,
activeStreams: make(map[uint32]*Stream),
isSecure: isSecure,
creds: opts.PerRPCCredentials,
perRPCCreds: perRPCCreds,
kp: kp,
statsHandler: opts.StatsHandler,
initialWindowSize: initialWindowSize,
onSuccess: onSuccess,
onPrefaceReceipt: onPrefaceReceipt,
nextID: 1,
maxConcurrentStreams: defaultMaxStreamsClient,
streamQuota: defaultMaxStreamsClient,
streamsQuotaAvailable: make(chan struct{}, 1),
czData: new(channelzData),
onGoAway: onGoAway,
onClose: onClose,
keepaliveEnabled: keepaliveEnabled,
bufferPool: newBufferPool(),
}
t.controlBuf = newControlBuffer(t.ctxDone)
if opts.InitialWindowSize >= defaultWindowSize {
@@ -235,9 +277,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
updateFlowControl: t.updateFlowControl,
}
}
// Make sure awakenKeepalive can't be written upon.
// keepalive routine will make it writable, if need be.
t.awakenKeepalive <- struct{}{}
if t.statsHandler != nil {
t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{
RemoteAddr: t.remoteAddr,
@@ -249,16 +288,17 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
t.statsHandler.HandleConn(t.ctx, connBegin)
}
if channelz.IsOn() {
t.channelzID = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, "")
t.channelzID = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, fmt.Sprintf("%s -> %s", t.localAddr, t.remoteAddr))
}
if t.kp.Time != infinity {
t.keepaliveEnabled = true
if t.keepaliveEnabled {
t.kpDormancyCond = sync.NewCond(&t.mu)
go t.keepalive()
}
// Start the reader goroutine for incoming message. Each transport has
// a dedicated goroutine which reads HTTP2 frame from network. Then it
// dispatches the frame to the corresponding stream entity.
go t.reader()
// Send connection preface to server.
n, err := t.conn.Write(clientPreface)
if err != nil {
@@ -295,7 +335,12 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
return nil, connectionErrorf(true, err, "transport: failed to write window update: %v", err)
}
}
t.framer.writer.Flush()
t.connectionID = atomic.AddUint64(&clientConnectionCounter, 1)
if err := t.framer.writer.Flush(); err != nil {
return nil, err
}
go func() {
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst)
err := t.loopy.run()
@@ -315,6 +360,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
ct: t,
done: make(chan struct{}),
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
@@ -335,6 +381,10 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
ctx: s.ctx,
ctxDone: s.ctx.Done(),
recv: s.buf,
closeStream: func(err error) {
t.CloseStream(s, err)
},
freeBuffer: t.bufferPool.put,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
@@ -344,23 +394,24 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
}
func (t *http2Client) getPeer() *peer.Peer {
pr := &peer.Peer{
Addr: t.remoteAddr,
return &peer.Peer{
Addr: t.remoteAddr,
AuthInfo: t.authInfo,
}
// Attach Auth info if there is any.
if t.authInfo != nil {
pr.AuthInfo = t.authInfo
}
return pr
}
func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) {
aud := t.createAudience(callHdr)
authData, err := t.getTrAuthData(ctx, aud)
ri := credentials.RequestInfo{
Method: callHdr.Method,
AuthInfo: t.authInfo,
}
ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri)
authData, err := t.getTrAuthData(ctxWithRequestInfo, aud)
if err != nil {
return nil, err
}
callAuthData, err := t.getCallAuthData(ctx, aud, callHdr)
callAuthData, err := t.getCallAuthData(ctxWithRequestInfo, aud, callHdr)
if err != nil {
return nil, err
}
@@ -383,11 +434,12 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
if callHdr.SendCompress != "" {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-accept-encoding", Value: callHdr.SendCompress})
}
if dl, ok := ctx.Deadline(); ok {
// Send out timeout regardless its value. The server can detect timeout context by itself.
// TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire.
timeout := dl.Sub(time.Now())
timeout := time.Until(dl)
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
}
for k, v := range authData {
@@ -405,6 +457,15 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok {
var k string
for k, vv := range md {
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
if isReservedHeader(k) {
continue
}
for _, v := range vv {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
for _, vv := range added {
for i, v := range vv {
if i%2 == 0 {
@@ -418,15 +479,6 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
headerFields = append(headerFields, hpack.HeaderField{Name: strings.ToLower(k), Value: encodeMetadataHeader(k, v)})
}
}
for k, vv := range md {
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
if isReservedHeader(k) {
continue
}
for _, v := range vv {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
}
if md, ok := t.md.(*metadata.MD); ok {
for k, vv := range *md {
@@ -443,7 +495,7 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
func (t *http2Client) createAudience(callHdr *CallHdr) string {
// Create an audience string only if needed.
if len(t.creds) == 0 && callHdr.Creds == nil {
if len(t.perRPCCreds) == 0 && callHdr.Creds == nil {
return ""
}
// Construct URI required to get auth request metadata.
@@ -457,8 +509,11 @@ func (t *http2Client) createAudience(callHdr *CallHdr) string {
}
func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[string]string, error) {
if len(t.perRPCCreds) == 0 {
return nil, nil
}
authData := map[string]string{}
for _, c := range t.creds {
for _, c := range t.perRPCCreds {
data, err := c.GetRequestMetadata(ctx, audience)
if err != nil {
if _, ok := status.FromError(err); ok {
@@ -477,7 +532,7 @@ func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[s
}
func (t *http2Client) getCallAuthData(ctx context.Context, audience string, callHdr *CallHdr) (map[string]string, error) {
callAuthData := map[string]string{}
var callAuthData map[string]string
// Check if credentials.PerRPCCredentials were provided via call options.
// Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied.
@@ -489,6 +544,7 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
if err != nil {
return nil, status.Errorf(codes.Internal, "transport: %v", err)
}
callAuthData = make(map[string]string, len(data))
for k, v := range data {
// Capital header names are illegal in HTTP/2
k = strings.ToLower(k)
@@ -517,15 +573,14 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
s.write(recvMsg{err: err})
close(s.done)
// If headerChan isn't closed, then close it.
if atomic.SwapUint32(&s.headerDone, 1) == 0 {
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
close(s.headerChan)
}
}
hdr := &headerFrame{
hf: headerFields,
endStream: false,
initStream: func(id uint32) (bool, error) {
initStream: func(id uint32) error {
t.mu.Lock()
if state := t.state; state != reachable {
t.mu.Unlock()
@@ -535,29 +590,19 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
err = ErrConnClosing
}
cleanup(err)
return false, err
return err
}
t.activeStreams[id] = s
if channelz.IsOn() {
atomic.AddInt64(&t.czData.streamsStarted, 1)
atomic.StoreInt64(&t.czData.lastStreamCreatedTime, time.Now().UnixNano())
}
var sendPing bool
// If the number of active streams change from 0 to 1, then check if keepalive
// has gone dormant. If so, wake it up.
if len(t.activeStreams) == 1 && t.keepaliveEnabled {
select {
case t.awakenKeepalive <- struct{}{}:
sendPing = true
// Fill the awakenKeepalive channel again as this channel must be
// kept non-writable except at the point that the keepalive()
// goroutine is waiting either to be awaken or shutdown.
t.awakenKeepalive <- struct{}{}
default:
}
// If the keepalive goroutine has gone dormant, wake it up.
if t.kpDormant {
t.kpDormancyCond.Signal()
}
t.mu.Unlock()
return sendPing, nil
return nil
},
onOrphaned: cleanup,
wq: s.wq,
@@ -635,12 +680,14 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
}
}
if t.statsHandler != nil {
header, _, _ := metadata.FromOutgoingContextRaw(ctx)
outHeader := &stats.OutHeader{
Client: true,
FullMethod: callHdr.Method,
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
Compression: callHdr.SendCompress,
Header: header.Copy(),
}
t.statsHandler.HandleRPC(s.ctx, outHeader)
}
@@ -664,7 +711,9 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) {
// Set stream status to done.
if s.swapState(streamDone) == streamDone {
// If it was already done, return.
// If it was already done, return. If multiple closeStream calls
// happen simultaneously, wait for the first to finish.
<-s.done
return
}
// status and trailers can be updated here without any synchronization because the stream goroutine will
@@ -678,10 +727,8 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
// This will unblock reads eventually.
s.write(recvMsg{err: err})
}
// This will unblock write.
close(s.done)
// If headerChan isn't closed, then close it.
if atomic.SwapUint32(&s.headerDone, 1) == 0 {
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
s.noHeaders = true
close(s.headerChan)
}
@@ -715,11 +762,17 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
return true
}
t.controlBuf.executeAndPut(addBackStreamQuota, cleanup)
// This will unblock write.
close(s.done)
}
// Close kicks off the shutdown process of the transport. This should be called
// only once on a transport. Once it is called, the transport should not be
// accessed any more.
//
// This method blocks until the addrConn that initiated this transport is
// re-connected. This happens because t.onClose() begins reconnect logic at the
// addrConn level and blocks until the addrConn is successfully connected.
func (t *http2Client) Close() error {
t.mu.Lock()
// Make sure we only Close once.
@@ -727,9 +780,17 @@ func (t *http2Client) Close() error {
t.mu.Unlock()
return nil
}
// Call t.onClose before setting the state to closing to prevent the client
// from attempting to create new streams ASAP.
t.onClose()
t.state = closing
streams := t.activeStreams
t.activeStreams = nil
if t.kpDormant {
// If the keepalive goroutine is blocked on this condition variable, we
// should unblock it so that the goroutine eventually exits.
t.kpDormancyCond.Signal()
}
t.mu.Unlock()
t.controlBuf.finish()
t.cancel()
@@ -755,21 +816,21 @@ func (t *http2Client) Close() error {
// stream is closed. If there are no active streams, the transport is closed
// immediately. This does nothing if the transport is already draining or
// closing.
func (t *http2Client) GracefulClose() error {
func (t *http2Client) GracefulClose() {
t.mu.Lock()
// Make sure we move to draining only from active.
if t.state == draining || t.state == closing {
t.mu.Unlock()
return nil
return
}
t.state = draining
active := len(t.activeStreams)
t.mu.Unlock()
if active == 0 {
return t.Close()
t.Close()
return
}
t.controlBuf.put(&incomingGoAway{})
return nil
}
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller
@@ -805,11 +866,11 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
return t.controlBuf.put(df)
}
func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
func (t *http2Client) getStream(f http2.Frame) *Stream {
t.mu.Lock()
defer t.mu.Unlock()
s, ok := t.activeStreams[f.Header().StreamID]
return s, ok
s := t.activeStreams[f.Header().StreamID]
t.mu.Unlock()
return s
}
// adjustWindow sends out extra window update over the initial window size
@@ -889,8 +950,8 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
t.controlBuf.put(bdpPing)
}
// Select the right stream to dispatch.
s, ok := t.getStream(f)
if !ok {
s := t.getStream(f)
if s == nil {
return
}
if size > 0 {
@@ -907,9 +968,10 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated?
if len(f.Data()) > 0 {
data := make([]byte, len(f.Data()))
copy(data, f.Data())
s.write(recvMsg{data: data})
buffer := t.bufferPool.get()
buffer.Reset()
buffer.Write(f.Data())
s.write(recvMsg{buffer: buffer})
}
}
// The server has closed the stream without sending trailers. Record that
@@ -920,8 +982,8 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
}
func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
s, ok := t.getStream(f)
if !ok {
s := t.getStream(f)
if s == nil {
return
}
if f.ErrCode == http2.ErrCodeRefusedStream {
@@ -934,9 +996,9 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
statusCode = codes.Unknown
}
if statusCode == codes.Canceled {
// Our deadline was already exceeded, and that was likely the cause of
// this cancelation. Alter the status code accordingly.
if d, ok := s.ctx.Deadline(); ok && d.After(time.Now()) {
if d, ok := s.ctx.Deadline(); ok && !d.After(time.Now()) {
// Our deadline was already exceeded, and that was likely the cause
// of this cancelation. Alter the status code accordingly.
statusCode = codes.DeadlineExceeded
}
}
@@ -1041,8 +1103,12 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
default:
t.setGoAwayReason(f)
close(t.goAway)
t.state = draining
t.controlBuf.put(&incomingGoAway{})
// Notify the clientconn about the GOAWAY before we set the state to
// draining, to allow the client to stop attempting to create streams
// before disallowing new streams on this connection.
t.onGoAway(t.goAwayReason)
t.state = draining
}
// All streams with IDs greater than the GoAwayId
// and smaller than the previous GoAway ID should be killed.
@@ -1094,58 +1160,77 @@ func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
// operateHeaders takes action on the decoded headers.
func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
s, ok := t.getStream(frame)
if !ok {
s := t.getStream(frame)
if s == nil {
return
}
endStream := frame.StreamEnded()
atomic.StoreUint32(&s.bytesReceived, 1)
var state decodeState
if err := state.decodeHeader(frame); err != nil {
t.closeStream(s, err, true, http2.ErrCodeProtocol, status.New(codes.Internal, err.Error()), nil, false)
// Something wrong. Stops reading even when there is remaining.
initialHeader := atomic.LoadUint32(&s.headerChanClosed) == 0
if !initialHeader && !endStream {
// As specified by gRPC over HTTP2, a HEADERS frame (and associated CONTINUATION frames) can only appear at the start or end of a stream. Therefore, second HEADERS frame must have EOS bit set.
st := status.New(codes.Internal, "a HEADERS frame cannot appear in the middle of a stream")
t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false)
return
}
endStream := frame.StreamEnded()
var isHeader bool
state := &decodeState{}
// Initialize isGRPC value to be !initialHeader, since if a gRPC Response-Headers has already been received, then it means that the peer is speaking gRPC and we are in gRPC mode.
state.data.isGRPC = !initialHeader
if err := state.decodeHeader(frame); err != nil {
t.closeStream(s, err, true, http2.ErrCodeProtocol, status.Convert(err), nil, endStream)
return
}
isHeader := false
defer func() {
if t.statsHandler != nil {
if isHeader {
inHeader := &stats.InHeader{
Client: true,
WireLength: int(frame.Header().Length),
Header: s.header.Copy(),
}
t.statsHandler.HandleRPC(s.ctx, inHeader)
} else {
inTrailer := &stats.InTrailer{
Client: true,
WireLength: int(frame.Header().Length),
Trailer: s.trailer.Copy(),
}
t.statsHandler.HandleRPC(s.ctx, inTrailer)
}
}
}()
// If headers haven't been received yet.
if atomic.SwapUint32(&s.headerDone, 1) == 0 {
// If headerChan hasn't been closed yet
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
s.headerValid = true
if !endStream {
// Headers frame is not actually a trailers-only frame.
// HEADERS frame block carries a Response-Headers.
isHeader = true
// These values can be set without any synchronization because
// stream goroutine will read it only after seeing a closed
// headerChan which we'll close after setting this.
s.recvCompress = state.encoding
if len(state.mdata) > 0 {
s.header = state.mdata
s.recvCompress = state.data.encoding
if len(state.data.mdata) > 0 {
s.header = state.data.mdata
}
} else {
// HEADERS frame block carries a Trailers-Only.
s.noHeaders = true
}
close(s.headerChan)
}
if !endStream {
return
}
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, state.status(), state.mdata, true)
// if client received END_STREAM from server while stream was still active, send RST_STREAM
rst := s.getState() == streamActive
t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.data.mdata, true)
}
// reader runs as a separate goroutine in charge of reading data from network
@@ -1159,25 +1244,27 @@ func (t *http2Client) reader() {
// Check the validity of server preface.
frame, err := t.framer.fr.ReadFrame()
if err != nil {
t.Close()
t.Close() // this kicks off resetTransport, so must be last before return
return
}
t.conn.SetReadDeadline(time.Time{}) // reset deadline once we get the settings frame (we didn't time out, yay!)
if t.keepaliveEnabled {
atomic.CompareAndSwapUint32(&t.activity, 0, 1)
atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
}
sf, ok := frame.(*http2.SettingsFrame)
if !ok {
t.Close()
t.Close() // this kicks off resetTransport, so must be last before return
return
}
t.onSuccess()
t.onPrefaceReceipt()
t.handleSettings(sf, true)
// loop to keep reading incoming messages on this transport.
for {
t.controlBuf.throttle()
frame, err := t.framer.fr.ReadFrame()
if t.keepaliveEnabled {
atomic.CompareAndSwapUint32(&t.activity, 0, 1)
atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
}
if err != nil {
// Abort an active stream if the http2.Framer returns a
@@ -1221,55 +1308,83 @@ func (t *http2Client) reader() {
}
}
func minTime(a, b time.Duration) time.Duration {
if a < b {
return a
}
return b
}
// keepalive running in a separate goroutune makes sure the connection is alive by sending pings.
func (t *http2Client) keepalive() {
p := &ping{data: [8]byte{}}
// True iff a ping has been sent, and no data has been received since then.
outstandingPing := false
// Amount of time remaining before which we should receive an ACK for the
// last sent ping.
timeoutLeft := time.Duration(0)
// Records the last value of t.lastRead before we go block on the timer.
// This is required to check for read activity since then.
prevNano := time.Now().UnixNano()
timer := time.NewTimer(t.kp.Time)
for {
select {
case <-timer.C:
if atomic.CompareAndSwapUint32(&t.activity, 1, 0) {
timer.Reset(t.kp.Time)
lastRead := atomic.LoadInt64(&t.lastRead)
if lastRead > prevNano {
// There has been read activity since the last time we were here.
outstandingPing = false
// Next timer should fire at kp.Time seconds from lastRead time.
timer.Reset(time.Duration(lastRead) + t.kp.Time - time.Duration(time.Now().UnixNano()))
prevNano = lastRead
continue
}
// Check if keepalive should go dormant.
if outstandingPing && timeoutLeft <= 0 {
t.Close()
return
}
t.mu.Lock()
if t.state == closing {
// If the transport is closing, we should exit from the
// keepalive goroutine here. If not, we could have a race
// between the call to Signal() from Close() and the call to
// Wait() here, whereby the keepalive goroutine ends up
// blocking on the condition variable which will never be
// signalled again.
t.mu.Unlock()
return
}
if len(t.activeStreams) < 1 && !t.kp.PermitWithoutStream {
// Make awakenKeepalive writable.
<-t.awakenKeepalive
t.mu.Unlock()
select {
case <-t.awakenKeepalive:
// If the control gets here a ping has been sent
// need to reset the timer with keepalive.Timeout.
case <-t.ctx.Done():
return
}
} else {
t.mu.Unlock()
// If a ping was sent out previously (because there were active
// streams at that point) which wasn't acked and its timeout
// hadn't fired, but we got here and are about to go dormant,
// we should make sure that we unconditionally send a ping once
// we awaken.
outstandingPing = false
t.kpDormant = true
t.kpDormancyCond.Wait()
}
t.kpDormant = false
t.mu.Unlock()
// We get here either because we were dormant and a new stream was
// created which unblocked the Wait() call, or because the
// keepalive timer expired. In both cases, we need to send a ping.
if !outstandingPing {
if channelz.IsOn() {
atomic.AddInt64(&t.czData.kpCount, 1)
}
// Send ping.
t.controlBuf.put(p)
timeoutLeft = t.kp.Timeout
outstandingPing = true
}
// By the time control gets here a ping has been sent one way or the other.
timer.Reset(t.kp.Timeout)
select {
case <-timer.C:
if atomic.CompareAndSwapUint32(&t.activity, 1, 0) {
timer.Reset(t.kp.Time)
continue
}
t.Close()
return
case <-t.ctx.Done():
if !timer.Stop() {
<-timer.C
}
return
}
// The amount of time to sleep here is the minimum of kp.Time and
// timeoutLeft. This will ensure that we wait only for kp.Time
// before sending out the next ping (for cases where the ping is
// acked).
sleepDuration := minTime(t.kp.Time, timeoutLeft)
timeoutLeft -= sleepDuration
timer.Reset(sleepDuration)
case <-t.ctx.Done():
if !timer.Stop() {
<-timer.C
@@ -1311,6 +1426,8 @@ func (t *http2Client) ChannelzMetric() *channelz.SocketInternalMetric {
return &s
}
func (t *http2Client) RemoteAddr() net.Addr { return t.remoteAddr }
func (t *http2Client) IncrMsgSent() {
atomic.AddInt64(&t.czData.msgSent, 1)
atomic.StoreInt64(&t.czData.lastMsgSentTime, time.Now().UnixNano())

View File

@@ -20,6 +20,7 @@ package transport
import (
"bytes"
"context"
"errors"
"fmt"
"io"
@@ -31,13 +32,14 @@ import (
"time"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/keepalive"
@@ -55,13 +57,20 @@ var (
// ErrHeaderListSizeLimitViolation indicates that the header list size is larger
// than the limit set by peer.
ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer")
// statusRawProto is a function to get to the raw status proto wrapped in a
// status.Status without a proto.Clone().
statusRawProto = internal.StatusRawProto.(func(*status.Status) *spb.Status)
)
// serverConnectionCounter counts the number of connections a server has seen
// (equal to the number of http2Servers created). Must be accessed atomically.
var serverConnectionCounter uint64
// http2Server implements the ServerTransport interface with HTTP2.
type http2Server struct {
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
ctx context.Context
ctxDone <-chan struct{} // Cache the context.Done() chan
cancel context.CancelFunc
done chan struct{}
conn net.Conn
loopy *loopyWriter
readerDone chan struct{} // sync point to enable testing.
@@ -79,12 +88,8 @@ type http2Server struct {
controlBuf *controlBuffer
fc *trInFlow
stats stats.Handler
// Flag to keep track of reading activity on transport.
// 1 is true and 0 is false.
activity uint32 // Accessed atomically.
// Keepalive and max-age parameters for the server.
kp keepalive.ServerParameters
// Keepalive enforcement policy.
kep keepalive.EnforcementPolicy
// The time instance last ping was received.
@@ -119,6 +124,9 @@ type http2Server struct {
// Fields below are for channelz metric collection.
channelzID int64 // channelz unique identification number
czData *channelzData
bufferPool *bufferPool
connectionID uint64
}
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
@@ -132,7 +140,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
}
framer := newFramer(conn, writeBufSize, readBufSize, maxHeaderListSize)
// Send initial settings as connection preface to client.
var isettings []http2.Setting
isettings := []http2.Setting{{
ID: http2.SettingMaxFrameSize,
Val: http2MaxFrameLen,
}}
// TODO(zhaoq): Have a better way to signal "no limit" because 0 is
// permitted in the HTTP2 spec.
maxStreams := config.MaxStreams
@@ -166,6 +177,12 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
Val: *config.MaxHeaderListSize,
})
}
if config.HeaderTableSize != nil {
isettings = append(isettings, http2.Setting{
ID: http2.SettingHeaderTableSize,
Val: *config.HeaderTableSize,
})
}
if err := framer.fr.WriteSettings(isettings...); err != nil {
return nil, connectionErrorf(false, err, "transport: %v", err)
}
@@ -197,11 +214,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
if kep.MinTime == 0 {
kep.MinTime = defaultKeepalivePolicyMinTime
}
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
t := &http2Server{
ctx: ctx,
cancel: cancel,
ctxDone: ctx.Done(),
ctx: context.Background(),
done: done,
conn: conn,
remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(),
@@ -220,8 +236,9 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
kep: kep,
initialWindowSize: iwz,
czData: new(channelzData),
bufferPool: newBufferPool(),
}
t.controlBuf = newControlBuffer(t.ctxDone)
t.controlBuf = newControlBuffer(t.done)
if dynamicWindow {
t.bdpEst = &bdpEstimator{
bdp: initialWindowSize,
@@ -237,8 +254,11 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
t.stats.HandleConn(t.ctx, connBegin)
}
if channelz.IsOn() {
t.channelzID = channelz.RegisterNormalSocket(t, config.ChannelzParentID, "")
t.channelzID = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
}
t.connectionID = atomic.AddUint64(&serverConnectionCounter, 1)
t.framer.writer.Flush()
defer func() {
@@ -263,7 +283,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
if err != nil {
return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to read initial settings frame: %v", err)
}
atomic.StoreUint32(&t.activity, 1)
atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
sf, ok := frame.(*http2.SettingsFrame)
if !ok {
return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams saw invalid preface type %T from client", frame)
@@ -286,7 +306,9 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
// operateHeader takes action on the decoded headers.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (fatal bool) {
streamID := frame.Header().StreamID
state := decodeState{serverSide: true}
state := &decodeState{
serverSide: true,
}
if err := state.decodeHeader(frame); err != nil {
if se, ok := status.FromError(err); ok {
t.controlBuf.put(&cleanupStream{
@@ -305,16 +327,16 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
recvCompress: state.encoding,
method: state.method,
contentSubtype: state.contentSubtype,
recvCompress: state.data.encoding,
method: state.data.method,
contentSubtype: state.data.contentSubtype,
}
if frame.StreamEnded() {
// s is just created by the caller. No lock needed.
s.state = streamReadDone
}
if state.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout)
if state.data.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(t.ctx, state.data.timeout)
} else {
s.ctx, s.cancel = context.WithCancel(t.ctx)
}
@@ -327,19 +349,19 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
}
s.ctx = peer.NewContext(s.ctx, pr)
// Attach the received metadata to the context.
if len(state.mdata) > 0 {
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
if len(state.data.mdata) > 0 {
s.ctx = metadata.NewIncomingContext(s.ctx, state.data.mdata)
}
if state.statsTags != nil {
s.ctx = stats.SetIncomingTags(s.ctx, state.statsTags)
if state.data.statsTags != nil {
s.ctx = stats.SetIncomingTags(s.ctx, state.data.statsTags)
}
if state.statsTrace != nil {
s.ctx = stats.SetIncomingTrace(s.ctx, state.statsTrace)
if state.data.statsTrace != nil {
s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace)
}
if t.inTapHandle != nil {
var err error
info := &tap.Info{
FullMethodName: state.method,
FullMethodName: state.data.method,
}
s.ctx, err = t.inTapHandle(s.ctx, info)
if err != nil {
@@ -350,12 +372,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
rstCode: http2.ErrCodeRefusedStream,
onWrite: func() {},
})
s.cancel()
return false
}
}
t.mu.Lock()
if t.state != reachable {
t.mu.Unlock()
s.cancel()
return false
}
if uint32(len(t.activeStreams)) >= t.maxStreams {
@@ -366,12 +390,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
rstCode: http2.ErrCodeRefusedStream,
onWrite: func() {},
})
s.cancel()
return false
}
if streamID%2 != 1 || streamID <= t.maxStreamID {
t.mu.Unlock()
// illegal gRPC stream id.
errorf("transport: http2Server.HandleStreams received an illegal stream id: %v", streamID)
s.cancel()
return true
}
t.maxStreamID = streamID
@@ -396,6 +422,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
LocalAddr: t.localAddr,
Compression: s.recvCompress,
WireLength: int(frame.Header().Length),
Header: metadata.MD(state.data.mdata).Copy(),
}
t.stats.HandleRPC(s.ctx, inHeader)
}
@@ -403,9 +430,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
ctxDone: s.ctxDone,
recv: s.buf,
ctx: s.ctx,
ctxDone: s.ctxDone,
recv: s.buf,
freeBuffer: t.bufferPool.put,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
@@ -426,8 +454,9 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) {
defer close(t.readerDone)
for {
t.controlBuf.throttle()
frame, err := t.framer.fr.ReadFrame()
atomic.StoreUint32(&t.activity, 1)
atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
if err != nil {
if se, ok := err.(http2.StreamError); ok {
warningf("transport: http2Server.HandleStreams encountered http2.StreamError: %v", se)
@@ -435,7 +464,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.
s := t.activeStreams[se.StreamID]
t.mu.Unlock()
if s != nil {
t.closeStream(s, true, se.Code, nil, false)
t.closeStream(s, true, se.Code, false)
} else {
t.controlBuf.put(&cleanupStream{
streamID: se.StreamID,
@@ -577,7 +606,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
}
if size > 0 {
if err := s.fc.onData(size); err != nil {
t.closeStream(s, true, http2.ErrCodeFlowControl, nil, false)
t.closeStream(s, true, http2.ErrCodeFlowControl, false)
return
}
if f.Header().Flags.Has(http2.FlagDataPadded) {
@@ -589,9 +618,10 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated?
if len(f.Data()) > 0 {
data := make([]byte, len(f.Data()))
copy(data, f.Data())
s.write(recvMsg{data: data})
buffer := t.bufferPool.get()
buffer.Reset()
buffer.Write(f.Data())
s.write(recvMsg{buffer: buffer})
}
}
if f.Header().Flags.Has(http2.FlagDataEndStream) {
@@ -602,11 +632,18 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
}
func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) {
s, ok := t.getStream(f)
if !ok {
// If the stream is not deleted from the transport's active streams map, then do a regular close stream.
if s, ok := t.getStream(f); ok {
t.closeStream(s, false, 0, false)
return
}
t.closeStream(s, false, 0, nil, false)
// If the stream is already deleted from the active streams map, then put a cleanupStream item into controlbuf to delete the stream from loopy writer's established streams map.
t.controlBuf.put(&cleanupStream{
streamID: f.Header().StreamID,
rst: false,
rstCode: 0,
onWrite: func() {},
})
}
func (t *http2Server) handleSettings(f *http2.SettingsFrame) {
@@ -727,7 +764,7 @@ func (t *http2Server) checkForHeaderListSize(it interface{}) bool {
return true
}
// WriteHeader sends the header metedata md back to the client.
// WriteHeader sends the header metadata md back to the client.
func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
if s.updateHeaderSent() || s.getState() == streamDone {
return ErrIllegalHeaderWrite
@@ -748,6 +785,10 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
return nil
}
func (t *http2Server) setResetPingStrikes() {
atomic.StoreUint32(&t.resetPingStrikes, 1)
}
func (t *http2Server) writeHeaderLocked(s *Stream) error {
// TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
// first and create a slice of that exact size.
@@ -762,21 +803,21 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error {
streamID: s.id,
hf: headerFields,
endStream: false,
onWrite: func() {
atomic.StoreUint32(&t.resetPingStrikes, 1)
},
onWrite: t.setResetPingStrikes,
})
if !success {
if err != nil {
return err
}
t.closeStream(s, true, http2.ErrCodeInternal, nil, false)
t.closeStream(s, true, http2.ErrCodeInternal, false)
return ErrHeaderListSizeLimitViolation
}
if t.stats != nil {
// Note: WireLength is not set in outHeader.
// TODO(mmukhi): Revisit this later, if needed.
outHeader := &stats.OutHeader{}
outHeader := &stats.OutHeader{
Header: s.header.Copy(),
}
t.stats.HandleRPC(s.Context(), outHeader)
}
return nil
@@ -808,7 +849,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
if p := st.Proto(); p != nil && len(p.Details) > 0 {
if p := statusRawProto(st); p != nil && len(p.Details) > 0 {
stBytes, err := proto.Marshal(p)
if err != nil {
// TODO: return error instead, when callers are able to handle it.
@@ -824,9 +865,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
streamID: s.id,
hf: headerFields,
endStream: true,
onWrite: func() {
atomic.StoreUint32(&t.resetPingStrikes, 1)
},
onWrite: t.setResetPingStrikes,
}
s.hdrMu.Unlock()
success, err := t.controlBuf.execute(t.checkForHeaderListSize, trailingHeader)
@@ -834,12 +873,16 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
if err != nil {
return err
}
t.closeStream(s, true, http2.ErrCodeInternal, nil, false)
t.closeStream(s, true, http2.ErrCodeInternal, false)
return ErrHeaderListSizeLimitViolation
}
t.closeStream(s, false, 0, trailingHeader, true)
// Send a RST_STREAM after the trailers if the client has not already half-closed.
rst := s.getState() == streamActive
t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true)
if t.stats != nil {
t.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
t.stats.HandleRPC(s.Context(), &stats.OutTrailer{
Trailer: s.trailer.Copy(),
})
}
return nil
}
@@ -849,6 +892,9 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
if !s.isHeaderSent() { // Headers haven't been written yet.
if err := t.WriteHeader(s, nil); err != nil {
if _, ok := err.(ConnectionError); ok {
return err
}
// TODO(mmukhi, dfawley): Make sure this is the right code to return.
return status.Errorf(codes.Internal, "transport: %v", err)
}
@@ -858,7 +904,7 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
// TODO(mmukhi, dfawley): Should the server write also return io.EOF?
s.cancel()
select {
case <-t.ctx.Done():
case <-t.done:
return ErrConnClosing
default:
}
@@ -873,16 +919,14 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
hdr = append(hdr, data[:emptyLen]...)
data = data[emptyLen:]
df := &dataFrame{
streamID: s.id,
h: hdr,
d: data,
onEachWrite: func() {
atomic.StoreUint32(&t.resetPingStrikes, 1)
},
streamID: s.id,
h: hdr,
d: data,
onEachWrite: t.setResetPingStrikes,
}
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil {
select {
case <-t.ctx.Done():
case <-t.done:
return ErrConnClosing
default:
}
@@ -899,32 +943,35 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
// after an additional duration of keepalive.Timeout.
func (t *http2Server) keepalive() {
p := &ping{}
var pingSent bool
maxIdle := time.NewTimer(t.kp.MaxConnectionIdle)
maxAge := time.NewTimer(t.kp.MaxConnectionAge)
keepalive := time.NewTimer(t.kp.Time)
// NOTE: All exit paths of this function should reset their
// respective timers. A failure to do so will cause the
// following clean-up to deadlock and eventually leak.
// True iff a ping has been sent, and no data has been received since then.
outstandingPing := false
// Amount of time remaining before which we should receive an ACK for the
// last sent ping.
kpTimeoutLeft := time.Duration(0)
// Records the last value of t.lastRead before we go block on the timer.
// This is required to check for read activity since then.
prevNano := time.Now().UnixNano()
// Initialize the different timers to their default values.
idleTimer := time.NewTimer(t.kp.MaxConnectionIdle)
ageTimer := time.NewTimer(t.kp.MaxConnectionAge)
kpTimer := time.NewTimer(t.kp.Time)
defer func() {
if !maxIdle.Stop() {
<-maxIdle.C
}
if !maxAge.Stop() {
<-maxAge.C
}
if !keepalive.Stop() {
<-keepalive.C
}
// We need to drain the underlying channel in these timers after a call
// to Stop(), only if we are interested in resetting them. Clearly we
// are not interested in resetting them here.
idleTimer.Stop()
ageTimer.Stop()
kpTimer.Stop()
}()
for {
select {
case <-maxIdle.C:
case <-idleTimer.C:
t.mu.Lock()
idle := t.idle
if idle.IsZero() { // The connection is non-idle.
t.mu.Unlock()
maxIdle.Reset(t.kp.MaxConnectionIdle)
idleTimer.Reset(t.kp.MaxConnectionIdle)
continue
}
val := t.kp.MaxConnectionIdle - time.Since(idle)
@@ -933,42 +980,52 @@ func (t *http2Server) keepalive() {
// The connection has been idle for a duration of keepalive.MaxConnectionIdle or more.
// Gracefully close the connection.
t.drain(http2.ErrCodeNo, []byte{})
// Resetting the timer so that the clean-up doesn't deadlock.
maxIdle.Reset(infinity)
return
}
maxIdle.Reset(val)
case <-maxAge.C:
idleTimer.Reset(val)
case <-ageTimer.C:
t.drain(http2.ErrCodeNo, []byte{})
maxAge.Reset(t.kp.MaxConnectionAgeGrace)
ageTimer.Reset(t.kp.MaxConnectionAgeGrace)
select {
case <-maxAge.C:
case <-ageTimer.C:
// Close the connection after grace period.
infof("transport: closing server transport due to maximum connection age.")
t.Close()
// Resetting the timer so that the clean-up doesn't deadlock.
maxAge.Reset(infinity)
case <-t.ctx.Done():
case <-t.done:
}
return
case <-keepalive.C:
if atomic.CompareAndSwapUint32(&t.activity, 1, 0) {
pingSent = false
keepalive.Reset(t.kp.Time)
case <-kpTimer.C:
lastRead := atomic.LoadInt64(&t.lastRead)
if lastRead > prevNano {
// There has been read activity since the last time we were
// here. Setup the timer to fire at kp.Time seconds from
// lastRead time and continue.
outstandingPing = false
kpTimer.Reset(time.Duration(lastRead) + t.kp.Time - time.Duration(time.Now().UnixNano()))
prevNano = lastRead
continue
}
if pingSent {
if outstandingPing && kpTimeoutLeft <= 0 {
infof("transport: closing server transport due to idleness.")
t.Close()
// Resetting the timer so that the clean-up doesn't deadlock.
keepalive.Reset(infinity)
return
}
pingSent = true
if channelz.IsOn() {
atomic.AddInt64(&t.czData.kpCount, 1)
if !outstandingPing {
if channelz.IsOn() {
atomic.AddInt64(&t.czData.kpCount, 1)
}
t.controlBuf.put(p)
kpTimeoutLeft = t.kp.Timeout
outstandingPing = true
}
t.controlBuf.put(p)
keepalive.Reset(t.kp.Timeout)
case <-t.ctx.Done():
// The amount of time to sleep here is the minimum of kp.Time and
// timeoutLeft. This will ensure that we wait only for kp.Time
// before sending out the next ping (for cases where the ping is
// acked).
sleepDuration := minTime(t.kp.Time, kpTimeoutLeft)
kpTimeoutLeft -= sleepDuration
kpTimer.Reset(sleepDuration)
case <-t.done:
return
}
}
@@ -988,7 +1045,7 @@ func (t *http2Server) Close() error {
t.activeStreams = nil
t.mu.Unlock()
t.controlBuf.finish()
t.cancel()
close(t.done)
err := t.conn.Close()
if channelz.IsOn() {
channelz.RemoveEntry(t.channelzID)
@@ -1004,45 +1061,61 @@ func (t *http2Server) Close() error {
return err
}
// closeStream clears the footprint of a stream when the stream is not needed
// any more.
func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) {
if s.swapState(streamDone) == streamDone {
// If the stream was already done, return.
return
}
// deleteStream deletes the stream s from transport's active streams.
func (t *http2Server) deleteStream(s *Stream, eosReceived bool) {
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be
// called to interrupt the potential blocking on other goroutines.
s.cancel()
cleanup := &cleanupStream{
t.mu.Lock()
if _, ok := t.activeStreams[s.id]; ok {
delete(t.activeStreams, s.id)
if len(t.activeStreams) == 0 {
t.idle = time.Now()
}
}
t.mu.Unlock()
if channelz.IsOn() {
if eosReceived {
atomic.AddInt64(&t.czData.streamsSucceeded, 1)
} else {
atomic.AddInt64(&t.czData.streamsFailed, 1)
}
}
}
// finishStream closes the stream and puts the trailing headerFrame into controlbuf.
func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) {
oldState := s.swapState(streamDone)
if oldState == streamDone {
// If the stream was already done, return.
return
}
hdr.cleanup = &cleanupStream{
streamID: s.id,
rst: rst,
rstCode: rstCode,
onWrite: func() {
t.mu.Lock()
if t.activeStreams != nil {
delete(t.activeStreams, s.id)
if len(t.activeStreams) == 0 {
t.idle = time.Now()
}
}
t.mu.Unlock()
if channelz.IsOn() {
if eosReceived {
atomic.AddInt64(&t.czData.streamsSucceeded, 1)
} else {
atomic.AddInt64(&t.czData.streamsFailed, 1)
}
}
t.deleteStream(s, eosReceived)
},
}
if hdr != nil {
hdr.cleanup = cleanup
t.controlBuf.put(hdr)
} else {
t.controlBuf.put(cleanup)
}
t.controlBuf.put(hdr)
}
// closeStream clears the footprint of a stream when the stream is not needed any more.
func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) {
s.swapState(streamDone)
t.deleteStream(s, eosReceived)
t.controlBuf.put(&cleanupStream{
streamID: s.id,
rst: rst,
rstCode: rstCode,
onWrite: func() {},
})
}
func (t *http2Server) RemoteAddr() net.Addr {
@@ -1112,7 +1185,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
select {
case <-t.drainChan:
case <-timer.C:
case <-t.ctx.Done():
case <-t.done:
return
}
t.controlBuf.put(&goAway{code: g.code, debugData: g.debugData})
@@ -1155,14 +1228,14 @@ func (t *http2Server) IncrMsgRecv() {
}
func (t *http2Server) getOutFlowWindow() int64 {
resp := make(chan uint32)
resp := make(chan uint32, 1)
timer := time.NewTimer(time.Second)
defer timer.Stop()
t.controlBuf.put(&outFlowControlSizeRequest{resp})
select {
case sz := <-resp:
return int64(sz)
case <-t.ctxDone:
case <-t.done:
return -1
case <-timer.C:
return -2

View File

@@ -24,6 +24,7 @@ import (
"encoding/base64"
"fmt"
"io"
"math"
"net"
"net/http"
"strconv"
@@ -77,7 +78,8 @@ var (
codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm,
codes.PermissionDenied: http2.ErrCodeInadequateSecurity,
}
httpStatusConvTab = map[int]codes.Code{
// HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table.
HTTPStatusConvTab = map[int]codes.Code{
// 400 Bad Request - INTERNAL.
http.StatusBadRequest: codes.Internal,
// 401 Unauthorized - UNAUTHENTICATED.
@@ -97,9 +99,7 @@ var (
}
)
// Records the states during HPACK decoding. Must be reset once the
// decoding of the entire headers are finished.
type decodeState struct {
type parsedHeaderData struct {
encoding string
// statusGen caches the stream status received from the trailer the server
// sent. Client side only. Do not access directly. After all trailers are
@@ -119,8 +119,30 @@ type decodeState struct {
statsTags []byte
statsTrace []byte
contentSubtype string
// isGRPC field indicates whether the peer is speaking gRPC (otherwise HTTP).
//
// We are in gRPC mode (peer speaking gRPC) if:
// * We are client side and have already received a HEADER frame that indicates gRPC peer.
// * The header contains valid a content-type, i.e. a string starts with "application/grpc"
// And we should handle error specific to gRPC.
//
// Otherwise (i.e. a content-type string starts without "application/grpc", or does not exist), we
// are in HTTP fallback mode, and should handle error specific to HTTP.
isGRPC bool
grpcErr error
httpErr error
contentTypeErr string
}
// decodeState configures decoding criteria and records the decoded data.
type decodeState struct {
// whether decoding on server side or not
serverSide bool
// Records the states during HPACK decoding. It will be filled with info parsed from HTTP HEADERS
// frame once decodeHeader function has been invoked and returned.
data parsedHeaderData
}
// isReservedHeader checks whether hdr belongs to HTTP2 headers
@@ -201,11 +223,11 @@ func contentType(contentSubtype string) string {
}
func (d *decodeState) status() *status.Status {
if d.statusGen == nil {
if d.data.statusGen == nil {
// No status-details were provided; generate status using code/msg.
d.statusGen = status.New(codes.Code(int32(*(d.rawStatusCode))), d.rawStatusMsg)
d.data.statusGen = status.New(codes.Code(int32(*(d.data.rawStatusCode))), d.data.rawStatusMsg)
}
return d.statusGen
return d.data.statusGen
}
const binHdrSuffix = "-bin"
@@ -243,113 +265,146 @@ func (d *decodeState) decodeHeader(frame *http2.MetaHeadersFrame) error {
if frame.Truncated {
return status.Error(codes.Internal, "peer header list size exceeded limit")
}
for _, hf := range frame.Fields {
if err := d.processHeaderField(hf); err != nil {
return err
d.processHeaderField(hf)
}
if d.data.isGRPC {
if d.data.grpcErr != nil {
return d.data.grpcErr
}
if d.serverSide {
return nil
}
if d.data.rawStatusCode == nil && d.data.statusGen == nil {
// gRPC status doesn't exist.
// Set rawStatusCode to be unknown and return nil error.
// So that, if the stream has ended this Unknown status
// will be propagated to the user.
// Otherwise, it will be ignored. In which case, status from
// a later trailer, that has StreamEnded flag set, is propagated.
code := int(codes.Unknown)
d.data.rawStatusCode = &code
}
}
if d.serverSide {
return nil
}
// If grpc status exists, no need to check further.
if d.rawStatusCode != nil || d.statusGen != nil {
return nil
// HTTP fallback mode
if d.data.httpErr != nil {
return d.data.httpErr
}
// If grpc status doesn't exist and http status doesn't exist,
// then it's a malformed header.
if d.httpStatus == nil {
return status.Error(codes.Internal, "malformed header: doesn't contain status(gRPC or HTTP)")
}
var (
code = codes.Internal // when header does not include HTTP status, return INTERNAL
ok bool
)
if *(d.httpStatus) != http.StatusOK {
code, ok := httpStatusConvTab[*(d.httpStatus)]
if d.data.httpStatus != nil {
code, ok = HTTPStatusConvTab[*(d.data.httpStatus)]
if !ok {
code = codes.Unknown
}
return status.Error(code, http.StatusText(*(d.httpStatus)))
}
// gRPC status doesn't exist and http status is OK.
// Set rawStatusCode to be unknown and return nil error.
// So that, if the stream has ended this Unknown status
// will be propagated to the user.
// Otherwise, it will be ignored. In which case, status from
// a later trailer, that has StreamEnded flag set, is propagated.
code := int(codes.Unknown)
d.rawStatusCode = &code
return nil
return status.Error(code, d.constructHTTPErrMsg())
}
// constructErrMsg constructs error message to be returned in HTTP fallback mode.
// Format: HTTP status code and its corresponding message + content-type error message.
func (d *decodeState) constructHTTPErrMsg() string {
var errMsgs []string
if d.data.httpStatus == nil {
errMsgs = append(errMsgs, "malformed header: missing HTTP status")
} else {
errMsgs = append(errMsgs, fmt.Sprintf("%s: HTTP status code %d", http.StatusText(*(d.data.httpStatus)), *d.data.httpStatus))
}
if d.data.contentTypeErr == "" {
errMsgs = append(errMsgs, "transport: missing content-type field")
} else {
errMsgs = append(errMsgs, d.data.contentTypeErr)
}
return strings.Join(errMsgs, "; ")
}
func (d *decodeState) addMetadata(k, v string) {
if d.mdata == nil {
d.mdata = make(map[string][]string)
if d.data.mdata == nil {
d.data.mdata = make(map[string][]string)
}
d.mdata[k] = append(d.mdata[k], v)
d.data.mdata[k] = append(d.data.mdata[k], v)
}
func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
func (d *decodeState) processHeaderField(f hpack.HeaderField) {
switch f.Name {
case "content-type":
contentSubtype, validContentType := contentSubtype(f.Value)
if !validContentType {
return status.Errorf(codes.Internal, "transport: received the unexpected content-type %q", f.Value)
d.data.contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", f.Value)
return
}
d.contentSubtype = contentSubtype
d.data.contentSubtype = contentSubtype
// TODO: do we want to propagate the whole content-type in the metadata,
// or come up with a way to just propagate the content-subtype if it was set?
// ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"}
// in the metadata?
d.addMetadata(f.Name, f.Value)
d.data.isGRPC = true
case "grpc-encoding":
d.encoding = f.Value
d.data.encoding = f.Value
case "grpc-status":
code, err := strconv.Atoi(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err)
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err)
return
}
d.rawStatusCode = &code
d.data.rawStatusCode = &code
case "grpc-message":
d.rawStatusMsg = decodeGrpcMessage(f.Value)
d.data.rawStatusMsg = decodeGrpcMessage(f.Value)
case "grpc-status-details-bin":
v, err := decodeBinHeader(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
return
}
s := &spb.Status{}
if err := proto.Unmarshal(v, s); err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
return
}
d.statusGen = status.FromProto(s)
d.data.statusGen = status.FromProto(s)
case "grpc-timeout":
d.timeoutSet = true
d.data.timeoutSet = true
var err error
if d.timeout, err = decodeTimeout(f.Value); err != nil {
return status.Errorf(codes.Internal, "transport: malformed time-out: %v", err)
if d.data.timeout, err = decodeTimeout(f.Value); err != nil {
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed time-out: %v", err)
}
case ":path":
d.method = f.Value
d.data.method = f.Value
case ":status":
code, err := strconv.Atoi(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed http-status: %v", err)
d.data.httpErr = status.Errorf(codes.Internal, "transport: malformed http-status: %v", err)
return
}
d.httpStatus = &code
d.data.httpStatus = &code
case "grpc-tags-bin":
v, err := decodeBinHeader(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err)
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err)
return
}
d.statsTags = v
d.data.statsTags = v
d.addMetadata(f.Name, string(v))
case "grpc-trace-bin":
v, err := decodeBinHeader(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err)
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err)
return
}
d.statsTrace = v
d.data.statsTrace = v
d.addMetadata(f.Name, string(v))
default:
if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) {
@@ -358,11 +413,10 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
v, err := decodeMetadataHeader(f.Name, f.Value)
if err != nil {
errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err)
return nil
return
}
d.addMetadata(f.Name, v)
}
return nil
}
type timeoutUnit uint8
@@ -435,6 +489,10 @@ func decodeTimeout(s string) (time.Duration, error) {
if size < 2 {
return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
}
if size > 9 {
// Spec allows for 8 digits plus the unit.
return 0, fmt.Errorf("transport: timeout string is too long: %q", s)
}
unit := timeoutUnit(s[size-1])
d, ok := timeoutUnitToDuration(unit)
if !ok {
@@ -444,6 +502,11 @@ func decodeTimeout(s string) (time.Duration, error) {
if err != nil {
return 0, err
}
const maxHours = math.MaxInt64 / int64(time.Hour)
if d == time.Hour && t > maxHours {
// This timeout would overflow math.MaxInt64; clamp it.
return time.Duration(math.MaxInt64), nil
}
return d * time.Duration(t), nil
}
@@ -604,6 +667,7 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderList
writer: w,
fr: http2.NewFramer(w, r),
}
f.fr.SetMaxReadFrameSize(http2MaxFrameLen)
// Opt-in to Frame reuse API on framer to reduce garbage.
// Frames aren't safe to read from after a subsequent call to ReadFrame.
f.fr.SetReuseFrames()

View File

@@ -22,6 +22,8 @@
package transport
import (
"bytes"
"context"
"errors"
"fmt"
"io"
@@ -29,7 +31,6 @@ import (
"sync"
"sync/atomic"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
@@ -39,10 +40,32 @@ import (
"google.golang.org/grpc/tap"
)
type bufferPool struct {
pool sync.Pool
}
func newBufferPool() *bufferPool {
return &bufferPool{
pool: sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
},
}
}
func (p *bufferPool) get() *bytes.Buffer {
return p.pool.Get().(*bytes.Buffer)
}
func (p *bufferPool) put(b *bytes.Buffer) {
p.pool.Put(b)
}
// recvMsg represents the received msg from the transport. All transport
// protocol specific info has been removed.
type recvMsg struct {
data []byte
buffer *bytes.Buffer
// nil: received some data
// io.EOF: stream is completed. data is nil.
// other non-nil error: transport failure. data is nil.
@@ -50,10 +73,11 @@ type recvMsg struct {
}
// recvBuffer is an unbounded channel of recvMsg structs.
// Note recvBuffer differs from controlBuffer only in that recvBuffer
// holds a channel of only recvMsg structs instead of objects implementing "item" interface.
// recvBuffer is written to much more often than
// controlBuffer and using strict recvMsg structs helps avoid allocation in "recvBuffer.put"
//
// Note: recvBuffer differs from buffer.Unbounded only in the fact that it
// holds a channel of recvMsg structs instead of objects implementing "item"
// interface. recvBuffer is written to much more often and using strict recvMsg
// structs helps avoid allocation in "recvBuffer.put"
type recvBuffer struct {
c chan recvMsg
mu sync.Mutex
@@ -110,15 +134,16 @@ func (b *recvBuffer) get() <-chan recvMsg {
return b.c
}
//
// recvBufferReader implements io.Reader interface to read the data from
// recvBuffer.
type recvBufferReader struct {
ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer
last []byte // Stores the remaining data in the previous calls.
err error
closeStream func(error) // Closes the client transport stream with the given error and nil trailer metadata.
ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer
last *bytes.Buffer // Stores the remaining data in the previous calls.
err error
freeBuffer func(*bytes.Buffer)
}
// Read reads the next len(p) bytes from last. If last is drained, it tries to
@@ -128,31 +153,74 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
if r.err != nil {
return 0, r.err
}
n, r.err = r.read(p)
if r.last != nil {
// Read remaining data left in last call.
copied, _ := r.last.Read(p)
if r.last.Len() == 0 {
r.freeBuffer(r.last)
r.last = nil
}
return copied, nil
}
if r.closeStream != nil {
n, r.err = r.readClient(p)
} else {
n, r.err = r.read(p)
}
return n, r.err
}
func (r *recvBufferReader) read(p []byte) (n int, err error) {
if r.last != nil && len(r.last) > 0 {
// Read remaining data left in last call.
copied := copy(p, r.last)
r.last = r.last[copied:]
return copied, nil
}
select {
case <-r.ctxDone:
return 0, ContextErr(r.ctx.Err())
case m := <-r.recv.get():
r.recv.load()
if m.err != nil {
return 0, m.err
}
copied := copy(p, m.data)
r.last = m.data[copied:]
return copied, nil
return r.readAdditional(m, p)
}
}
func (r *recvBufferReader) readClient(p []byte) (n int, err error) {
// If the context is canceled, then closes the stream with nil metadata.
// closeStream writes its error parameter to r.recv as a recvMsg.
// r.readAdditional acts on that message and returns the necessary error.
select {
case <-r.ctxDone:
// Note that this adds the ctx error to the end of recv buffer, and
// reads from the head. This will delay the error until recv buffer is
// empty, thus will delay ctx cancellation in Recv().
//
// It's done this way to fix a race between ctx cancel and trailer. The
// race was, stream.Recv() may return ctx error if ctxDone wins the
// race, but stream.Trailer() may return a non-nil md because the stream
// was not marked as done when trailer is received. This closeStream
// call will mark stream as done, thus fix the race.
//
// TODO: delaying ctx error seems like a unnecessary side effect. What
// we really want is to mark the stream as done, and return ctx error
// faster.
r.closeStream(ContextErr(r.ctx.Err()))
m := <-r.recv.get()
return r.readAdditional(m, p)
case m := <-r.recv.get():
return r.readAdditional(m, p)
}
}
func (r *recvBufferReader) readAdditional(m recvMsg, p []byte) (n int, err error) {
r.recv.load()
if m.err != nil {
return 0, m.err
}
copied, _ := m.buffer.Read(p)
if m.buffer.Len() == 0 {
r.freeBuffer(m.buffer)
r.last = nil
} else {
r.last = m.buffer
}
return copied, nil
}
type streamState uint32
const (
@@ -166,6 +234,7 @@ const (
type Stream struct {
id uint32
st ServerTransport // nil for client side Stream
ct *http2Client // nil for server side Stream
ctx context.Context // the associated context of the stream
cancel context.CancelFunc // always nil for client side Stream
done chan struct{} // closed at the end of stream to unblock writers. On the client side.
@@ -182,12 +251,20 @@ type Stream struct {
// is used to adjust flow control, if needed.
requestRead func(int)
headerChan chan struct{} // closed to indicate the end of header metadata.
headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
headerChan chan struct{} // closed to indicate the end of header metadata.
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
// headerValid indicates whether a valid header was received. Only
// meaningful after headerChan is closed (always call waitOnHeader() before
// reading its value). Not valid on server side.
headerValid bool
// hdrMu protects header and trailer metadata on the server-side.
hdrMu sync.Mutex
header metadata.MD // the received header metadata.
hdrMu sync.Mutex
// On client side, header keeps the received header metadata.
//
// On server side, header keeps the header set by SetHeader(). The complete
// header will merged into this after t.WriteHeader() is called.
header metadata.MD
trailer metadata.MD // the key-value map of trailer metadata.
noHeaders bool // set if the client never received headers (set only after the stream is done).
@@ -232,26 +309,28 @@ func (s *Stream) getState() streamState {
return streamState(atomic.LoadUint32((*uint32)(&s.state)))
}
func (s *Stream) waitOnHeader() error {
func (s *Stream) waitOnHeader() {
if s.headerChan == nil {
// On the server headerChan is always nil since a stream originates
// only after having received headers.
return nil
return
}
select {
case <-s.ctx.Done():
return ContextErr(s.ctx.Err())
// Close the stream to prevent headers/trailers from changing after
// this function returns.
s.ct.CloseStream(s, ContextErr(s.ctx.Err()))
// headerChan could possibly not be closed yet if closeStream raced
// with operateHeaders; wait until it is closed explicitly here.
<-s.headerChan
case <-s.headerChan:
return nil
}
}
// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *Stream) RecvCompress() string {
if err := s.waitOnHeader(); err != nil {
return ""
}
s.waitOnHeader()
return s.recvCompress
}
@@ -266,34 +345,33 @@ func (s *Stream) Done() <-chan struct{} {
return s.done
}
// Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is canceled/expired.
// Header returns the header metadata of the stream.
//
// On client side, it acquires the key-value pairs of header metadata once it is
// available. It blocks until i) the metadata is ready or ii) there is no header
// metadata or iii) the stream is canceled/expired.
//
// On server side, it returns the out header after t.WriteHeader is called. It
// does not block and must not be called until after WriteHeader.
func (s *Stream) Header() (metadata.MD, error) {
err := s.waitOnHeader()
// Even if the stream is closed, header is returned if available.
select {
case <-s.headerChan:
if s.header == nil {
return nil, nil
}
if s.headerChan == nil {
// On server side, return the header in stream. It will be the out
// header after t.WriteHeader is called.
return s.header.Copy(), nil
default:
}
return nil, err
s.waitOnHeader()
if !s.headerValid {
return nil, s.status.Err()
}
return s.header.Copy(), nil
}
// TrailersOnly blocks until a header or trailers-only frame is received and
// then returns true if the stream was trailers-only. If the stream ends
// before headers are received, returns true, nil. If a context error happens
// first, returns it as a status error. Client-side only.
func (s *Stream) TrailersOnly() (bool, error) {
err := s.waitOnHeader()
if err != nil {
return false, err
}
// if !headerDone, some other connection error occurred.
return s.noHeaders && atomic.LoadUint32(&s.headerDone) == 1, nil
// before headers are received, returns true, nil. Client-side only.
func (s *Stream) TrailersOnly() bool {
s.waitOnHeader()
return s.noHeaders
}
// Trailer returns the cached trailer metedata. Note that if it is not called
@@ -447,6 +525,7 @@ type ServerConfig struct {
ReadBufferSize int
ChannelzParentID int64
MaxHeaderListSize *uint32
HeaderTableSize *uint32
}
// NewServerTransport creates a ServerTransport with conn or non-nil error
@@ -465,8 +544,12 @@ type ConnectOptions struct {
FailOnNonTempDialError bool
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client connection.
// TransportCredentials stores the Authenticator required to setup a client
// connection. Only one of TransportCredentials and CredsBundle is non-nil.
TransportCredentials credentials.TransportCredentials
// CredsBundle is the credentials bundle to be used. Only one of
// TransportCredentials and CredsBundle is non-nil.
CredsBundle credentials.Bundle
// KeepaliveParams stores the keepalive parameters.
KeepaliveParams keepalive.ClientParameters
// StatsHandler stores the handler for stats.
@@ -494,8 +577,8 @@ type TargetInfo struct {
// NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller.
func NewClientTransport(connectCtx, ctx context.Context, target TargetInfo, opts ConnectOptions, onSuccess func()) (ClientTransport, error) {
return newHTTP2Client(connectCtx, ctx, target, opts, onSuccess)
func NewClientTransport(connectCtx, ctx context.Context, target TargetInfo, opts ConnectOptions, onPrefaceReceipt func(), onGoAway func(GoAwayReason), onClose func()) (ClientTransport, error) {
return newHTTP2Client(connectCtx, ctx, target, opts, onPrefaceReceipt, onGoAway, onClose)
}
// Options provides additional hints and information for message
@@ -540,9 +623,12 @@ type ClientTransport interface {
// is called only once.
Close() error
// GracefulClose starts to tear down the transport. It stops accepting
// new RPCs and wait the completion of the pending RPCs.
GracefulClose() error
// GracefulClose starts to tear down the transport: the transport will stop
// accepting new RPCs and NewStream will return error. Once all streams are
// finished, the transport will close.
//
// It does not block.
GracefulClose()
// Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole.
@@ -572,6 +658,9 @@ type ClientTransport interface {
// GetGoAwayReason returns the reason why GoAway frame was received.
GetGoAwayReason() GoAwayReason
// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr
// IncrMsgSent increments the number of message sent through this transport.
IncrMsgSent()
@@ -706,3 +795,14 @@ type channelzData struct {
lastMsgSentTime int64
lastMsgRecvTime int64
}
// ContextErr converts the error from context package into a status error.
func ContextErr(err error) error {
switch err {
case context.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled:
return status.Error(codes.Canceled, err.Error())
}
return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err)
}