diff --git a/auth.go b/auth.go deleted file mode 100644 index d28b83a..0000000 --- a/auth.go +++ /dev/null @@ -1,49 +0,0 @@ -package gomail - -import ( - "bytes" - "errors" - "fmt" - "net/smtp" -) - -// loginAuth is an smtp.Auth that implements the LOGIN authentication mechanism. -type loginAuth struct { - username string - password string - host string -} - -func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { - if !server.TLS { - advertised := false - for _, mechanism := range server.Auth { - if mechanism == "LOGIN" { - advertised = true - break - } - } - if !advertised { - return "", nil, errors.New("gomail: unencrypted connection") - } - } - if server.Name != a.host { - return "", nil, errors.New("gomail: wrong host name") - } - return "LOGIN", nil, nil -} - -func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { - if !more { - return nil, nil - } - - switch { - case bytes.Equal(fromServer, []byte("Username:")): - return []byte(a.username), nil - case bytes.Equal(fromServer, []byte("Password:")): - return []byte(a.password), nil - default: - return nil, fmt.Errorf("gomail: unexpected server challenge: %s", fromServer) - } -} diff --git a/auth_test.go b/auth_test.go deleted file mode 100644 index 428ef34..0000000 --- a/auth_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package gomail - -import ( - "net/smtp" - "testing" -) - -const ( - testUser = "user" - testPwd = "pwd" - testHost = "smtp.example.com" -) - -type authTest struct { - auths []string - challenges []string - tls bool - wantData []string - wantError bool -} - -func TestNoAdvertisement(t *testing.T) { - testLoginAuth(t, &authTest{ - auths: []string{}, - tls: false, - wantError: true, - }) -} - -func TestNoAdvertisementTLS(t *testing.T) { - testLoginAuth(t, &authTest{ - auths: []string{}, - challenges: []string{"Username:", "Password:"}, - tls: true, - wantData: []string{"", testUser, testPwd}, - }) -} - -func TestLogin(t *testing.T) { - testLoginAuth(t, &authTest{ - auths: []string{"PLAIN", "LOGIN"}, - challenges: []string{"Username:", "Password:"}, - tls: false, - wantData: []string{"", testUser, testPwd}, - }) -} - -func TestLoginTLS(t *testing.T) { - testLoginAuth(t, &authTest{ - auths: []string{"LOGIN"}, - challenges: []string{"Username:", "Password:"}, - tls: true, - wantData: []string{"", testUser, testPwd}, - }) -} - -func testLoginAuth(t *testing.T, test *authTest) { - auth := &loginAuth{ - username: testUser, - password: testPwd, - host: testHost, - } - server := &smtp.ServerInfo{ - Name: testHost, - TLS: test.tls, - Auth: test.auths, - } - proto, toServer, err := auth.Start(server) - if err != nil && !test.wantError { - t.Fatalf("loginAuth.Start(): %v", err) - } - if err != nil && test.wantError { - return - } - if proto != "LOGIN" { - t.Errorf("invalid protocol, got %q, want LOGIN", proto) - } - - i := 0 - got := string(toServer) - if got != test.wantData[i] { - t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i]) - } - - for _, challenge := range test.challenges { - i++ - if i >= len(test.wantData) { - t.Fatalf("unexpected challenge: %q", challenge) - } - - toServer, err = auth.Next([]byte(challenge), true) - if err != nil { - t.Fatalf("loginAuth.Auth(): %v", err) - } - got = string(toServer) - if got != test.wantData[i] { - t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i]) - } - } -} diff --git a/go.mod b/go.mod index 4c2f62d..63581c3 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,8 @@ module github.com/gophish/gomail go 1.13 -require gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc +require ( + github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 + github.com/emersion/go-smtp v0.14.0 + gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc +) diff --git a/go.sum b/go.sum index 4ca8b8e..45067bd 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,6 @@ +github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ= +github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= +github.com/emersion/go-smtp v0.14.0 h1:RYW203p+EcPjL8Z/ZpT9lZ6iOc8MG1MQzEx1UKEkXlA= +github.com/emersion/go-smtp v0.14.0/go.mod h1:qm27SGYgoIPRot6ubfQ/GpiPy/g3PaZAVRxiO/sDUgQ= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= diff --git a/smtp.go b/smtp.go index 18ac53a..264b1d7 100644 --- a/smtp.go +++ b/smtp.go @@ -5,9 +5,11 @@ import ( "fmt" "io" "net" - "net/smtp" "strings" "time" + + "github.com/emersion/go-sasl" + "github.com/emersion/go-smtp" ) var defaultDialer = &net.Dialer{ @@ -27,7 +29,7 @@ type Dialer struct { Password string // Auth represents the authentication mechanism used to authenticate to the // SMTP server. - Auth smtp.Auth + Auth sasl.Client // SSL defines whether an SSL connection is used. It should be false in // most cases since the authentication mechanism should use the STARTTLS // extension instead. @@ -38,6 +40,8 @@ type Dialer struct { // LocalName is the hostname sent to the SMTP server with the HELO command. // By default, "localhost" is sent. LocalName string + // Options + MailOptions *smtp.MailOptions dialer netDialer } @@ -112,17 +116,11 @@ func (d *Dialer) Dial() (SendCloser, error) { if d.Auth == nil && d.Username != "" { if ok, auths := c.Extension("AUTH"); ok { - if strings.Contains(auths, "CRAM-MD5") { - d.Auth = smtp.CRAMMD5Auth(d.Username, d.Password) - } else if strings.Contains(auths, "LOGIN") && + if strings.Contains(auths, "LOGIN") && !strings.Contains(auths, "PLAIN") { - d.Auth = &loginAuth{ - username: d.Username, - password: d.Password, - host: d.Host, - } + d.Auth = sasl.NewLoginClient(d.Username, d.Password) } else { - d.Auth = smtp.PlainAuth("", d.Username, d.Password, d.Host) + d.Auth = sasl.NewPlainClient("", d.Username, d.Password) } } } @@ -166,7 +164,7 @@ type smtpSender struct { } func (c *smtpSender) Send(from string, to []string, msg io.WriterTo) error { - if err := c.Mail(from); err != nil { + if err := c.Mail(from, c.d.MailOptions); err != nil { if err == io.EOF { // This is probably due to a timeout, so reconnect and try again. sc, derr := c.d.Dial() @@ -223,8 +221,8 @@ type smtpClient interface { Hello(string) error Extension(string) (bool, string) StartTLS(*tls.Config) error - Auth(smtp.Auth) error - Mail(string) error + Auth(sasl.Client) error + Mail(string, *smtp.MailOptions) error Rcpt(string) error Reset() error Data() (io.WriteCloser, error) diff --git a/smtp_test.go b/smtp_test.go index 8c43d1c..f9807c3 100644 --- a/smtp_test.go +++ b/smtp_test.go @@ -5,9 +5,18 @@ import ( "crypto/tls" "io" "net" - "net/smtp" "reflect" "testing" + + "github.com/emersion/go-sasl" + "github.com/emersion/go-smtp" +) + + +const ( + testUser = "user" + testPwd = "pwd" + testHost = "smtp.example.com" ) const ( @@ -19,7 +28,7 @@ var ( testConn = &net.TCPConn{} testTLSConn = &tls.Conn{} testConfig = &tls.Config{InsecureSkipVerify: true} - testAuth = smtp.PlainAuth("", testUser, testPwd, testHost) + testAuth = sasl.NewPlainClient("", testUser, testPwd) ) type mockNetDialer struct { @@ -183,7 +192,7 @@ func (c *mockClient) StartTLS(config *tls.Config) error { return nil } -func (c *mockClient) Auth(a smtp.Auth) error { +func (c *mockClient) Auth(a sasl.Client) error { if !reflect.DeepEqual(a, testAuth) { c.t.Errorf("Invalid auth, got %#v, want %#v", a, testAuth) } @@ -191,7 +200,7 @@ func (c *mockClient) Auth(a smtp.Auth) error { return nil } -func (c *mockClient) Mail(from string) error { +func (c *mockClient) Mail(from string, opts *smtp.MailOptions) error { c.do("Mail " + from) if c.timeout { c.timeout = false