diff --git a/readthedocs/organizations/migrations/0011_add_stripe_subscription_field.py b/readthedocs/organizations/migrations/0011_add_stripe_subscription_field.py new file mode 100644 index 00000000000..49a409c51dd --- /dev/null +++ b/readthedocs/organizations/migrations/0011_add_stripe_subscription_field.py @@ -0,0 +1,40 @@ +# Generated by Django 3.2.16 on 2022-11-21 22:52 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("djstripe", "0010_alter_customer_balance"), + ("organizations", "0010_add_stripe_customer"), + ] + + operations = [ + migrations.AddField( + model_name="historicalorganization", + name="stripe_subscription", + field=models.ForeignKey( + blank=True, + db_constraint=False, + null=True, + on_delete=django.db.models.deletion.DO_NOTHING, + related_name="+", + to="djstripe.subscription", + verbose_name="Stripe subscription", + ), + ), + migrations.AddField( + model_name="organization", + name="stripe_subscription", + field=models.OneToOneField( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="rtd_organization", + to="djstripe.subscription", + verbose_name="Stripe subscription", + ), + ), + ] diff --git a/readthedocs/organizations/models.py b/readthedocs/organizations/models.py index 886d722b564..07f12a4985b 100644 --- a/readthedocs/organizations/models.py +++ b/readthedocs/organizations/models.py @@ -1,4 +1,5 @@ """Organizations models.""" +import structlog from autoslug import AutoSlugField from django.contrib.auth.models import User from django.contrib.contenttypes.models import ContentType @@ -6,6 +7,7 @@ from django.urls import reverse from django.utils.crypto import salted_hmac from django.utils.translation import gettext_lazy as _ +from djstripe.enums import SubscriptionStatus from readthedocs.core.history import ExtraHistoricalRecords from readthedocs.core.permissions import AdminPermission @@ -16,6 +18,8 @@ from .querysets import OrganizationQuerySet from .utils import send_team_add_email +log = structlog.get_logger(__name__) + class Organization(models.Model): @@ -101,6 +105,14 @@ class Organization(models.Model): null=True, blank=True, ) + stripe_subscription = models.OneToOneField( + "djstripe.Subscription", + verbose_name=_("Stripe subscription"), + on_delete=models.SET_NULL, + related_name="rtd_organization", + null=True, + blank=True, + ) # Managers objects = OrganizationQuerySet.as_manager() @@ -115,15 +127,23 @@ class Meta: def __str__(self): return self.name - @property - def stripe_subscription(self): + def get_or_create_stripe_subscription(self): # TODO: remove this once we don't depend on our Subscription models. from readthedocs.subscriptions.models import Subscription subscription = Subscription.objects.get_or_create_default_subscription(self) if not subscription: # This only happens during development. + log.warning("No default subscription created.") return None + + # Active subscriptions take precedence over non-active subscriptions, + # otherwise we return the must recently created subscription. + active_subscription = self.stripe_customer.subscriptions.filter( + status=SubscriptionStatus.active + ).first() + if active_subscription: + return active_subscription return self.stripe_customer.subscriptions.latest() def get_absolute_url(self): diff --git a/readthedocs/subscriptions/event_handlers.py b/readthedocs/subscriptions/event_handlers.py index b755f50360a..2f1719fae5b 100644 --- a/readthedocs/subscriptions/event_handlers.py +++ b/readthedocs/subscriptions/event_handlers.py @@ -155,6 +155,15 @@ def checkout_completed(event): return stripe_subscription_id = event.data["object"]["subscription"] + stripe_subscription = djstripe.Subscription.objects.filter( + id=stripe_subscription_id + ).first() + if not stripe_subscription: + log.info("Stripe subscription not found.") + return + organization.stripe_subscription = stripe_subscription + organization.save() + _update_subscription_from_stripe( rtd_subscription=organization.subscription, stripe_subscription_id=stripe_subscription_id, diff --git a/readthedocs/subscriptions/tests/test_event_handlers.py b/readthedocs/subscriptions/tests/test_event_handlers.py index 47c027f7477..aec2943e936 100644 --- a/readthedocs/subscriptions/tests/test_event_handlers.py +++ b/readthedocs/subscriptions/tests/test_event_handlers.py @@ -168,11 +168,14 @@ def test_subscription_checkout_completed_event(self): status=SubscriptionStatus.canceled, ) + self.assertIsNone(self.organization.stripe_subscription) event_handlers.checkout_completed(event=event) subscription.refresh_from_db() + self.organization.refresh_from_db() self.assertEqual(subscription.stripe_id, stripe_subscription.id) self.assertEqual(subscription.status, SubscriptionStatus.active) + self.assertEqual(self.organization.stripe_subscription, stripe_subscription) @mock.patch("readthedocs.subscriptions.event_handlers.cancel_stripe_subscription") def test_cancel_trial_subscription_after_trial_has_ended( diff --git a/readthedocs/subscriptions/tests/test_views.py b/readthedocs/subscriptions/tests/test_views.py index 8f2f3f9a10f..61d280cab7f 100644 --- a/readthedocs/subscriptions/tests/test_views.py +++ b/readthedocs/subscriptions/tests/test_views.py @@ -34,6 +34,7 @@ def setUp(self): self.stripe_customer = self.stripe_subscription.customer self.organization.stripe_customer = self.stripe_customer + self.organization.stripe_subscription = self.stripe_subscription self.organization.save() self.subscription = get( Subscription, @@ -113,10 +114,12 @@ def test_user_without_subscription( self.organization.refresh_from_db() self.organization.stripe_customer = None + self.organization.stripe_subscription = None self.organization.save() self.subscription.delete() self.assertFalse(hasattr(self.organization, 'subscription')) self.assertIsNone(self.organization.stripe_customer) + self.assertIsNone(self.organization.stripe_subscription) resp = self.client.get(reverse('subscription_detail', args=[self.organization.slug])) self.assertEqual(resp.status_code, 200) @@ -125,6 +128,7 @@ def test_user_without_subscription( self.assertEqual(subscription.status, 'active') self.assertEqual(subscription.stripe_id, 'sub_a1b2c3') self.assertEqual(self.organization.stripe_customer, stripe_customer) + self.assertEqual(self.organization.stripe_subscription, stripe_subscription) customer_retrieve_mock.assert_called_once() customer_create_mock.assert_not_called() @@ -146,12 +150,14 @@ def test_user_without_subscription_and_customer( # When stripe_id is None, a new customer is created. self.organization.stripe_id = None self.organization.stripe_customer = None + self.organization.stripe_subscription = None self.organization.save() self.subscription.delete() self.organization.refresh_from_db() self.assertFalse(hasattr(self.organization, 'subscription')) self.assertIsNone(self.organization.stripe_id) self.assertIsNone(self.organization.stripe_customer) + self.assertIsNone(self.organization.stripe_subscription) customer_retrieve_mock.reset_mock() resp = self.client.get(reverse('subscription_detail', args=[self.organization.slug])) @@ -162,6 +168,7 @@ def test_user_without_subscription_and_customer( self.assertEqual(subscription.stripe_id, 'sub_a1b2c3') self.assertEqual(self.organization.stripe_id, 'cus_a1b2c3') self.assertEqual(self.organization.stripe_customer, stripe_customer) + self.assertEqual(self.organization.stripe_subscription, stripe_subscription) customer_create_mock.assert_called_once() customer_retrieve_mock.assert_not_called() diff --git a/readthedocs/subscriptions/utils.py b/readthedocs/subscriptions/utils.py index e5334d50c68..7fcbeb38e1d 100644 --- a/readthedocs/subscriptions/utils.py +++ b/readthedocs/subscriptions/utils.py @@ -86,4 +86,7 @@ def get_or_create_stripe_subscription(organization): trial_period_days=settings.RTD_ORG_TRIAL_PERIOD_DAYS, ) stripe_subscription = stripe_customer.subscriptions.latest() + if organization.stripe_subscription != stripe_subscription: + organization.stripe_subscription = stripe_subscription + organization.save() return stripe_subscription diff --git a/readthedocs/subscriptions/views.py b/readthedocs/subscriptions/views.py index 0774c303f6e..a49d74aae2e 100644 --- a/readthedocs/subscriptions/views.py +++ b/readthedocs/subscriptions/views.py @@ -109,7 +109,7 @@ def get_object(self): We retry the operation when the user visits the subscription page. """ org = self.get_organization() - return org.stripe_subscription + return org.get_or_create_stripe_subscription() def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs)