From dd4d482a7011bd7dcc2dbcc3eadd685a589e93e8 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 20 May 2026 16:40:18 +0800 Subject: [PATCH] fix email reminder dedup keys --- .../service/notification_email_service.go | 66 +++-- .../notification_email_service_test.go | 228 ++++++++++++++++++ 2 files changed, 278 insertions(+), 16 deletions(-) diff --git a/backend/internal/service/notification_email_service.go b/backend/internal/service/notification_email_service.go index 82078283..d21363b6 100644 --- a/backend/internal/service/notification_email_service.go +++ b/backend/internal/service/notification_email_service.go @@ -382,7 +382,7 @@ func (s *NotificationEmailService) Send(ctx context.Context, input NotificationE deliveryKey := notificationEmailDeliveryKey(normalizedEvent, input.SourceType, input.SourceID, recipient, input.ReminderKey) if deliveryKey != "" { - sent, err := s.deliveryExists(ctx, deliveryKey) + sent, err := s.deliveryExists(ctx, deliveryKey, legacyNotificationEmailDeliveryKey(normalizedEvent, input.SourceType, input.SourceID, recipient, input.ReminderKey)) if err != nil { return err } @@ -398,7 +398,9 @@ func (s *NotificationEmailService) Send(ctx context.Context, input NotificationE return notificationEmailDeliveryErr(err) } if deliveryKey != "" { - _ = s.settingRepo.Set(ctx, deliveryKey, time.Now().UTC().Format(time.RFC3339Nano)) + if err := s.settingRepo.Set(ctx, deliveryKey, time.Now().UTC().Format(time.RFC3339Nano)); err != nil { + return err + } } return nil } @@ -441,14 +443,19 @@ func (s *NotificationEmailService) IsUnsubscribed(ctx context.Context, email, ev if !info.Optional { return false, nil } - value, err := s.settingRepo.GetValue(ctx, notificationEmailPreferenceKey(normalizedEvent, email)) - if err != nil { - if errors.Is(err, ErrSettingNotFound) { - return false, nil + for _, key := range []string{notificationEmailPreferenceKey(normalizedEvent, email), legacyNotificationEmailPreferenceKey(normalizedEvent, email)} { + if strings.TrimSpace(key) == "" { + continue + } + value, err := s.settingRepo.GetValue(ctx, key) + if err == nil { + return strings.EqualFold(strings.TrimSpace(value), "unsubscribed"), nil + } + if !errors.Is(err, ErrSettingNotFound) { + return false, err } - return false, err } - return strings.EqualFold(strings.TrimSpace(value), "unsubscribed"), nil + return false, nil } func (s *NotificationEmailService) Unsubscribe(ctx context.Context, token string) (NotificationEmailUnsubscribeResult, error) { @@ -610,15 +617,20 @@ func (s *NotificationEmailService) unsubscribeSecret(ctx context.Context) (strin return secret, nil } -func (s *NotificationEmailService) deliveryExists(ctx context.Context, key string) (bool, error) { - _, err := s.settingRepo.GetValue(ctx, key) - if err == nil { - return true, nil +func (s *NotificationEmailService) deliveryExists(ctx context.Context, keys ...string) (bool, error) { + for _, key := range keys { + if strings.TrimSpace(key) == "" { + continue + } + _, err := s.settingRepo.GetValue(ctx, key) + if err == nil { + return true, nil + } + if !errors.Is(err, ErrSettingNotFound) { + return false, err + } } - if errors.Is(err, ErrSettingNotFound) { - return false, nil - } - return false, err + return false, nil } func validateNotificationEmailTemplate(event, subject, htmlBody string) error { @@ -749,10 +761,32 @@ func notificationEmailTemplateKey(event, locale string) string { } func notificationEmailPreferenceKey(event, email string) string { + if strings.TrimSpace(event) == "" || strings.TrimSpace(email) == "" { + return "" + } + identity := strings.TrimSpace(event) + "\x00" + strings.ToLower(strings.TrimSpace(email)) + return notificationEmailPreferenceKeyPrefix + "v2:" + notificationEmailHash(identity) +} + +func legacyNotificationEmailPreferenceKey(event, email string) string { return notificationEmailPreferenceKeyPrefix + event + ":" + notificationEmailHash(email) } func notificationEmailDeliveryKey(event, sourceType, sourceID, recipient, reminderKey string) string { + if strings.TrimSpace(sourceType) == "" || strings.TrimSpace(sourceID) == "" || strings.TrimSpace(recipient) == "" { + return "" + } + identity := strings.Join([]string{ + strings.ToLower(strings.TrimSpace(event)), + safeNotificationEmailKeyPart(sourceType), + safeNotificationEmailKeyPart(sourceID), + strings.ToLower(strings.TrimSpace(recipient)), + safeNotificationEmailKeyPart(reminderKey), + }, "\x00") + return notificationEmailDeliveryKeyPrefix + "v2:" + notificationEmailHash(identity) +} + +func legacyNotificationEmailDeliveryKey(event, sourceType, sourceID, recipient, reminderKey string) string { if strings.TrimSpace(sourceType) == "" || strings.TrimSpace(sourceID) == "" || strings.TrimSpace(recipient) == "" { return "" } diff --git a/backend/internal/service/notification_email_service_test.go b/backend/internal/service/notification_email_service_test.go index f375ba7e..38987f6f 100644 --- a/backend/internal/service/notification_email_service_test.go +++ b/backend/internal/service/notification_email_service_test.go @@ -1,10 +1,13 @@ package service import ( + "bufio" "context" "errors" + "net" "strings" "sync" + "sync/atomic" "testing" "github.com/stretchr/testify/require" @@ -262,6 +265,114 @@ func TestNotificationEmailLocaleMemoryNormalizesAcceptLanguage(t *testing.T) { require.Equal(t, "zh", svc.ResolveRecipientLocale(ctx, 0, "user@example.com")) } +func TestNotificationEmailDeliveryKeyUsesShortStableHash(t *testing.T) { + key := notificationEmailDeliveryKey( + NotificationEmailEventSubscriptionExpiryReminder, + "user_subscription", + "1234567890", + "User@Example.com", + "7d", + ) + require.NotEmpty(t, key) + require.LessOrEqual(t, len(key), 100) + require.True(t, strings.HasPrefix(key, notificationEmailDeliveryKeyPrefix+"v2:")) + require.Equal(t, key, notificationEmailDeliveryKey( + NotificationEmailEventSubscriptionExpiryReminder, + "user_subscription", + "1234567890", + "user@example.com", + "7d", + )) + require.NotEqual(t, key, notificationEmailDeliveryKey( + NotificationEmailEventSubscriptionExpiryReminder, + "user_subscription", + "1234567890", + "user@example.com", + "3d", + )) + + legacyKey := legacyNotificationEmailDeliveryKey( + NotificationEmailEventSubscriptionExpiryReminder, + "user_subscription", + "1234567890", + "user@example.com", + "7d", + ) + require.Greater(t, len(legacyKey), 100) +} + +func TestNotificationEmailPreferenceKeyUsesShortStableHashAndReadsLegacyKey(t *testing.T) { + ctx := context.Background() + repo := newNotificationEmailMemorySettingRepo() + svc := NewNotificationEmailService(repo, nil) + + key := notificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "User@Example.com") + require.NotEmpty(t, key) + require.LessOrEqual(t, len(key), 100) + require.True(t, strings.HasPrefix(key, notificationEmailPreferenceKeyPrefix+"v2:")) + require.Equal(t, key, notificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "user@example.com")) + + legacyKey := legacyNotificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "user@example.com") + require.Greater(t, len(legacyKey), 100) + require.NoError(t, repo.Set(ctx, legacyKey, "unsubscribed")) + + unsubscribed, err := svc.IsUnsubscribed(ctx, "User@Example.com", NotificationEmailEventSubscriptionExpiryReminder) + require.NoError(t, err) + require.True(t, unsubscribed) +} + +func TestNotificationEmailSendDeduplicatesSubscriptionExpiryReminder(t *testing.T) { + ctx := context.Background() + repo := newNotificationEmailMemorySettingRepo() + smtpServer := startNotificationEmailTestSMTPServer(t) + require.NoError(t, repo.SetMultiple(ctx, smtpServer.settings())) + + emailSvc := NewEmailService(repo, nil) + svc := NewNotificationEmailService(repo, emailSvc) + input := NotificationEmailSendInput{ + Event: NotificationEmailEventSubscriptionExpiryReminder, + RecipientEmail: "User@Example.com", + RecipientName: "User", + UserID: 42, + SourceType: "user_subscription", + SourceID: "1234567890", + ReminderKey: "7d", + Variables: map[string]string{ + "subscription_group": "Codex", + "expiry_time": "2026-05-27 12:00", + "days_remaining": "7", + }, + } + + require.NoError(t, svc.Send(ctx, input)) + require.Equal(t, int64(1), smtpServer.messageCount()) + + key := notificationEmailDeliveryKey(input.Event, input.SourceType, input.SourceID, input.RecipientEmail, input.ReminderKey) + require.LessOrEqual(t, len(key), 100) + _, err := repo.GetValue(ctx, key) + require.NoError(t, err) + + require.NoError(t, svc.Send(ctx, input)) + require.Equal(t, int64(1), smtpServer.messageCount()) +} + +func TestNotificationEmailSendRespectsLegacyDeliveryKey(t *testing.T) { + ctx := context.Background() + repo := newNotificationEmailMemorySettingRepo() + svc := NewNotificationEmailService(repo, nil) + input := NotificationEmailSendInput{ + Event: NotificationEmailEventSubscriptionExpiryReminder, + RecipientEmail: "user@example.com", + SourceType: "user_subscription", + SourceID: "1234567890", + ReminderKey: "7d", + } + legacyKey := legacyNotificationEmailDeliveryKey(input.Event, input.SourceType, input.SourceID, input.RecipientEmail, input.ReminderKey) + require.NoError(t, repo.Set(ctx, legacyKey, "sent")) + + require.NoError(t, svc.Send(ctx, input)) +} + type notificationEmailMemorySettingRepo struct { mu sync.RWMutex values map[string]string @@ -341,3 +452,120 @@ func TestNotificationEmailMemorySettingRepoSatisfiesInterface(t *testing.T) { var _ SettingRepository = (*notificationEmailMemorySettingRepo)(nil) require.False(t, strings.Contains(notificationEmailPreferenceKey(NotificationEmailEventBalanceLow, "User@Example.com"), "User@Example.com")) } + +type notificationEmailTestSMTPServer struct { + listener net.Listener + wg sync.WaitGroup + messages atomic.Int64 +} + +func startNotificationEmailTestSMTPServer(t *testing.T) *notificationEmailTestSMTPServer { + t.Helper() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + server := ¬ificationEmailTestSMTPServer{listener: listener} + server.wg.Add(1) + go server.serve() + t.Cleanup(server.close) + return server +} + +func (s *notificationEmailTestSMTPServer) settings() map[string]string { + host, port, _ := net.SplitHostPort(s.listener.Addr().String()) + return map[string]string{ + SettingKeySMTPHost: host, + SettingKeySMTPPort: port, + SettingKeySMTPUsername: "user", + SettingKeySMTPPassword: "password", + SettingKeySMTPFrom: "noreply@example.com", + SettingKeySMTPFromName: "Sub2API", + SettingKeySMTPUseTLS: "false", + } +} + +func (s *notificationEmailTestSMTPServer) messageCount() int64 { + return s.messages.Load() +} + +func (s *notificationEmailTestSMTPServer) close() { + _ = s.listener.Close() + s.wg.Wait() +} + +func (s *notificationEmailTestSMTPServer) serve() { + defer s.wg.Done() + for { + conn, err := s.listener.Accept() + if err != nil { + return + } + s.handleConn(conn) + } +} + +func (s *notificationEmailTestSMTPServer) handleConn(conn net.Conn) { + defer func() { _ = conn.Close() }() + rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + writeLine := func(line string) bool { + if _, err := rw.WriteString(line + "\r\n"); err != nil { + return false + } + return rw.Flush() == nil + } + if !writeLine("220 localhost ESMTP") { + return + } + for { + line, err := rw.ReadString('\n') + if err != nil { + return + } + cmd := strings.ToUpper(strings.TrimRight(line, "\r\n")) + switch { + case strings.HasPrefix(cmd, "EHLO"), strings.HasPrefix(cmd, "HELO"): + if _, err := rw.WriteString("250-localhost\r\n250 AUTH PLAIN\r\n"); err != nil { + return + } + if err := rw.Flush(); err != nil { + return + } + case strings.HasPrefix(cmd, "AUTH"): + if !writeLine("235 2.7.0 Authentication successful") { + return + } + case strings.HasPrefix(cmd, "MAIL FROM:"): + if !writeLine("250 2.1.0 OK") { + return + } + case strings.HasPrefix(cmd, "RCPT TO:"): + if !writeLine("250 2.1.5 OK") { + return + } + case strings.HasPrefix(cmd, "DATA"): + if !writeLine("354 End data with .") { + return + } + for { + dataLine, err := rw.ReadString('\n') + if err != nil { + return + } + if strings.TrimRight(dataLine, "\r\n") == "." { + break + } + } + s.messages.Add(1) + if !writeLine("250 2.0.0 OK") { + return + } + case strings.HasPrefix(cmd, "QUIT"): + _ = writeLine("221 2.0.0 Bye") + return + default: + if !writeLine("250 OK") { + return + } + } + } +}