Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement STARTTLS connection upgrading #19

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package config

import (
"crypto/tls"
"encoding/json"
"flag"
"io/ioutil"
"log"
"strings"

"github.com/ian-kent/envconf"
"github.com/mailhog/MailHog-Server/monkey"
Expand Down Expand Up @@ -49,6 +51,8 @@ type Config struct {
OutgoingSMTPFile string
OutgoingSMTP map[string]*OutgoingSMTP
WebPath string
CertsPaths string
TLSConfig *tls.Config
}

// OutgoingSMTP is an outgoing SMTP server config
Expand Down Expand Up @@ -112,6 +116,32 @@ func Configure() *Config {
cfg.OutgoingSMTP = o
}

if cfg.CertsPaths != "" {
pairCandidates := strings.Split(cfg.CertsPaths, ";")

if len(pairCandidates) > 0 {
certificates := make([]tls.Certificate, len(pairCandidates))
for i, pairCandidate := range pairCandidates {
pair := strings.Split(pairCandidate, ",")

if len(pair) != 2 {
log.Fatalf("Certificate path pair %d must be in form certPath,keyPath", i)
}

cert, err := tls.LoadX509KeyPair(pair[0], pair[1])
if err != nil {
log.Fatal(err)
}

certificates[i] = cert
}

cfg.TLSConfig = &tls.Config{
Certificates: certificates,
}
}
}

return cfg
}

Expand All @@ -128,5 +158,6 @@ func RegisterFlags() {
flag.StringVar(&cfg.MaildirPath, "maildir-path", envconf.FromEnvP("MH_MAILDIR_PATH", "").(string), "Maildir path (if storage type is 'maildir')")
flag.BoolVar(&cfg.InviteJim, "invite-jim", envconf.FromEnvP("MH_INVITE_JIM", false).(bool), "Decide whether to invite Jim (beware, he causes trouble)")
flag.StringVar(&cfg.OutgoingSMTPFile, "outgoing-smtp", envconf.FromEnvP("MH_OUTGOING_SMTP", "").(string), "JSON file containing outgoing SMTP servers")
flag.StringVar(&cfg.CertsPaths, "certs-paths", envconf.FromEnvP("MH_CERTS_PATHS", "").(string), "A comma separated list of tls certificates, in schema cert1Path,key1Path;cert1Path,key2Path ... etc")
Jim.RegisterFlags()
}
28 changes: 26 additions & 2 deletions smtp/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type Session struct {
}

