diff options
Diffstat (limited to 'libgo/go/crypto/tls/conn.go')
-rw-r--r-- | libgo/go/crypto/tls/conn.go | 103 |
1 files changed, 81 insertions, 22 deletions
diff --git a/libgo/go/crypto/tls/conn.go b/libgo/go/crypto/tls/conn.go index e3dcf15400c..03775685fb6 100644 --- a/libgo/go/crypto/tls/conn.go +++ b/libgo/go/crypto/tls/conn.go @@ -16,6 +16,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" ) @@ -56,6 +57,11 @@ type Conn struct { input *block // application data waiting to be read hand bytes.Buffer // handshake data waiting to be read + // activeCall is an atomic int32; the low bit is whether Close has + // been called. the rest of the bits are the number of goroutines + // in Conn.Write. + activeCall int32 + tmp [16]byte } @@ -98,12 +104,13 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { type halfConn struct { sync.Mutex - err error // first permanent error - version uint16 // protocol version - cipher interface{} // cipher algorithm - mac macFunction - seq [8]byte // 64-bit sequence number - bfree *block // list of free blocks + err error // first permanent error + version uint16 // protocol version + cipher interface{} // cipher algorithm + mac macFunction + seq [8]byte // 64-bit sequence number + bfree *block // list of free blocks + additionalData [13]byte // to avoid allocs; interface method args escape nextCipher interface{} // next encryption state nextMac macFunction // next MAC algorithm @@ -262,14 +269,13 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) nonce := payload[:8] payload = payload[8:] - var additionalData [13]byte - copy(additionalData[:], hc.seq[:]) - copy(additionalData[8:], b.data[:3]) + copy(hc.additionalData[:], hc.seq[:]) + copy(hc.additionalData[8:], b.data[:3]) n := len(payload) - c.Overhead() - additionalData[11] = byte(n >> 8) - additionalData[12] = byte(n) + hc.additionalData[11] = byte(n >> 8) + hc.additionalData[12] = byte(n) var err error - payload, err = c.Open(payload[:0], nonce, payload, additionalData[:]) + payload, err = c.Open(payload[:0], nonce, payload, hc.additionalData[:]) if err != nil { return false, 0, alertBadRecordMAC } @@ -378,13 +384,12 @@ func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) { payload := b.data[recordHeaderLen+explicitIVLen:] payload = payload[:payloadLen] - var additionalData [13]byte - copy(additionalData[:], hc.seq[:]) - copy(additionalData[8:], b.data[:3]) - additionalData[11] = byte(payloadLen >> 8) - additionalData[12] = byte(payloadLen) + copy(hc.additionalData[:], hc.seq[:]) + copy(hc.additionalData[8:], b.data[:3]) + hc.additionalData[11] = byte(payloadLen >> 8) + hc.additionalData[12] = byte(payloadLen) - c.Seal(payload[:0], nonce, payload, additionalData[:]) + c.Seal(payload[:0], nonce, payload, hc.additionalData[:]) case cbcMode: blockSize := c.BlockSize() if explicitIVLen > 0 { @@ -507,6 +512,23 @@ func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) { return b, bb } +// RecordHeaderError results when a TLS record header is invalid. +type RecordHeaderError struct { + // Msg contains a human readable string that describes the error. + Msg string + // RecordHeader contains the five bytes of TLS record header that + // triggered the error. + RecordHeader [5]byte +} + +func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } + +func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) { + err.Msg = msg + copy(err.RecordHeader[:], c.rawInput.data) + return err +} + // readRecord reads the next TLS record from the connection // and updates the record layer state. // c.in.Mutex <= L; c.input == nil. @@ -557,18 +579,20 @@ Again: // an SSLv2 client. if want == recordTypeHandshake && typ == 0x80 { c.sendAlert(alertProtocolVersion) - return c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received")) + return c.in.setErrorLocked(c.newRecordHeaderError("unsupported SSLv2 handshake received")) } vers := uint16(b.data[1])<<8 | uint16(b.data[2]) n := int(b.data[3])<<8 | int(b.data[4]) if c.haveVers && vers != c.vers { c.sendAlert(alertProtocolVersion) - return c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers)) + msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) + return c.in.setErrorLocked(c.newRecordHeaderError(msg)) } if n > maxCiphertext { c.sendAlert(alertRecordOverflow) - return c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n)) + msg := fmt.Sprintf("oversized record received with length %d", n) + return c.in.setErrorLocked(c.newRecordHeaderError(msg)) } if !c.haveVers { // First message, be extra suspicious: this might not be a TLS @@ -577,7 +601,7 @@ Again: // it's probably not real. if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 { c.sendAlert(alertUnexpectedMessage) - return c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake")) + return c.in.setErrorLocked(c.newRecordHeaderError("first record does not look like a TLS handshake")) } } if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil { @@ -837,8 +861,22 @@ func (c *Conn) readHandshake() (interface{}, error) { return m, nil } +var errClosed = errors.New("crypto/tls: use of closed connection") + // Write writes data to the connection. func (c *Conn) Write(b []byte) (int, error) { + // interlock with Close below + for { + x := atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return 0, errClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { + defer atomic.AddInt32(&c.activeCall, -2) + break + } + } + if err := c.Handshake(); err != nil { return 0, err } @@ -942,6 +980,27 @@ func (c *Conn) Read(b []byte) (n int, err error) { // Close closes the connection. func (c *Conn) Close() error { + // Interlock with Conn.Write above. + var x int32 + for { + x = atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return errClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) { + break + } + } + if x != 0 { + // io.Writer and io.Closer should not be used concurrently. + // If Close is called while a Write is currently in-flight, + // interpret that as a sign that this Close is really just + // being used to break the Write and/or clean up resources and + // avoid sending the alertCloseNotify, which may block + // waiting on handshakeMutex or the c.out mutex. + return c.conn.Close() + } + var alertErr error c.handshakeMutex.Lock() |