aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/crypto/tls/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/crypto/tls/conn.go')
-rw-r--r--libgo/go/crypto/tls/conn.go103
1 files changed, 81 insertions, 22 deletions
diff --git a/libgo/go/crypto/tls/conn.go b/libgo/go/crypto/tls/conn.go
index e3dcf15400c..03775685fb6 100644
--- a/libgo/go/crypto/tls/conn.go
+++ b/libgo/go/crypto/tls/conn.go
@@ -16,6 +16,7 @@ import (
"io"
"net"
"sync"
+ "sync/atomic"
"time"
)
@@ -56,6 +57,11 @@ type Conn struct {
input *block // application data waiting to be read
hand bytes.Buffer // handshake data waiting to be read
+ // activeCall is an atomic int32; the low bit is whether Close has
+ // been called. the rest of the bits are the number of goroutines
+ // in Conn.Write.
+ activeCall int32
+
tmp [16]byte
}
@@ -98,12 +104,13 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
type halfConn struct {
sync.Mutex
- err error // first permanent error
- version uint16 // protocol version
- cipher interface{} // cipher algorithm
- mac macFunction
- seq [8]byte // 64-bit sequence number
- bfree *block // list of free blocks
+ err error // first permanent error
+ version uint16 // protocol version
+ cipher interface{} // cipher algorithm
+ mac macFunction
+ seq [8]byte // 64-bit sequence number
+ bfree *block // list of free blocks
+ additionalData [13]byte // to avoid allocs; interface method args escape
nextCipher interface{} // next encryption state
nextMac macFunction // next MAC algorithm
@@ -262,14 +269,13 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert)
nonce := payload[:8]
payload = payload[8:]
- var additionalData [13]byte
- copy(additionalData[:], hc.seq[:])
- copy(additionalData[8:], b.data[:3])
+ copy(hc.additionalData[:], hc.seq[:])
+ copy(hc.additionalData[8:], b.data[:3])
n := len(payload) - c.Overhead()
- additionalData[11] = byte(n >> 8)
- additionalData[12] = byte(n)
+ hc.additionalData[11] = byte(n >> 8)
+ hc.additionalData[12] = byte(n)
var err error
- payload, err = c.Open(payload[:0], nonce, payload, additionalData[:])
+ payload, err = c.Open(payload[:0], nonce, payload, hc.additionalData[:])
if err != nil {
return false, 0, alertBadRecordMAC
}
@@ -378,13 +384,12 @@ func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) {
payload := b.data[recordHeaderLen+explicitIVLen:]
payload = payload[:payloadLen]
- var additionalData [13]byte
- copy(additionalData[:], hc.seq[:])
- copy(additionalData[8:], b.data[:3])
- additionalData[11] = byte(payloadLen >> 8)
- additionalData[12] = byte(payloadLen)
+ copy(hc.additionalData[:], hc.seq[:])
+ copy(hc.additionalData[8:], b.data[:3])
+ hc.additionalData[11] = byte(payloadLen >> 8)
+ hc.additionalData[12] = byte(payloadLen)
- c.Seal(payload[:0], nonce, payload, additionalData[:])
+ c.Seal(payload[:0], nonce, payload, hc.additionalData[:])
case cbcMode:
blockSize := c.BlockSize()
if explicitIVLen > 0 {
@@ -507,6 +512,23 @@ func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
return b, bb
}
+// RecordHeaderError results when a TLS record header is invalid.
+type RecordHeaderError struct {
+ // Msg contains a human readable string that describes the error.
+ Msg string
+ // RecordHeader contains the five bytes of TLS record header that
+ // triggered the error.
+ RecordHeader [5]byte
+}
+
+func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
+
+func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) {
+ err.Msg = msg
+ copy(err.RecordHeader[:], c.rawInput.data)
+ return err
+}
+
// readRecord reads the next TLS record from the connection
// and updates the record layer state.
// c.in.Mutex <= L; c.input == nil.
@@ -557,18 +579,20 @@ Again:
// an SSLv2 client.
if want == recordTypeHandshake && typ == 0x80 {
c.sendAlert(alertProtocolVersion)
- return c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received"))
+ return c.in.setErrorLocked(c.newRecordHeaderError("unsupported SSLv2 handshake received"))
}
vers := uint16(b.data[1])<<8 | uint16(b.data[2])
n := int(b.data[3])<<8 | int(b.data[4])
if c.haveVers && vers != c.vers {
c.sendAlert(alertProtocolVersion)
- return c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers))
+ msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
+ return c.in.setErrorLocked(c.newRecordHeaderError(msg))
}
if n > maxCiphertext {
c.sendAlert(alertRecordOverflow)
- return c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n))
+ msg := fmt.Sprintf("oversized record received with length %d", n)
+ return c.in.setErrorLocked(c.newRecordHeaderError(msg))
}
if !c.haveVers {
// First message, be extra suspicious: this might not be a TLS
@@ -577,7 +601,7 @@ Again:
// it's probably not real.
if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 {
c.sendAlert(alertUnexpectedMessage)
- return c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake"))
+ return c.in.setErrorLocked(c.newRecordHeaderError("first record does not look like a TLS handshake"))
}
}
if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
@@ -837,8 +861,22 @@ func (c *Conn) readHandshake() (interface{}, error) {
return m, nil
}
+var errClosed = errors.New("crypto/tls: use of closed connection")
+
// Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) {
+ // interlock with Close below
+ for {
+ x := atomic.LoadInt32(&c.activeCall)
+ if x&1 != 0 {
+ return 0, errClosed
+ }
+ if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
+ defer atomic.AddInt32(&c.activeCall, -2)
+ break
+ }
+ }
+
if err := c.Handshake(); err != nil {
return 0, err
}
@@ -942,6 +980,27 @@ func (c *Conn) Read(b []byte) (n int, err error) {
// Close closes the connection.
func (c *Conn) Close() error {
+ // Interlock with Conn.Write above.
+ var x int32
+ for {
+ x = atomic.LoadInt32(&c.activeCall)
+ if x&1 != 0 {
+ return errClosed
+ }
+ if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
+ break
+ }
+ }
+ if x != 0 {
+ // io.Writer and io.Closer should not be used concurrently.
+ // If Close is called while a Write is currently in-flight,
+ // interpret that as a sign that this Close is really just
+ // being used to break the Write and/or clean up resources and
+ // avoid sending the alertCloseNotify, which may block
+ // waiting on handshakeMutex or the c.out mutex.
+ return c.conn.Close()
+ }
+
var alertErr error
c.handshakeMutex.Lock()