diff --git a/services/skus/datastore.go b/services/skus/datastore.go index b13585b72..4f1543cb9 100644 --- a/services/skus/datastore.go +++ b/services/skus/datastore.go @@ -112,7 +112,7 @@ type orderStore interface { SetLastPaidAt(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error SetTrialDays(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID, ndays int64) (*model.Order, error) SetStatus(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, status string) error - GetTimeBounds(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (model.OrderTimeBounds, error) + GetExpiresAtP1M(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (time.Time, error) SetExpiresAt(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error UpdateMetadata(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, data datastore.Metadata) error AppendMetadata(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, key, val string) error @@ -1435,10 +1435,10 @@ func (pg *Postgres) recordOrderPayment(ctx context.Context, dbi sqlx.ExecerConte } func (pg *Postgres) updateOrderExpiresAt(ctx context.Context, dbi sqlx.ExtContext, orderID uuid.UUID) error { - orderTimeBounds, err := pg.orderRepo.GetTimeBounds(ctx, dbi, orderID) + expiresAt, err := pg.orderRepo.GetExpiresAtP1M(ctx, dbi, orderID) if err != nil { return fmt.Errorf("unable to get order time bounds: %w", err) } - return pg.orderRepo.SetExpiresAt(ctx, dbi, orderID, orderTimeBounds.ExpiresAt()) + return pg.orderRepo.SetExpiresAt(ctx, dbi, orderID, expiresAt) } diff --git a/services/skus/storage/repository/repository.go b/services/skus/storage/repository/repository.go index 4a300341f..0e840a06f 100644 --- a/services/skus/storage/repository/repository.go +++ b/services/skus/storage/repository/repository.go @@ -142,6 +142,22 @@ func (r *Order) GetTimeBounds(ctx context.Context, dbi sqlx.QueryerContext, id u return result, nil } +// GetExpiresAtP1M returns expires_at that is last_paid_at (or now()) plus 1 month. +func (r *Order) GetExpiresAtP1M(ctx context.Context, dbi sqlx.QueryerContext, id uuid.UUID) (time.Time, error) { + const q = `SELECT (SELECT COALESCE(last_paid_at, now()) AS last_paid_at) + interval '1 month' AS expires_at FROM orders WHERE id = $1` + + var result time.Time + if err := sqlx.GetContext(ctx, dbi, &result, q, id); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return time.Time{}, model.ErrOrderNotFound + } + + return time.Time{}, err + } + + return result, nil +} + // SetExpiresAt sets expires_at. func (r *Order) SetExpiresAt(ctx context.Context, dbi sqlx.ExecerContext, id uuid.UUID, when time.Time) error { const q = `UPDATE orders SET updated_at = CURRENT_TIMESTAMP, expires_at = $2 WHERE id = $1` diff --git a/services/skus/storage/repository/repository_test.go b/services/skus/storage/repository/repository_test.go index 62da44ef1..8011b71bf 100644 --- a/services/skus/storage/repository/repository_test.go +++ b/services/skus/storage/repository/repository_test.go @@ -7,6 +7,7 @@ import ( "database/sql" "errors" "testing" + "time" uuid "github.com/satori/go.uuid" should "github.com/stretchr/testify/assert" @@ -360,3 +361,94 @@ func TestOrder_AppendMetadataInt(t *testing.T) { }) } } + +func TestOrder_GetExpiresAtP1M(t *testing.T) { + dbi, err := setupDBI() + must.Equal(t, nil, err) + + defer func() { + _, _ = dbi.Exec("TRUNCATE_TABLE orders;") + }() + + type tcGiven struct { + lastPaidAt time.Time + } + + type tcExpected struct { + expiresAt time.Time + err error + } + + type testCase struct { + name string + given tcGiven + exp tcExpected + } + + tests := []testCase{ + { + name: "no_last_paid", + }, + + { + name: "20230202", + given: tcGiven{ + lastPaidAt: time.Date(2023, time.February, 2, 1, 0, 0, 0, time.UTC), + }, + exp: tcExpected{ + expiresAt: time.Date(2023, time.March, 2, 1, 0, 0, 0, time.UTC), + }, + }, + + { + name: "20230331", + given: tcGiven{ + lastPaidAt: time.Date(2023, time.March, 31, 1, 0, 0, 0, time.UTC), + }, + exp: tcExpected{ + expiresAt: time.Date(2023, time.April, 30, 1, 0, 0, 0, time.UTC), + }, + }, + } + + repo := repository.NewOrder() + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + ctx := context.TODO() + + tx, err := dbi.BeginTxx(ctx, &sql.TxOptions{Isolation: sql.LevelReadUncommitted}) + must.Equal(t, nil, err) + + t.Cleanup(func() { _ = tx.Rollback() }) + + order, err := createOrderForTest(ctx, tx, repo) + must.Equal(t, nil, err) + + if !tc.given.lastPaidAt.IsZero() { + err := repo.SetLastPaidAt(ctx, tx, order.ID, tc.given.lastPaidAt) + must.Equal(t, nil, err) + } + + actual, err := repo.GetExpiresAtP1M(ctx, tx, order.ID) + must.Equal(t, nil, err) + + // Handle the special case where last_paid_at was not set. + // The time is generated by the database, so it is non-deterministic. + // The result should not be too far from time.Now()+1 month. + if tc.given.lastPaidAt.IsZero() { + now := time.Now() + future := time.Date(now.Year(), now.Month()+1, now.Day(), now.Hour(), now.Minute(), now.Second(), now.Nanosecond(), now.Location()) + + should.Equal(t, true, future.Sub(actual) < time.Duration(12*time.Hour)) + return + } + + // TODO(pavelb): update local and testing containers to use Go 1.20+. + // Then switch to tc.exp.expiresAt.Compare(actual) == 0. + should.Equal(t, true, tc.exp.expiresAt.Sub(actual) == 0) + }) + } +}