// Accept starts a new SMTP session using io.ReadWriteCloser
func Accept(remoteAddress string, conn io.ReadWriteCloser, storage storage.Storage, messageChan chan *data.Message, hostname string, monkey monkey.ChaosMonkey) {
func Accept(remoteAddress string, conn io.ReadWriteCloser, tlsUpgrader func() io.ReadWriteCloser, storage storage.Storage, messageChan chan *data.Message, hostname string, monkey monkey.ChaosMonkey) {
defer conn.Close()

proto := smtp.NewProtocol()
Expand All @@ -56,10 +56,30 @@ func Accept(remoteAddress string, conn io.ReadWriteCloser, storage storage.Stora
proto.ValidateAuthenticationHandler = session.validateAuthentication
proto.GetAuthenticationMechanismsHandler = func() []string { return []string{"PLAIN"} }

if tlsUpgrader != nil {
proto.TLSHandler = func(done func(ok bool)) (errorReply *smtp.Reply, callback func(), ok bool) {
done(true)
return nil, func() {
newCon := tlsUpgrader()

session.reader = io.Reader(newCon)
session.writer = io.Writer(newCon)
if monkey != nil {
linkSpeed := monkey.LinkSpeed()
if linkSpeed != nil {
link = linkio.NewLink(*linkSpeed * linkio.BytePerSecond)
session.reader = link.NewLinkReader(io.Reader(newCon))
session.writer = link.NewLinkWriter(io.Writer(newCon))
}
}
}, true
}
}

session.logf("Starting session")
session.Write(proto.Start())
for session.Read() == true {
if monkey != nil && monkey.Disconnect != nil && monkey.Disconnect() {
if monkey != nil && monkey.Disconnect() {
session.conn.Close()
break
}
Expand Down Expand Up @@ -160,4 +180,8 @@ func (c *Session) Write(reply *smtp.Reply) {
c.logf("Sent %d bytes: '%s'", len(l), logText)
c.writer.Write([]byte(l))
}

if reply.Done != nil {
reply.Done()
}
}
145 changes: 111 additions & 34 deletions smtp/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package smtp

import (
"errors"
"io"
"sync"
"testing"

Expand Down Expand Up @@ -40,7 +41,7 @@ func TestAccept(t *testing.T) {
Convey("Accept should handle a connection", t, func() {
frw := &fakeRw{}
mChan := make(chan *data.Message)
Accept("1.1.1.1:11111", frw, storage.CreateInMemory(), mChan, "localhost", nil)
Accept("1.1.1.1:11111", frw, nil, storage.CreateInMemory(), mChan, "localhost", nil)
})
}

Expand All @@ -52,58 +53,134 @@ func TestSocketError(t *testing.T) {
},
}
mChan := make(chan *data.Message)
Accept("1.1.1.1:11111", frw, storage.CreateInMemory(), mChan, "localhost", nil)
Accept("1.1.1.1:11111", frw, nil, storage.CreateInMemory(), mChan, "localhost", nil)
})
}

func TestAcceptMessage(t *testing.T) {
Convey("acceptMessage should be called", t, func() {
mbuf := "EHLO localhost\nMAIL FROM:<test>\nRCPT TO:<test>\nDATA\nHi.\r\n.\r\nQUIT\n"
var rbuf []byte
frw := &fakeRw{
_read: func(p []byte) (n int, err error) {
if len(p) >= len(mbuf) {
ba := []byte(mbuf)
mbuf = ""
for i, b := range ba {
p[i] = b
}
return len(ba), nil
}
mbuf := "EHLO localhost\r\n" +
"MAIL FROM:<test>\r\n" +
"RCPT TO:<test>\r\n" +
"DATA\r\n" +
"Hi.\r\n" +
".\r\n" +
"QUIT\n"

frw, obuf := getBuffer(mbuf)
mChan := make(chan *data.Message)
var wg sync.WaitGroup
wg.Add(1)
handlerCalled := false
var storedMessage *data.Message
go func() {
handlerCalled = true
storedMessage = <-mChan
wg.Done()
}()
Accept("1.1.1.1:11111", frw, nil, storage.CreateInMemory(), mChan, "localhost", nil)
wg.Wait()

ba := []byte(mbuf[0:len(p)])
mbuf = mbuf[len(p):]
for i, b := range ba {
p[i] = b
}
return len(ba), nil
},
_write: func(p []byte) (n int, err error) {
rbuf = append(rbuf, p...)
return len(p), nil
},
_close: func() error {
return nil
},
}
So(handlerCalled, ShouldBeTrue)

So(storedMessage, ShouldNotBeNil)
So(string(*obuf), ShouldEqual,
"220 localhost ESMTP MailHog\r\n"+
"250-Hello localhost\r\n"+
"250-PIPELINING\r\n"+
"250 AUTH PLAIN\r\n"+
"250 Sender test ok\r\n"+
"250 Recipient test ok\r\n"+
"354 End data with <CR><LF>.<CR><LF>\r\n"+
"250 Ok: queued as "+storedMessage.ID+"\r\n",
)
})
}

func TestAcceptTLSUpgrade(t *testing.T) {
Convey("acceptMessage should be called", t, func() {
mbuf1 := "STARTTLS\r\n"
mbuf2 := "EHLO localhost\r\n" +
"MAIL FROM:<test>\r\n" +
"RCPT TO:<test>\r\n" +
"DATA\r\n" +
"Hi.\r\n" +
".\r\n" +
"QUIT\n"

frw1, obuf1 := getBuffer(mbuf1)
frw2, obuf2 := getBuffer(mbuf2)
mChan := make(chan *data.Message)
var wg sync.WaitGroup
wg.Add(1)
handlerCalled := false
var storedMessage *data.Message
go func() {
handlerCalled = true
<-mChan
//FIXME breaks some tests (in drone.io)
//m := <-mChan
//So(m, ShouldNotBeNil)
storedMessage = <-mChan
wg.Done()
}()
Accept("1.1.1.1:11111", frw, storage.CreateInMemory(), mChan, "localhost", nil)

tlsWasUpgraded := false
tlsUpgrade := func() io.ReadWriteCloser {
tlsWasUpgraded = true
return frw2
}

Accept("1.1.1.1:11111", frw1, tlsUpgrade, storage.CreateInMemory(), mChan, "localhost", nil)
wg.Wait()

So(handlerCalled, ShouldBeTrue)
So(tlsWasUpgraded, ShouldBeTrue)

So(storedMessage, ShouldNotBeNil)
So(string(*obuf1), ShouldEqual,
"220 localhost ESMTP MailHog\r\n"+
"220 Ready to start TLS\r\n",
)
So(string(*obuf2), ShouldEqual,
"250-Hello localhost\r\n"+
"250-PIPELINING\r\n"+
"250 AUTH PLAIN\r\n"+
"250 Sender test ok\r\n"+
"250 Recipient test ok\r\n"+
"354 End data with <CR><LF>.<CR><LF>\r\n"+
"250 Ok: queued as "+storedMessage.ID+"\r\n",
)
})
}

func getBuffer(input string) (io.ReadWriteCloser, *[]byte) {
var rbuf []byte
frw := &fakeRw{
_read: func(p []byte) (n int, err error) {
if len(p) >= len(input) {
ba := []byte(input)
input = ""
for i, b := range ba {
p[i] = b
}
return len(ba), nil
}

ba := []byte(input[0:len(p)])
input = input[len(p):]
for i, b := range ba {
p[i] = b
}
return len(ba), nil
},
_write: func(p []byte) (n int, err error) {
rbuf = append(rbuf, p...)
return len(p), nil
},
_close: func() error {
return nil
},
}
return frw, &rbuf
}

func TestValidateAuthentication(t *testing.T) {
Convey("validateAuthentication is always successful", t, func() {
c := &Session{}
Expand Down
9 changes: 9 additions & 0 deletions smtp/smtp.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package smtp

import (
"crypto/tls"
"io"
"log"
"net"
Expand Down Expand Up @@ -31,9 +32,17 @@ func Listen(cfg *config.Config, exitCh chan int) *net.TCPListener {
}
}

var tlsUpgrader func() io.ReadWriteCloser
if cfg.TLSConfig != nil {
tlsUpgrader = func() io.ReadWriteCloser {
return io.ReadWriteCloser(tls.Server(conn, cfg.TLSConfig))
}
}

go Accept(
conn.(*net.TCPConn).RemoteAddr().String(),
io.ReadWriteCloser(conn),
tlsUpgrader,
cfg.Storage,
cfg.MessageChan,
cfg.Hostname,
Expand Down