diff --git a/corporate/lib/decorator.py b/corporate/lib/decorator.py index 8864883126..39b44deb90 100644 --- a/corporate/lib/decorator.py +++ b/corporate/lib/decorator.py @@ -1,17 +1,21 @@ from functools import wraps from typing import Callable +from urllib.parse import urlencode, urljoin from django.conf import settings -from django.http import HttpRequest, HttpResponse +from django.http import HttpRequest, HttpResponse, HttpResponseRedirect from django.shortcuts import render from typing_extensions import Concatenate, ParamSpec from corporate.lib.remote_billing_util import ( + RemoteBillingIdentityExpiredError, get_remote_realm_from_session, get_remote_server_from_session, ) from corporate.lib.stripe import RemoteRealmBillingSession, RemoteServerBillingSession from zerver.lib.subdomains import get_subdomain +from zerver.lib.url_encoding import append_url_query_string +from zilencer.models import RemoteRealm ParamT = ParamSpec("ParamT") @@ -52,7 +56,65 @@ def authenticated_remote_realm_management_endpoint( if realm_uuid is not None and not isinstance(realm_uuid, str): raise TypeError("realm_uuid must be a string or None") - remote_realm = get_remote_realm_from_session(request, realm_uuid) + try: + remote_realm = get_remote_realm_from_session(request, realm_uuid) + except RemoteBillingIdentityExpiredError as e: + # The user had an authenticated session with an identity_dict, + # but it expired. + # We want to redirect back to the start of their login flow + # at their {realm.host}/self-hosted-billing/ with a proper + # next parameter to take them back to what they're trying + # to access after re-authing. + # Note: Theoretically we could take the realm_uuid from the request + # path or params to figure out the remote_realm.host for the redirect, + # but that would mean leaking that .host value to anyone who knows + # the uuid. Therefore we limit ourselves to taking the realm_uuid + # from the identity_dict - since that proves that the user at least + # previously was successfully authenticated as a billing admin of that + # realm. + realm_uuid = e.realm_uuid + server_uuid = e.server_uuid + uri_scheme = e.uri_scheme + if realm_uuid is None: + # This doesn't make sense - if get_remote_realm_from_session + # found an expired identity dict, it should have had a realm_uuid. + raise AssertionError + + assert server_uuid is not None, "identity_dict with realm_uuid must have server_uuid" + assert uri_scheme is not None, "identity_dict with realm_uuid must have uri_scheme" + + try: + remote_realm = RemoteRealm.objects.get(uuid=realm_uuid, server__uuid=server_uuid) + except RemoteRealm.DoesNotExist: + # This should be impossible - unless the RemoteRealm existed and somehow the row + # was deleted. + raise AssertionError + + # Using EXTERNAL_URI_SCHEME means we'll do https:// in production, which is + # the sane default - while having http:// in development, which will allow + # these redirects to work there for testing. + url = urljoin(uri_scheme + remote_realm.host, "/self-hosted-billing/") + + # Our endpoint URLs in this subsystem end with something like + # /sponsorship or /plans etc. + # Therefore we can use this nice property to figure out easily what + # kind of page the user is trying to access and find the right value + # for the `next` query parameter. + path = request.path + if path.endswith("/"): # nocoverage + path = path[:-1] + + page_type = path.split("/")[-1] + + from corporate.views.remote_billing_page import ( + VALID_NEXT_PAGES as REMOTE_BILLING_VALID_NEXT_PAGES, + ) + + if page_type in REMOTE_BILLING_VALID_NEXT_PAGES: + query = urlencode({"next_page": page_type}) + url = append_url_query_string(url, query) + + return HttpResponseRedirect(url) billing_session = RemoteRealmBillingSession(remote_realm) return view_func(request, billing_session) diff --git a/corporate/lib/remote_billing_util.py b/corporate/lib/remote_billing_util.py index bd4d44e854..53b10dcd8f 100644 --- a/corporate/lib/remote_billing_util.py +++ b/corporate/lib/remote_billing_util.py @@ -1,11 +1,11 @@ import logging -from typing import Optional, TypedDict, Union, cast +from typing import Literal, Optional, TypedDict, Union, cast from django.http import HttpRequest from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ -from zerver.lib.exceptions import JsonableError +from zerver.lib.exceptions import JsonableError, RemoteBillingAuthenticationError from zerver.lib.timestamp import datetime_to_timestamp from zilencer.models import RemoteRealm, RemoteZulipServer @@ -29,6 +29,7 @@ class RemoteBillingIdentityDict(TypedDict): remote_realm_uuid: str authenticated_at: int + uri_scheme: Literal["http://", "https://"] next_page: Optional[str] @@ -41,6 +42,19 @@ class LegacyServerIdentityDict(TypedDict): authenticated_at: int +class RemoteBillingIdentityExpiredError(Exception): + def __init__( + self, + *, + realm_uuid: Optional[str] = None, + server_uuid: Optional[str] = None, + uri_scheme: Optional[Literal["http://", "https://"]] = None, + ) -> None: + self.realm_uuid = realm_uuid + self.server_uuid = server_uuid + self.uri_scheme = uri_scheme + + def get_identity_dict_from_session( request: HttpRequest, *, @@ -66,7 +80,14 @@ def get_identity_dict_from_session( datetime_to_timestamp(timezone_now()) - result["authenticated_at"] > REMOTE_BILLING_SESSION_VALIDITY_SECONDS ): - return None + # In this case we raise, because callers want to catch this as an explicitly + # different scenario from the user not being authenticated, to handle it nicely + # by redirecting them to their login page. + raise RemoteBillingIdentityExpiredError( + realm_uuid=result.get("remote_realm_uuid"), + server_uuid=result.get("remote_server_uuid"), + uri_scheme=result.get("uri_scheme"), + ) return result @@ -83,7 +104,7 @@ def get_remote_realm_from_session( ) if identity_dict is None: - raise JsonableError(_("User not authenticated")) + raise RemoteBillingAuthenticationError remote_server_uuid = identity_dict["remote_server_uuid"] remote_realm_uuid = identity_dict["remote_realm_uuid"] diff --git a/corporate/tests/test_remote_billing.py b/corporate/tests/test_remote_billing.py index 7c6e18c2ce..7836a4707a 100644 --- a/corporate/tests/test_remote_billing.py +++ b/corporate/tests/test_remote_billing.py @@ -55,6 +55,7 @@ class RemoteBillingAuthenticationTest(BouncerTestCase): remote_server_uuid=str(self.server.uuid), remote_realm_uuid=str(user.realm.uuid), authenticated_at=datetime_to_timestamp(now), + uri_scheme="http://", next_page=next_page, ) self.assertEqual( @@ -154,13 +155,81 @@ class RemoteBillingAuthenticationTest(BouncerTestCase): self.assert_in_success_response([desdemona.delivery_email], result) # Now go there again, simulating doing this after the session has expired. - # We should be denied access. + # We should be denied access and redirected to re-auth. with time_machine.travel( now + datetime.timedelta(seconds=REMOTE_BILLING_SESSION_VALIDITY_SECONDS + 1), tick=False, ): result = self.client_get(final_url, subdomain="selfhosting") - self.assert_json_error(result, "User not authenticated") + + self.assertEqual(result.status_code, 302) + self.assertEqual( + result["Location"], + f"http://{desdemona.realm.host}/self-hosted-billing/?next_page=plans", + ) + + # Opening this re-auth URL in result["Location"] is same as re-doing the auth + # flow via execute_remote_billing_authentication_flow with next_page="plans". + # So let's test that and assert that we end up successfully re-authed on the /plans + # page. + result = self.execute_remote_billing_authentication_flow(desdemona, next_page="plans") + self.assertEqual(result["Location"], f"/realm/{realm.uuid!s}/plans") + result = self.client_get(result["Location"], subdomain="selfhosting") + self.assert_in_success_response(["Your remote user info:"], result) + self.assert_in_success_response([desdemona.delivery_email], result) + + @responses.activate + def test_remote_billing_unauthed_access(self) -> None: + now = timezone_now() + self.login("desdemona") + desdemona = self.example_user("desdemona") + realm = desdemona.realm + + self.add_mock_response() + send_realms_only_to_push_bouncer() + + # Straight-up access without authing at all: + result = self.client_get(f"/realm/{realm.uuid!s}/plans", subdomain="selfhosting") + self.assert_json_error(result, "User not authenticated", 401) + + result = self.execute_remote_billing_authentication_flow(desdemona) + self.assertEqual(result["Location"], f"/realm/{realm.uuid!s}/plans") + + final_url = result["Location"] + + # Sanity check - access is granted after authing: + result = self.client_get(final_url, subdomain="selfhosting") + self.assertEqual(result.status_code, 200) + + # Now mess with the identity dict in the session in unlikely ways so that it should + # not grant access. + # First delete the RemoteRealm entry for this session. + RemoteRealm.objects.filter(uuid=realm.uuid).delete() + + with self.assertLogs("django.request", "ERROR") as m, self.assertRaises(AssertionError): + self.client_get(final_url, subdomain="selfhosting") + self.assertIn( + "The remote realm is missing despite being in the RemoteBillingIdentityDict", + m.output[0], + ) + + # Try the case where the identity dict is simultaneously expired. + with time_machine.travel( + now + datetime.timedelta(seconds=REMOTE_BILLING_SESSION_VALIDITY_SECONDS + 30), + tick=False, + ): + with self.assertLogs("django.request", "ERROR") as m, self.assertRaises(AssertionError): + self.client_get(final_url, subdomain="selfhosting") + # The django.request log should be a traceback, mentioning the relevant + # exceptions that occurred. + self.assertIn( + "RemoteBillingIdentityExpiredError", + m.output[0], + ) + self.assertIn( + "AssertionError", + m.output[0], + ) @responses.activate def test_remote_billing_authentication_flow_to_sponsorship_page(self) -> None: diff --git a/corporate/views/remote_billing_page.py b/corporate/views/remote_billing_page.py index be39f3e2d1..ed727269c4 100644 --- a/corporate/views/remote_billing_page.py +++ b/corporate/views/remote_billing_page.py @@ -12,7 +12,10 @@ from django.utils.translation import gettext as _ from django.views.decorators.csrf import csrf_exempt from pydantic import Json -from corporate.lib.decorator import self_hosting_management_endpoint +from corporate.lib.decorator import ( + authenticated_remote_realm_management_endpoint, + self_hosting_management_endpoint, +) from corporate.lib.remote_billing_util import ( REMOTE_BILLING_SESSION_VALIDITY_SECONDS, LegacyServerIdentityDict, @@ -20,6 +23,7 @@ from corporate.lib.remote_billing_util import ( RemoteBillingUserDict, get_identity_dict_from_session, ) +from corporate.lib.stripe import RemoteRealmBillingSession from zerver.lib.exceptions import JsonableError, MissingRemoteRealmError from zerver.lib.remote_server import RealmDataForAnalytics, UserDataForRemoteBilling from zerver.lib.response import json_success @@ -42,6 +46,7 @@ def remote_server_billing_entry( *, user: Json[UserDataForRemoteBilling], realm: Json[RealmDataForAnalytics], + uri_scheme: Literal["http://", "https://"] = "https://", next_page: VALID_NEXT_PAGES_TYPE = None, ) -> HttpResponse: if not settings.DEVELOPMENT: @@ -61,6 +66,7 @@ def remote_server_billing_entry( remote_server_uuid=str(remote_server.uuid), remote_realm_uuid=str(remote_realm.uuid), authenticated_at=datetime_to_timestamp(timezone_now()), + uri_scheme=uri_scheme, next_page=next_page, ) @@ -194,9 +200,11 @@ def remote_billing_plans_common( return render_tmp_remote_billing_page(request, realm_uuid=realm_uuid, server_uuid=server_uuid) -@self_hosting_management_endpoint -@typed_endpoint -def remote_realm_plans_page(request: HttpRequest, *, realm_uuid: PathOnly[str]) -> HttpResponse: +@authenticated_remote_realm_management_endpoint +def remote_realm_plans_page( + request: HttpRequest, billing_session: RemoteRealmBillingSession +) -> HttpResponse: + realm_uuid = str(billing_session.remote_realm.uuid) return remote_billing_plans_common(request, realm_uuid=realm_uuid, server_uuid=None) diff --git a/zerver/lib/exceptions.py b/zerver/lib/exceptions.py index c2bb334b9f..9393626a14 100644 --- a/zerver/lib/exceptions.py +++ b/zerver/lib/exceptions.py @@ -49,6 +49,7 @@ class ErrorCode(Enum): MISSING_REMOTE_REALM = auto() TOPIC_WILDCARD_MENTION_NOT_ALLOWED = auto() STREAM_WILDCARD_MENTION_NOT_ALLOWED = auto() + REMOTE_BILLING_UNAUTHENTICATED_USER = auto() class JsonableError(Exception): @@ -445,6 +446,22 @@ class MissingAuthenticationError(JsonableError): # converted into json_unauthorized in Zulip's middleware. +class RemoteBillingAuthenticationError(JsonableError): + # We want this as a distinct class from MissingAuthenticationError, + # as we don't want the json_unauthorized conversion mechanism to apply + # to this. + code = ErrorCode.REMOTE_BILLING_UNAUTHENTICATED_USER + http_status_code = 401 + + def __init__(self) -> None: + pass + + @staticmethod + @override + def msg_format() -> str: + return _("User not authenticated") + + class InvalidSubdomainError(JsonableError): code = ErrorCode.NONEXISTENT_SUBDOMAIN http_status_code = 404 diff --git a/zerver/views/push_notifications.py b/zerver/views/push_notifications.py index a2af156f0f..46687a3bff 100644 --- a/zerver/views/push_notifications.py +++ b/zerver/views/push_notifications.py @@ -148,6 +148,11 @@ def self_hosting_auth_redirect( post_data = { "user": user_info.model_dump_json(), "realm": realm_info.model_dump_json(), + # The uri_scheme is necessary for the bouncer to know the correct URL + # to redirect the user to for re-authing in case the session expires. + # Otherwise, the bouncer would know only the realm.host but be missing + # the knowledge of whether to use http or https. + "uri_scheme": settings.EXTERNAL_URI_SCHEME, } if next_page is not None: post_data["next_page"] = next_page