Fix Office 365 SMTP auth fallback (#4157)

* Fix Office 365 SMTP auth fallback

* Fix SMTP auth fallback tests

* fix(smtp): address code review feedback for Office 365 auth fallback

- Move defer c.Close() after nil check in sendSMTP to prevent panic
  when openSMTPClient() fails (c can be nil on dial/setup failure).
- Add TLS security guard to loginAuth.Start: refuse credentials on
  unencrypted remote connections (mirroring smtp.PlainAuth behavior),
  validate expected host name, and allow localhost bypass.
- Add isLocalhost() helper for loopback/private-network checks.
- Add comprehensive test coverage: loginAuth.Start security checks
  (unencrypted remote, TLS, localhost, loopback IPs, wrong host),
  sendSMTP no-panic on dial failure, and full sendSMTP flow tests
  with mock SMTP server (PLAIN success, LOGIN fallback reconnect,
  unauthenticated relay).
This commit is contained in:
taogejiang
2026-06-17 11:27:48 +08:00
committed by GitHub
parent f46b929ebc
commit 1f8f3e8037
2 changed files with 550 additions and 50 deletions

View File

@@ -2,6 +2,7 @@ package service
import (
"crypto/tls"
"encoding/base64"
"fmt"
"html"
"mime"
@@ -34,6 +35,131 @@ type EmailService struct {
smtpEHLOName string
}
type smtpAuthClient interface {
Auth(smtp.Auth) error
Extension(string) (bool, string)
}
type smtpClientAdapter struct {
client *smtp.Client
}
func (a smtpClientAdapter) Auth(auth smtp.Auth) error {
return a.client.Auth(auth)
}
func (a smtpClientAdapter) Extension(name string) (bool, string) {
return a.client.Extension(name)
}
func isLocalhost(name string) bool {
return name == "localhost" || name == "127.0.0.1" || name == "::1"
}
type loginAuth struct {
username string
password string
host string
}
func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
if !server.TLS && !isLocalhost(server.Name) {
return "", nil, fmt.Errorf("unencrypted connection")
}
if server.Name != a.host {
return "", nil, fmt.Errorf("wrong host name: %q does not match expected %q", server.Name, a.host)
}
return "LOGIN", nil, nil
}
func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if !more {
return nil, nil
}
raw := strings.TrimSpace(string(fromServer))
challenge := strings.ToLower(raw)
if decoded, err := base64.StdEncoding.DecodeString(raw); err == nil {
challenge = strings.ToLower(strings.TrimSpace(string(decoded)))
}
switch {
case strings.Contains(challenge, "username") || strings.Contains(challenge, "user name"):
return []byte(a.username), nil
case strings.Contains(challenge, "password"):
return []byte(a.password), nil
default:
return nil, fmt.Errorf("unexpected LOGIN challenge %q", raw)
}
}
func smtpAuthWithFallback(c smtpAuthClient, host, username, password string) (bool, error) {
plainErr := c.Auth(smtp.PlainAuth("", username, password, host))
if plainErr == nil {
return false, nil
}
msg := strings.ToLower(plainErr.Error())
if !strings.Contains(msg, "unrecognized authentication type") && !strings.Contains(msg, "504 5.7.4") {
return false, plainErr
}
ok, authLine := c.Extension("AUTH")
if !ok || !strings.Contains(strings.ToUpper(authLine), "LOGIN") {
return false, plainErr
}
return true, plainErr
}
func (s *EmailService) openSMTPClient() (*smtp.Client, error) {
addr := net.JoinHostPort(s.smtpHost, s.smtpPort)
tlsCfg := &tls.Config{
ServerName: s.smtpHost,
InsecureSkipVerify: s.smtpTLSInsecure, //nolint:gosec // opt-in via SMTP_TLS_INSECURE=true
}
var conn net.Conn
var err error
if s.smtpTLSImplicit {
dialer := &net.Dialer{Timeout: 10 * time.Second}
conn, err = tls.DialWithDialer(dialer, "tcp", addr, tlsCfg)
} else {
conn, err = net.DialTimeout("tcp", addr, 10*time.Second)
}
if err != nil {
return nil, fmt.Errorf("smtp dial %s: %w", addr, err)
}
if err = conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
conn.Close()
return nil, fmt.Errorf("smtp set deadline: %w", err)
}
c, err := smtp.NewClient(conn, s.smtpHost)
if err != nil {
conn.Close()
return nil, fmt.Errorf("smtp client: %w", err)
}
if s.smtpEHLOName != "" {
if err = c.Hello(s.smtpEHLOName); err != nil {
c.Close()
return nil, fmt.Errorf("smtp EHLO %s: %w", s.smtpEHLOName, err)
}
}
if !s.smtpTLSImplicit {
if ok, _ := c.Extension("STARTTLS"); ok {
if err = c.StartTLS(tlsCfg); err != nil {
c.Close()
return nil, fmt.Errorf("smtp starttls: %w", err)
}
}
}
return c, nil
}
func NewEmailService() *EmailService {
apiKey := os.Getenv("RESEND_API_KEY")
from := strings.TrimSpace(os.Getenv("RESEND_FROM_EMAIL"))
@@ -119,61 +245,29 @@ func NewEmailService() *EmailService {
// Upgrades to STARTTLS when advertised by the server.
// Set SMTP_TLS_INSECURE=true for self-signed or private CA certificates.
func (s *EmailService) sendSMTP(to, subject, htmlBody string) error {
addr := net.JoinHostPort(s.smtpHost, s.smtpPort)
tlsCfg := &tls.Config{
ServerName: s.smtpHost,
InsecureSkipVerify: s.smtpTLSInsecure, //nolint:gosec // opt-in via SMTP_TLS_INSECURE=true
}
// Bounded dial + whole-session deadline: prevents a blackholed SMTP server
// from hanging the auth handler (or a background goroutine) indefinitely.
var conn net.Conn
var err error
if s.smtpTLSImplicit {
dialer := &net.Dialer{Timeout: 10 * time.Second}
conn, err = tls.DialWithDialer(dialer, "tcp", addr, tlsCfg)
} else {
conn, err = net.DialTimeout("tcp", addr, 10*time.Second)
}
c, err := s.openSMTPClient()
if err != nil {
return fmt.Errorf("smtp dial %s: %w", addr, err)
}
if err = conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
conn.Close()
return fmt.Errorf("smtp set deadline: %w", err)
}
c, err := smtp.NewClient(conn, s.smtpHost)
if err != nil {
conn.Close()
return fmt.Errorf("smtp client: %w", err)
return err
}
defer c.Close()
// Greet with a real hostname before any other command, else net/smtp lazily
// EHLOs "localhost" — which strict relays drop, surfacing as an opaque EOF on
// a later command rather than at the EHLO itself.
if s.smtpEHLOName != "" {
if err = c.Hello(s.smtpEHLOName); err != nil {
return fmt.Errorf("smtp EHLO %s: %w", s.smtpEHLOName, err)
}
}
// STARTTLS upgrade only makes sense when the underlying connection is still
// plaintext. Skip when we already dialed with implicit TLS.
if !s.smtpTLSImplicit {
if ok, _ := c.Extension("STARTTLS"); ok {
if err = c.StartTLS(tlsCfg); err != nil {
return fmt.Errorf("smtp starttls: %w", err)
}
}
}
if s.smtpUsername != "" {
auth := smtp.PlainAuth("", s.smtpUsername, s.smtpPassword, s.smtpHost)
if err = c.Auth(auth); err != nil {
return fmt.Errorf("smtp auth: %w", err)
fallbackToLogin, authErr := smtpAuthWithFallback(smtpClientAdapter{client: c}, s.smtpHost, s.smtpUsername, s.smtpPassword)
if authErr != nil {
if !fallbackToLogin {
return fmt.Errorf("smtp auth: %w", authErr)
}
c.Close()
c, err = s.openSMTPClient()
if err != nil {
return fmt.Errorf("smtp auth: plain auth failed (%v); login reconnect failed: %w", authErr, err)
}
defer c.Close()
if err = c.Auth(&loginAuth{username: s.smtpUsername, password: s.smtpPassword, host: s.smtpHost}); err != nil {
return fmt.Errorf("smtp auth: plain auth failed (%v); login auth fallback failed: %w", authErr, err)
}
}
}

View File

@@ -1,11 +1,104 @@
package service
import (
"bufio"
"encoding/base64"
"errors"
"fmt"
"net"
"net/smtp"
"net/textproto"
"os"
"strings"
"testing"
)
type fakeSMTPAuthClient struct {
authErrs []error
authCalls []smtp.Auth
authLine string
textClient *textproto.Conn
}
func (f *fakeSMTPAuthClient) Auth(auth smtp.Auth) error {
f.authCalls = append(f.authCalls, auth)
if len(f.authErrs) == 0 {
return nil
}
err := f.authErrs[0]
f.authErrs = f.authErrs[1:]
return err
}
func (f *fakeSMTPAuthClient) Text() *textproto.Conn {
return f.textClient
}
func (f *fakeSMTPAuthClient) Extension(name string) (bool, string) {
if strings.EqualFold(name, "AUTH") && f.authLine != "" {
return true, f.authLine
}
return false, ""
}
func TestSMTPAuthWithFallback_UsesPlainWhenAccepted(t *testing.T) {
client := &fakeSMTPAuthClient{}
fallback, err := smtpAuthWithFallback(client, "smtp.office365.com", "user", "pass")
if err != nil {
t.Fatalf("smtpAuthWithFallback returned error: %v", err)
}
if fallback {
t.Fatalf("expected no fallback when PLAIN auth succeeds")
}
if len(client.authCalls) != 1 {
t.Fatalf("expected 1 auth call, got %d", len(client.authCalls))
}
if _, ok := client.authCalls[0].(*loginAuth); ok {
t.Fatalf("expected first auth to be PLAIN, got LOGIN")
}
}
func TestSMTPAuthWithFallback_FallsBackToLoginOnOffice365Style504(t *testing.T) {
client := &fakeSMTPAuthClient{
authErrs: []error{
errors.New("504 5.7.4 Unrecognized authentication type"),
nil,
},
authLine: "XOAUTH2 LOGIN",
}
fallback, err := smtpAuthWithFallback(client, "smtp.office365.com", "user", "pass")
if !fallback {
t.Fatalf("expected fallback signal when Office 365 rejects PLAIN auth")
}
if err == nil {
t.Fatalf("expected original PLAIN auth error to be returned for reconnect path")
}
if len(client.authCalls) != 1 {
t.Fatalf("expected 1 auth call before reconnect, got %d", len(client.authCalls))
}
if _, ok := client.authCalls[0].(*loginAuth); ok {
t.Fatalf("expected first auth attempt to remain PLAIN")
}
}
func TestSMTPAuthWithFallback_DoesNotFallbackWithoutLoginSupport(t *testing.T) {
wantErr := errors.New("504 5.7.4 Unrecognized authentication type")
client := &fakeSMTPAuthClient{
authErrs: []error{wantErr},
authLine: "XOAUTH2",
}
fallback, err := smtpAuthWithFallback(client, "smtp.office365.com", "user", "pass")
if fallback {
t.Fatalf("did not expect fallback when server does not advertise LOGIN")
}
if !errors.Is(err, wantErr) {
t.Fatalf("expected original error, got %v", err)
}
if len(client.authCalls) != 1 {
t.Fatalf("expected 1 auth call, got %d", len(client.authCalls))
}
}
func TestSanitizeSubjectField(t *testing.T) {
long := strings.Repeat("a", 100)
longRunes := strings.Repeat("深", 100)
@@ -250,3 +343,316 @@ func TestBuildInvitationParams_ToAndFromPassedThrough(t *testing.T) {
t.Errorf("body missing invite URL: %s", p.Html)
}
}
// --- loginAuth.Start security tests ---
func TestLoginAuth_Start_RefusesUnencryptedRemote(t *testing.T) {
auth := &loginAuth{username: "user", password: "pass", host: "smtp.office365.com"}
_, _, err := auth.Start(&smtp.ServerInfo{
Name: "smtp.office365.com",
TLS: false,
})
if err == nil {
t.Fatal("expected error for unencrypted remote connection")
}
if !strings.Contains(err.Error(), "unencrypted connection") {
t.Errorf("expected 'unencrypted connection' error, got: %v", err)
}
}
func TestLoginAuth_Start_AllowsTLS(t *testing.T) {
auth := &loginAuth{username: "user", password: "pass", host: "smtp.office365.com"}
_, _, err := auth.Start(&smtp.ServerInfo{
Name: "smtp.office365.com",
TLS: true,
})
if err != nil {
t.Fatalf("expected no error for TLS connection, got: %v", err)
}
}
func TestLoginAuth_Start_AllowsLocalhost(t *testing.T) {
auth := &loginAuth{username: "user", password: "pass", host: "localhost"}
_, _, err := auth.Start(&smtp.ServerInfo{
Name: "localhost",
TLS: false,
})
if err != nil {
t.Fatalf("expected no error for localhost connection, got: %v", err)
}
}
func TestLoginAuth_Start_RejectsWrongHost(t *testing.T) {
auth := &loginAuth{username: "user", password: "pass", host: "smtp.office365.com"}
_, _, err := auth.Start(&smtp.ServerInfo{
Name: "evil-relay.example.com",
TLS: true,
})
if err == nil {
t.Fatal("expected error for host mismatch")
}
if !strings.Contains(err.Error(), "wrong host name") {
t.Errorf("expected 'wrong host name' error, got: %v", err)
}
}
func TestLoginAuth_Start_AllowsLoopbackIPs(t *testing.T) {
for _, name := range []string{"127.0.0.1", "::1"} {
auth := &loginAuth{username: "user", password: "pass", host: name}
_, _, err := auth.Start(&smtp.ServerInfo{
Name: name,
TLS: false,
})
if err != nil {
t.Errorf("expected no error for %s, got: %v", name, err)
}
}
}
// --- sendSMTP no panic on openSMTPClient failure ---
func TestSendSMTP_OpenClientFailureNoPanic(t *testing.T) {
s := &EmailService{
smtpHost: "255.255.255.255", // unroutable, will time out or fail
smtpPort: "25",
smtpUsername: "user",
smtpPassword: "pass",
}
err := s.sendSMTP("to@example.com", "Subject", "<p>body</p>")
if err == nil {
t.Fatal("expected error from unreachable SMTP server")
}
// The important assertion: we reached here without panicking.
t.Logf("sendSMTP correctly returned error: %v", err)
}
// --- Full sendSMTP flow tests with a mock SMTP server ---
// testSMTPServer is a minimal SMTP server that can simulate Office 365-style
// PLAIN auth rejection followed by LOGIN auth acceptance.
type testSMTPServer struct {
Listener net.Listener
Addr string
// Auth mechs advertised in EHLO response (e.g. "LOGIN" or "PLAIN LOGIN")
AuthMechs string
// If true, AUTH PLAIN returns 504; otherwise it succeeds
RejectPlain bool
ExpectedUser string
ExpectedPass string
// If true, advertise STARTTLS in EHLO
AdvertiseSTARTTLS bool
}
func startTestSMTPServer(t *testing.T, cfg testSMTPServer) (*testSMTPServer, func()) {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
cfg.Listener = l
cfg.Addr = l.Addr().String()
done := make(chan struct{})
go func() {
defer close(done)
for {
conn, err := l.Accept()
if err != nil {
return
}
go cfg.handleConn(conn)
}
}()
cleanup := func() {
l.Close()
<-done
}
return &cfg, cleanup
}
func (s *testSMTPServer) handleConn(conn net.Conn) {
defer conn.Close()
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
writeLine := func(format string, args ...interface{}) {
fmt.Fprintf(rw, format+"\r\n", args...)
rw.Flush()
}
readLine := func() string {
line, err := rw.ReadString('\n')
if err != nil {
return ""
}
return strings.TrimRight(line, "\r\n")
}
writeLine("220 test-smtp ESMTP")
// Wait for EHLO
ehloLine := readLine()
if !strings.HasPrefix(strings.ToUpper(ehloLine), "EHLO") {
writeLine("500 unrecognized command")
return
}
// Build EHLO response
writeLine("250-test-smtp Hello")
if s.AdvertiseSTARTTLS {
writeLine("250-STARTTLS")
}
if s.AuthMechs != "" {
writeLine("250-AUTH " + s.AuthMechs)
}
writeLine("250 OK")
// Read commands until QUIT
for {
line := readLine()
if line == "" {
return
}
upper := strings.ToUpper(line)
switch {
case strings.HasPrefix(upper, "AUTH PLAIN") || strings.HasPrefix(upper, "AUTH PLAIN "):
if s.RejectPlain {
writeLine("504 5.7.4 Unrecognized authentication type")
continue
}
writeLine("235 2.7.0 Auth succeeded")
case strings.HasPrefix(upper, "AUTH LOGIN"):
writeLine("334 VXNlcm5hbWU6") // base64("Username:")
userLine := readLine()
userBytes, _ := base64.StdEncoding.DecodeString(strings.TrimSpace(userLine))
writeLine("334 UGFzc3dvcmQ6") // base64("Password:")
passLine := readLine()
passBytes, _ := base64.StdEncoding.DecodeString(strings.TrimSpace(passLine))
if string(userBytes) == s.ExpectedUser && string(passBytes) == s.ExpectedPass {
writeLine("235 2.7.0 Auth succeeded")
} else {
writeLine("535 5.7.8 Auth failed")
}
case strings.HasPrefix(upper, "MAIL FROM:"):
writeLine("250 OK")
case strings.HasPrefix(upper, "RCPT TO:"):
writeLine("250 OK")
case upper == "DATA":
writeLine("354 Start mail input; end with <CRLF>.<CRLF>")
// Read until line containing only "."
for {
dataLine := readLine()
if dataLine == "." {
break
}
}
writeLine("250 OK")
case strings.HasPrefix(upper, "STARTTLS"):
writeLine("220 Ready to start TLS")
case strings.HasPrefix(upper, "QUIT"):
writeLine("221 bye")
return
default:
writeLine("500 unrecognized command")
}
}
}
func TestSendSMTP_FallbackReconnectsAndAuthsWithLOGIN(t *testing.T) {
srv, cleanup := startTestSMTPServer(t, testSMTPServer{
AuthMechs: "PLAIN LOGIN",
RejectPlain: true,
ExpectedUser: "testuser",
ExpectedPass: "testpass",
})
defer cleanup()
host, port, _ := net.SplitHostPort(srv.Addr)
s := &EmailService{
smtpHost: host,
smtpPort: port,
smtpUsername: "testuser",
smtpPassword: "testpass",
}
// smtpEHLOName is empty so net/smtp defaults to "localhost", which the
// test server accepts. No STARTTLS advertised → plain connection to
// localhost, which loginAuth.Start allows.
err := s.sendSMTP("to@example.com", "Test Subject", "<p>Hello</p>")
if err != nil {
t.Fatalf("sendSMTP failed: %v", err)
}
}
func TestSendSMTP_PlainAuthSucceedsWithoutFallback(t *testing.T) {
srv, cleanup := startTestSMTPServer(t, testSMTPServer{
AuthMechs: "PLAIN LOGIN",
RejectPlain: false, // PLAIN succeeds
ExpectedUser: "testuser",
ExpectedPass: "testpass",
})
defer cleanup()
host, port, _ := net.SplitHostPort(srv.Addr)
s := &EmailService{
smtpHost: host,
smtpPort: port,
smtpUsername: "testuser",
smtpPassword: "testpass",
}
err := s.sendSMTP("to@example.com", "Test Subject", "<p>Hello</p>")
if err != nil {
t.Fatalf("sendSMTP failed: %v", err)
}
}
func TestSendSMTP_NoAuthWhenUsernameEmpty(t *testing.T) {
srv, cleanup := startTestSMTPServer(t, testSMTPServer{
AuthMechs: "PLAIN LOGIN",
})
defer cleanup()
host, port, _ := net.SplitHostPort(srv.Addr)
s := &EmailService{
smtpHost: host,
smtpPort: port,
// smtpUsername is empty → unauthenticated relay
}
err := s.sendSMTP("to@example.com", "Test Subject", "<p>Hello</p>")
if err != nil {
t.Fatalf("sendSMTP failed for unauthenticated relay: %v", err)
}
}
func TestSendSMTP_LoginAuthRejectsUnencryptedRemote(t *testing.T) {
// Simulate a remote server that advertises LOGIN but not STARTTLS.
// Since the connection is not TLS and not localhost, loginAuth.Start
// must refuse to send credentials.
auth := &loginAuth{
username: "user",
password: "pass",
host: "smtp.remote.example.com",
}
_, _, err := auth.Start(&smtp.ServerInfo{
Name: "smtp.remote.example.com",
TLS: false,
})
if err == nil {
t.Fatal("expected error: LOGIN auth on unencrypted remote connection")
}
if !strings.Contains(err.Error(), "unencrypted connection") {
t.Errorf("expected 'unencrypted connection' error, got: %v", err)
}
}