mas_handlers/oauth2/
token.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::sync::{Arc, LazyLock};
8
9use axum::{Json, extract::State, response::IntoResponse};
10use axum_extra::typed_header::TypedHeader;
11use chrono::Duration;
12use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
13use hyper::StatusCode;
14use mas_axum_utils::{
15    client_authorization::{ClientAuthorization, CredentialsVerificationError},
16    record_error,
17};
18use mas_data_model::{
19    AuthorizationGrantStage, Client, Device, DeviceCodeGrantState, SiteConfig, TokenType,
20};
21use mas_i18n::DataLocale;
22use mas_keystore::{Encrypter, Keystore};
23use mas_matrix::HomeserverConnection;
24use mas_oidc_client::types::scope::ScopeToken;
25use mas_policy::Policy;
26use mas_router::UrlBuilder;
27use mas_storage::{
28    BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
29    oauth2::{
30        OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
31        OAuth2RefreshTokenRepository, OAuth2SessionRepository,
32    },
33    user::BrowserSessionRepository,
34};
35use mas_templates::{DeviceNameContext, TemplateContext, Templates};
36use oauth2_types::{
37    errors::{ClientError, ClientErrorCode},
38    pkce::CodeChallengeError,
39    requests::{
40        AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, ClientCredentialsGrant,
41        DeviceCodeGrant, GrantType, RefreshTokenGrant,
42    },
43    scope,
44};
45use opentelemetry::{Key, KeyValue, metrics::Counter};
46use thiserror::Error;
47use tracing::{debug, info, warn};
48use ulid::Ulid;
49
50use super::{generate_id_token, generate_token_pair};
51use crate::{BoundActivityTracker, METER, impl_from_error_for_route};
52
53static TOKEN_REQUEST_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
54    METER
55        .u64_counter("mas.oauth2.token_request")
56        .with_description("How many OAuth 2.0 token requests have gone through")
57        .with_unit("{request}")
58        .build()
59});
60const GRANT_TYPE: Key = Key::from_static_str("grant_type");
61const RESULT: Key = Key::from_static_str("successful");
62
63#[derive(Debug, Error)]
64pub(crate) enum RouteError {
65    #[error(transparent)]
66    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
67
68    #[error("bad request")]
69    BadRequest,
70
71    #[error("pkce verification failed")]
72    PkceVerification(#[from] CodeChallengeError),
73
74    #[error("client not found")]
75    ClientNotFound,
76
77    #[error("client not allowed to use the token endpoint: {0}")]
78    ClientNotAllowed(Ulid),
79
80    #[error("invalid client credentials for client {client_id}")]
81    InvalidClientCredentials {
82        client_id: Ulid,
83        #[source]
84        source: CredentialsVerificationError,
85    },
86
87    #[error("could not verify client credentials for client {client_id}")]
88    ClientCredentialsVerification {
89        client_id: Ulid,
90        #[source]
91        source: CredentialsVerificationError,
92    },
93
94    #[error("grant not found")]
95    GrantNotFound,
96
97    #[error("invalid grant {0}")]
98    InvalidGrant(Ulid),
99
100    #[error("refresh token not found")]
101    RefreshTokenNotFound,
102
103    #[error("refresh token {0} is invalid")]
104    RefreshTokenInvalid(Ulid),
105
106    #[error("session {0} is invalid")]
107    SessionInvalid(Ulid),
108
109    #[error("client id mismatch: expected {expected}, got {actual}")]
110    ClientIDMismatch { expected: Ulid, actual: Ulid },
111
112    #[error("policy denied the request: {0}")]
113    DeniedByPolicy(mas_policy::EvaluationResult),
114
115    #[error("unsupported grant type")]
116    UnsupportedGrantType,
117
118    #[error("client {0} is not authorized to use this grant type")]
119    UnauthorizedClient(Ulid),
120
121    #[error("unexpected client {was} (expected {expected})")]
122    UnexptectedClient { was: Ulid, expected: Ulid },
123
124    #[error("failed to load browser session {0}")]
125    NoSuchBrowserSession(Ulid),
126
127    #[error("failed to load oauth session {0}")]
128    NoSuchOAuthSession(Ulid),
129
130    #[error(
131        "failed to load the next refresh token ({next:?}) from the previous one ({previous:?})"
132    )]
133    NoSuchNextRefreshToken { next: Ulid, previous: Ulid },
134
135    #[error(
136        "failed to load the access token ({access_token:?}) associated with the next refresh token ({refresh_token:?})"
137    )]
138    NoSuchNextAccessToken {
139        access_token: Ulid,
140        refresh_token: Ulid,
141    },
142
143    #[error("no access token associated with the refresh token {refresh_token:?}")]
144    NoAccessTokenOnRefreshToken { refresh_token: Ulid },
145
146    #[error("device code grant expired")]
147    DeviceCodeExpired,
148
149    #[error("device code grant is still pending")]
150    DeviceCodePending,
151
152    #[error("device code grant was rejected")]
153    DeviceCodeRejected,
154
155    #[error("device code grant was already exchanged")]
156    DeviceCodeExchanged,
157
158    #[error("failed to provision device")]
159    ProvisionDeviceFailed(#[source] anyhow::Error),
160}
161
162impl IntoResponse for RouteError {
163    fn into_response(self) -> axum::response::Response {
164        let sentry_event_id = record_error!(
165            self,
166            Self::Internal(_)
167                | Self::ClientCredentialsVerification { .. }
168                | Self::NoSuchBrowserSession(_)
169                | Self::NoSuchOAuthSession(_)
170                | Self::ProvisionDeviceFailed(_)
171                | Self::NoSuchNextRefreshToken { .. }
172                | Self::NoSuchNextAccessToken { .. }
173                | Self::NoAccessTokenOnRefreshToken { .. }
174        );
175
176        TOKEN_REQUEST_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
177
178        let response = match self {
179            Self::Internal(_)
180            | Self::ClientCredentialsVerification { .. }
181            | Self::NoSuchBrowserSession(_)
182            | Self::NoSuchOAuthSession(_)
183            | Self::ProvisionDeviceFailed(_)
184            | Self::NoSuchNextRefreshToken { .. }
185            | Self::NoSuchNextAccessToken { .. }
186            | Self::NoAccessTokenOnRefreshToken { .. } => (
187                StatusCode::INTERNAL_SERVER_ERROR,
188                Json(ClientError::from(ClientErrorCode::ServerError)),
189            ),
190
191            Self::BadRequest => (
192                StatusCode::BAD_REQUEST,
193                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
194            ),
195
196            Self::PkceVerification(err) => (
197                StatusCode::BAD_REQUEST,
198                Json(
199                    ClientError::from(ClientErrorCode::InvalidGrant)
200                        .with_description(format!("PKCE verification failed: {err}")),
201                ),
202            ),
203
204            Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
205                StatusCode::UNAUTHORIZED,
206                Json(ClientError::from(ClientErrorCode::InvalidClient)),
207            ),
208
209            Self::ClientNotAllowed(_)
210            | Self::UnauthorizedClient(_)
211            | Self::UnexptectedClient { .. } => (
212                StatusCode::UNAUTHORIZED,
213                Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
214            ),
215
216            Self::DeniedByPolicy(evaluation) => (
217                StatusCode::FORBIDDEN,
218                Json(
219                    ClientError::from(ClientErrorCode::InvalidScope).with_description(
220                        evaluation
221                            .violations
222                            .into_iter()
223                            .map(|violation| violation.msg)
224                            .collect::<Vec<_>>()
225                            .join(", "),
226                    ),
227                ),
228            ),
229
230            Self::DeviceCodeRejected => (
231                StatusCode::FORBIDDEN,
232                Json(ClientError::from(ClientErrorCode::AccessDenied)),
233            ),
234
235            Self::DeviceCodeExpired => (
236                StatusCode::FORBIDDEN,
237                Json(ClientError::from(ClientErrorCode::ExpiredToken)),
238            ),
239
240            Self::DeviceCodePending => (
241                StatusCode::FORBIDDEN,
242                Json(ClientError::from(ClientErrorCode::AuthorizationPending)),
243            ),
244
245            Self::InvalidGrant(_)
246            | Self::DeviceCodeExchanged
247            | Self::RefreshTokenNotFound
248            | Self::RefreshTokenInvalid(_)
249            | Self::SessionInvalid(_)
250            | Self::ClientIDMismatch { .. }
251            | Self::GrantNotFound => (
252                StatusCode::BAD_REQUEST,
253                Json(ClientError::from(ClientErrorCode::InvalidGrant)),
254            ),
255
256            Self::UnsupportedGrantType => (
257                StatusCode::BAD_REQUEST,
258                Json(ClientError::from(ClientErrorCode::UnsupportedGrantType)),
259            ),
260        };
261
262        (sentry_event_id, response).into_response()
263    }
264}
265
266impl_from_error_for_route!(mas_i18n::DataError);
267impl_from_error_for_route!(mas_templates::TemplateError);
268impl_from_error_for_route!(mas_storage::RepositoryError);
269impl_from_error_for_route!(mas_policy::EvaluationError);
270impl_from_error_for_route!(super::IdTokenSignatureError);
271
272#[tracing::instrument(
273    name = "handlers.oauth2.token.post",
274    fields(client.id = client_authorization.client_id()),
275    skip_all,
276)]
277pub(crate) async fn post(
278    mut rng: BoxRng,
279    clock: BoxClock,
280    State(http_client): State<reqwest::Client>,
281    State(key_store): State<Keystore>,
282    State(url_builder): State<UrlBuilder>,
283    activity_tracker: BoundActivityTracker,
284    mut repo: BoxRepository,
285    State(homeserver): State<Arc<dyn HomeserverConnection>>,
286    State(site_config): State<SiteConfig>,
287    State(encrypter): State<Encrypter>,
288    State(templates): State<Templates>,
289    policy: Policy,
290    user_agent: Option<TypedHeader<headers::UserAgent>>,
291    client_authorization: ClientAuthorization<AccessTokenRequest>,
292) -> Result<impl IntoResponse, RouteError> {
293    let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
294    let client = client_authorization
295        .credentials
296        .fetch(&mut repo)
297        .await?
298        .ok_or(RouteError::ClientNotFound)?;
299
300    let method = client
301        .token_endpoint_auth_method
302        .as_ref()
303        .ok_or(RouteError::ClientNotAllowed(client.id))?;
304
305    client_authorization
306        .credentials
307        .verify(&http_client, &encrypter, method, &client)
308        .await
309        .map_err(|err| {
310            // Classify the error differntly, depending on whether it's an 'internal' error,
311            // or just because the client presented invalid credentials.
312            if err.is_internal() {
313                RouteError::ClientCredentialsVerification {
314                    client_id: client.id,
315                    source: err,
316                }
317            } else {
318                RouteError::InvalidClientCredentials {
319                    client_id: client.id,
320                    source: err,
321                }
322            }
323        })?;
324
325    let form = client_authorization.form.ok_or(RouteError::BadRequest)?;
326
327    let grant_type = form.grant_type();
328
329    let (reply, repo) = match form {
330        AccessTokenRequest::AuthorizationCode(grant) => {
331            authorization_code_grant(
332                &mut rng,
333                &clock,
334                &activity_tracker,
335                &grant,
336                &client,
337                &key_store,
338                &url_builder,
339                &site_config,
340                repo,
341                &homeserver,
342                &templates,
343                user_agent,
344            )
345            .await?
346        }
347        AccessTokenRequest::RefreshToken(grant) => {
348            refresh_token_grant(
349                &mut rng,
350                &clock,
351                &activity_tracker,
352                &grant,
353                &client,
354                &site_config,
355                repo,
356                user_agent,
357            )
358            .await?
359        }
360        AccessTokenRequest::ClientCredentials(grant) => {
361            client_credentials_grant(
362                &mut rng,
363                &clock,
364                &activity_tracker,
365                &grant,
366                &client,
367                &site_config,
368                repo,
369                policy,
370                user_agent,
371            )
372            .await?
373        }
374        AccessTokenRequest::DeviceCode(grant) => {
375            device_code_grant(
376                &mut rng,
377                &clock,
378                &activity_tracker,
379                &grant,
380                &client,
381                &key_store,
382                &url_builder,
383                &site_config,
384                repo,
385                &homeserver,
386                user_agent,
387            )
388            .await?
389        }
390        _ => {
391            return Err(RouteError::UnsupportedGrantType);
392        }
393    };
394
395    repo.save().await?;
396
397    TOKEN_REQUEST_COUNTER.add(
398        1,
399        &[
400            KeyValue::new(GRANT_TYPE, grant_type),
401            KeyValue::new(RESULT, "success"),
402        ],
403    );
404
405    let mut headers = HeaderMap::new();
406    headers.typed_insert(CacheControl::new().with_no_store());
407    headers.typed_insert(Pragma::no_cache());
408
409    Ok((headers, Json(reply)))
410}
411
412#[allow(clippy::too_many_lines)] // TODO: refactor some parts out
413async fn authorization_code_grant(
414    mut rng: &mut BoxRng,
415    clock: &impl Clock,
416    activity_tracker: &BoundActivityTracker,
417    grant: &AuthorizationCodeGrant,
418    client: &Client,
419    key_store: &Keystore,
420    url_builder: &UrlBuilder,
421    site_config: &SiteConfig,
422    mut repo: BoxRepository,
423    homeserver: &Arc<dyn HomeserverConnection>,
424    templates: &Templates,
425    user_agent: Option<String>,
426) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
427    // Check that the client is allowed to use this grant type
428    if !client.grant_types.contains(&GrantType::AuthorizationCode) {
429        return Err(RouteError::UnauthorizedClient(client.id));
430    }
431
432    let authz_grant = repo
433        .oauth2_authorization_grant()
434        .find_by_code(&grant.code)
435        .await?
436        .ok_or(RouteError::GrantNotFound)?;
437
438    let now = clock.now();
439
440    let session_id = match authz_grant.stage {
441        AuthorizationGrantStage::Cancelled { cancelled_at } => {
442            debug!(%cancelled_at, "Authorization grant was cancelled");
443            return Err(RouteError::InvalidGrant(authz_grant.id));
444        }
445        AuthorizationGrantStage::Exchanged {
446            exchanged_at,
447            fulfilled_at,
448            session_id,
449        } => {
450            warn!(%exchanged_at, %fulfilled_at, "Authorization code was already exchanged");
451
452            // Ending the session if the token was already exchanged more than 20s ago
453            if now - exchanged_at > Duration::microseconds(20 * 1000 * 1000) {
454                warn!(oauth_session.id = %session_id, "Ending potentially compromised session");
455                let session = repo
456                    .oauth2_session()
457                    .lookup(session_id)
458                    .await?
459                    .ok_or(RouteError::NoSuchOAuthSession(session_id))?;
460
461                //if !session.is_finished() {
462                repo.oauth2_session().finish(clock, session).await?;
463                repo.save().await?;
464                //}
465            }
466
467            return Err(RouteError::InvalidGrant(authz_grant.id));
468        }
469        AuthorizationGrantStage::Pending => {
470            warn!("Authorization grant has not been fulfilled yet");
471            return Err(RouteError::InvalidGrant(authz_grant.id));
472        }
473        AuthorizationGrantStage::Fulfilled {
474            session_id,
475            fulfilled_at,
476        } => {
477            if now - fulfilled_at > Duration::microseconds(10 * 60 * 1000 * 1000) {
478                warn!("Code exchange took more than 10 minutes");
479                return Err(RouteError::InvalidGrant(authz_grant.id));
480            }
481
482            session_id
483        }
484    };
485
486    let mut session = repo
487        .oauth2_session()
488        .lookup(session_id)
489        .await?
490        .ok_or(RouteError::NoSuchOAuthSession(session_id))?;
491
492    // Generate a device name
493    let lang: DataLocale = authz_grant.locale.as_deref().unwrap_or("en").parse()?;
494    let ctx = DeviceNameContext::new(client.clone(), user_agent.clone()).with_language(lang);
495    let device_name = templates.render_device_name(&ctx)?;
496
497    if let Some(user_agent) = user_agent {
498        session = repo
499            .oauth2_session()
500            .record_user_agent(session, user_agent)
501            .await?;
502    }
503
504    // This should never happen, since we looked up in the database using the code
505    let code = authz_grant
506        .code
507        .as_ref()
508        .ok_or(RouteError::InvalidGrant(authz_grant.id))?;
509
510    if client.id != session.client_id {
511        return Err(RouteError::UnexptectedClient {
512            was: client.id,
513            expected: session.client_id,
514        });
515    }
516
517    match (code.pkce.as_ref(), grant.code_verifier.as_ref()) {
518        (None, None) => {}
519        // We have a challenge but no verifier (or vice-versa)? Bad request.
520        (Some(_), None) | (None, Some(_)) => return Err(RouteError::BadRequest),
521        // If we have both, we need to check the code validity
522        (Some(pkce), Some(verifier)) => {
523            pkce.verify(verifier)?;
524        }
525    }
526
527    let Some(user_session_id) = session.user_session_id else {
528        tracing::warn!("No user session associated with this OAuth2 session");
529        return Err(RouteError::InvalidGrant(authz_grant.id));
530    };
531
532    let browser_session = repo
533        .browser_session()
534        .lookup(user_session_id)
535        .await?
536        .ok_or(RouteError::NoSuchBrowserSession(user_session_id))?;
537
538    let last_authentication = repo
539        .browser_session()
540        .get_last_authentication(&browser_session)
541        .await?;
542
543    let ttl = site_config.access_token_ttl;
544    let (access_token, refresh_token) =
545        generate_token_pair(&mut rng, clock, &mut repo, &session, ttl).await?;
546
547    let id_token = if session.scope.contains(&scope::OPENID) {
548        Some(generate_id_token(
549            &mut rng,
550            clock,
551            url_builder,
552            key_store,
553            client,
554            Some(&authz_grant),
555            &browser_session,
556            Some(&access_token),
557            last_authentication.as_ref(),
558        )?)
559    } else {
560        None
561    };
562
563    let mut params = AccessTokenResponse::new(access_token.access_token)
564        .with_expires_in(ttl)
565        .with_refresh_token(refresh_token.refresh_token)
566        .with_scope(session.scope.clone());
567
568    if let Some(id_token) = id_token {
569        params = params.with_id_token(id_token);
570    }
571
572    // Lock the user sync to make sure we don't get into a race condition
573    repo.user()
574        .acquire_lock_for_sync(&browser_session.user)
575        .await?;
576
577    // Look for device to provision
578    let mxid = homeserver.mxid(&browser_session.user.username);
579    for scope in &*session.scope {
580        if let Some(device) = Device::from_scope_token(scope) {
581            homeserver
582                .create_device(&mxid, device.as_str(), Some(&device_name))
583                .await
584                .map_err(RouteError::ProvisionDeviceFailed)?;
585        }
586    }
587
588    repo.oauth2_authorization_grant()
589        .exchange(clock, authz_grant)
590        .await?;
591
592    // XXX: there is a potential (but unlikely) race here, where the activity for
593    // the session is recorded before the transaction is committed. We would have to
594    // save the repository here to fix that.
595    activity_tracker
596        .record_oauth2_session(clock, &session)
597        .await;
598
599    Ok((params, repo))
600}
601
602#[allow(clippy::too_many_lines)]
603async fn refresh_token_grant(
604    rng: &mut BoxRng,
605    clock: &impl Clock,
606    activity_tracker: &BoundActivityTracker,
607    grant: &RefreshTokenGrant,
608    client: &Client,
609    site_config: &SiteConfig,
610    mut repo: BoxRepository,
611    user_agent: Option<String>,
612) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
613    // Check that the client is allowed to use this grant type
614    if !client.grant_types.contains(&GrantType::RefreshToken) {
615        return Err(RouteError::UnauthorizedClient(client.id));
616    }
617
618    let refresh_token = repo
619        .oauth2_refresh_token()
620        .find_by_token(&grant.refresh_token)
621        .await?
622        .ok_or(RouteError::RefreshTokenNotFound)?;
623
624    let mut session = repo
625        .oauth2_session()
626        .lookup(refresh_token.session_id)
627        .await?
628        .ok_or(RouteError::NoSuchOAuthSession(refresh_token.session_id))?;
629
630    // Let's for now record the user agent on each refresh, that should be
631    // responsive enough and not too much of a burden on the database.
632    if let Some(user_agent) = user_agent {
633        session = repo
634            .oauth2_session()
635            .record_user_agent(session, user_agent)
636            .await?;
637    }
638
639    if !session.is_valid() {
640        return Err(RouteError::SessionInvalid(session.id));
641    }
642
643    if client.id != session.client_id {
644        // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
645        return Err(RouteError::ClientIDMismatch {
646            expected: session.client_id,
647            actual: client.id,
648        });
649    }
650
651    if !refresh_token.is_valid() {
652        // We're seing a refresh token that already has been consumed, this might be a
653        // double-refresh or a replay attack
654
655        // First, get the next refresh token
656        let Some(next_refresh_token_id) = refresh_token.next_refresh_token_id() else {
657            // If we don't have a 'next' refresh token, it may just be because this was
658            // before we were recording those. Let's just treat it as a replay.
659            return Err(RouteError::RefreshTokenInvalid(refresh_token.id));
660        };
661
662        let Some(next_refresh_token) = repo
663            .oauth2_refresh_token()
664            .lookup(next_refresh_token_id)
665            .await?
666        else {
667            return Err(RouteError::NoSuchNextRefreshToken {
668                next: next_refresh_token_id,
669                previous: refresh_token.id,
670            });
671        };
672
673        // Check if the next refresh token was already consumed or not
674        if !next_refresh_token.is_valid() {
675            // XXX: This is a replay, we *may* want to invalidate the session
676            return Err(RouteError::RefreshTokenInvalid(next_refresh_token.id));
677        }
678
679        // Check if the associated access token was already used
680        let Some(access_token_id) = next_refresh_token.access_token_id else {
681            // This should in theory not happen: this means an access token got cleaned up,
682            // but the refresh token was still valid.
683            return Err(RouteError::NoAccessTokenOnRefreshToken {
684                refresh_token: next_refresh_token.id,
685            });
686        };
687
688        // Load it
689        let next_access_token = repo
690            .oauth2_access_token()
691            .lookup(access_token_id)
692            .await?
693            .ok_or(RouteError::NoSuchNextAccessToken {
694                access_token: access_token_id,
695                refresh_token: next_refresh_token_id,
696            })?;
697
698        if next_access_token.is_used() {
699            // XXX: This is a replay, we *may* want to invalidate the session
700            return Err(RouteError::RefreshTokenInvalid(next_refresh_token.id));
701        }
702
703        // Looks like it's a double-refresh, client lost their refresh token on
704        // the way back. Let's revoke the unused access and refresh tokens, and
705        // issue new ones
706        info!(
707            oauth2_session.id = %session.id,
708            oauth2_client.id = %client.id,
709            %refresh_token.id,
710            "Refresh token already used, but issued refresh and access tokens are unused. Assuming those were lost; revoking those and reissuing new ones."
711        );
712
713        repo.oauth2_access_token()
714            .revoke(clock, next_access_token)
715            .await?;
716
717        repo.oauth2_refresh_token()
718            .revoke(clock, next_refresh_token)
719            .await?;
720    }
721
722    activity_tracker
723        .record_oauth2_session(clock, &session)
724        .await;
725
726    let ttl = site_config.access_token_ttl;
727    let (new_access_token, new_refresh_token) =
728        generate_token_pair(rng, clock, &mut repo, &session, ttl).await?;
729
730    let refresh_token = repo
731        .oauth2_refresh_token()
732        .consume(clock, refresh_token, &new_refresh_token)
733        .await?;
734
735    if let Some(access_token_id) = refresh_token.access_token_id {
736        let access_token = repo.oauth2_access_token().lookup(access_token_id).await?;
737        if let Some(access_token) = access_token {
738            // If it is a double-refresh, it might already be revoked
739            if !access_token.state.is_revoked() {
740                repo.oauth2_access_token()
741                    .revoke(clock, access_token)
742                    .await?;
743            }
744        }
745    }
746
747    let params = AccessTokenResponse::new(new_access_token.access_token)
748        .with_expires_in(ttl)
749        .with_refresh_token(new_refresh_token.refresh_token)
750        .with_scope(session.scope);
751
752    Ok((params, repo))
753}
754
755async fn client_credentials_grant(
756    rng: &mut BoxRng,
757    clock: &impl Clock,
758    activity_tracker: &BoundActivityTracker,
759    grant: &ClientCredentialsGrant,
760    client: &Client,
761    site_config: &SiteConfig,
762    mut repo: BoxRepository,
763    mut policy: Policy,
764    user_agent: Option<String>,
765) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
766    // Check that the client is allowed to use this grant type
767    if !client.grant_types.contains(&GrantType::ClientCredentials) {
768        return Err(RouteError::UnauthorizedClient(client.id));
769    }
770
771    // Default to an empty scope if none is provided
772    let scope = grant
773        .scope
774        .clone()
775        .unwrap_or_else(|| std::iter::empty::<ScopeToken>().collect());
776
777    // Make the request go through the policy engine
778    let res = policy
779        .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
780            user: None,
781            client,
782            scope: &scope,
783            grant_type: mas_policy::GrantType::ClientCredentials,
784            requester: mas_policy::Requester {
785                ip_address: activity_tracker.ip(),
786                user_agent: user_agent.clone(),
787            },
788        })
789        .await?;
790    if !res.valid() {
791        return Err(RouteError::DeniedByPolicy(res));
792    }
793
794    // Start the session
795    let mut session = repo
796        .oauth2_session()
797        .add_from_client_credentials(rng, clock, client, scope)
798        .await?;
799
800    if let Some(user_agent) = user_agent {
801        session = repo
802            .oauth2_session()
803            .record_user_agent(session, user_agent)
804            .await?;
805    }
806
807    let ttl = site_config.access_token_ttl;
808    let access_token_str = TokenType::AccessToken.generate(rng);
809
810    let access_token = repo
811        .oauth2_access_token()
812        .add(rng, clock, &session, access_token_str, Some(ttl))
813        .await?;
814
815    let mut params = AccessTokenResponse::new(access_token.access_token).with_expires_in(ttl);
816
817    // XXX: there is a potential (but unlikely) race here, where the activity for
818    // the session is recorded before the transaction is committed. We would have to
819    // save the repository here to fix that.
820    activity_tracker
821        .record_oauth2_session(clock, &session)
822        .await;
823
824    if !session.scope.is_empty() {
825        // We only return the scope if it's not empty
826        params = params.with_scope(session.scope);
827    }
828
829    Ok((params, repo))
830}
831
832async fn device_code_grant(
833    rng: &mut BoxRng,
834    clock: &impl Clock,
835    activity_tracker: &BoundActivityTracker,
836    grant: &DeviceCodeGrant,
837    client: &Client,
838    key_store: &Keystore,
839    url_builder: &UrlBuilder,
840    site_config: &SiteConfig,
841    mut repo: BoxRepository,
842    homeserver: &Arc<dyn HomeserverConnection>,
843    user_agent: Option<String>,
844) -> Result<(AccessTokenResponse, BoxRepository), RouteError> {
845    // Check that the client is allowed to use this grant type
846    if !client.grant_types.contains(&GrantType::DeviceCode) {
847        return Err(RouteError::UnauthorizedClient(client.id));
848    }
849
850    let grant = repo
851        .oauth2_device_code_grant()
852        .find_by_device_code(&grant.device_code)
853        .await?
854        .ok_or(RouteError::GrantNotFound)?;
855
856    // Check that the client match
857    if client.id != grant.client_id {
858        return Err(RouteError::ClientIDMismatch {
859            expected: grant.client_id,
860            actual: client.id,
861        });
862    }
863
864    if grant.expires_at < clock.now() {
865        return Err(RouteError::DeviceCodeExpired);
866    }
867
868    let browser_session_id = match &grant.state {
869        DeviceCodeGrantState::Pending => {
870            return Err(RouteError::DeviceCodePending);
871        }
872        DeviceCodeGrantState::Rejected { .. } => {
873            return Err(RouteError::DeviceCodeRejected);
874        }
875        DeviceCodeGrantState::Exchanged { .. } => {
876            return Err(RouteError::DeviceCodeExchanged);
877        }
878        DeviceCodeGrantState::Fulfilled {
879            browser_session_id, ..
880        } => *browser_session_id,
881    };
882
883    let browser_session = repo
884        .browser_session()
885        .lookup(browser_session_id)
886        .await?
887        .ok_or(RouteError::NoSuchBrowserSession(browser_session_id))?;
888
889    // Start the session
890    let mut session = repo
891        .oauth2_session()
892        .add_from_browser_session(rng, clock, client, &browser_session, grant.scope.clone())
893        .await?;
894
895    repo.oauth2_device_code_grant()
896        .exchange(clock, grant, &session)
897        .await?;
898
899    // XXX: should we get the user agent from the device code grant instead?
900    if let Some(user_agent) = user_agent {
901        session = repo
902            .oauth2_session()
903            .record_user_agent(session, user_agent)
904            .await?;
905    }
906
907    let ttl = site_config.access_token_ttl;
908    let access_token_str = TokenType::AccessToken.generate(rng);
909
910    let access_token = repo
911        .oauth2_access_token()
912        .add(rng, clock, &session, access_token_str, Some(ttl))
913        .await?;
914
915    let mut params =
916        AccessTokenResponse::new(access_token.access_token.clone()).with_expires_in(ttl);
917
918    // If the client uses the refresh token grant type, we also generate a refresh
919    // token
920    if client.grant_types.contains(&GrantType::RefreshToken) {
921        let refresh_token_str = TokenType::RefreshToken.generate(rng);
922
923        let refresh_token = repo
924            .oauth2_refresh_token()
925            .add(rng, clock, &session, &access_token, refresh_token_str)
926            .await?;
927
928        params = params.with_refresh_token(refresh_token.refresh_token);
929    }
930
931    // If the client asked for an ID token, we generate one
932    if session.scope.contains(&scope::OPENID) {
933        let id_token = generate_id_token(
934            rng,
935            clock,
936            url_builder,
937            key_store,
938            client,
939            None,
940            &browser_session,
941            Some(&access_token),
942            None,
943        )?;
944
945        params = params.with_id_token(id_token);
946    }
947
948    // Lock the user sync to make sure we don't get into a race condition
949    repo.user()
950        .acquire_lock_for_sync(&browser_session.user)
951        .await?;
952
953    // Look for device to provision
954    let mxid = homeserver.mxid(&browser_session.user.username);
955    for scope in &*session.scope {
956        if let Some(device) = Device::from_scope_token(scope) {
957            homeserver
958                .create_device(&mxid, device.as_str(), None)
959                .await
960                .map_err(RouteError::ProvisionDeviceFailed)?;
961        }
962    }
963
964    // XXX: there is a potential (but unlikely) race here, where the activity for
965    // the session is recorded before the transaction is committed. We would have to
966    // save the repository here to fix that.
967    activity_tracker
968        .record_oauth2_session(clock, &session)
969        .await;
970
971    if !session.scope.is_empty() {
972        // We only return the scope if it's not empty
973        params = params.with_scope(session.scope);
974    }
975
976    Ok((params, repo))
977}
978
979#[cfg(test)]
980mod tests {
981    use hyper::Request;
982    use mas_data_model::{AccessToken, AuthorizationCode, RefreshToken};
983    use mas_router::SimpleRoute;
984    use oauth2_types::{
985        registration::ClientRegistrationResponse,
986        requests::{DeviceAuthorizationResponse, ResponseMode},
987        scope::{OPENID, Scope},
988    };
989    use sqlx::PgPool;
990
991    use super::*;
992    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
993
994    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
995    async fn test_auth_code_grant(pool: PgPool) {
996        setup();
997        let state = TestState::from_pool(pool).await.unwrap();
998
999        // Provision a client
1000        let request =
1001            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1002                "client_uri": "https://example.com/",
1003                "redirect_uris": ["https://example.com/callback"],
1004                "token_endpoint_auth_method": "none",
1005                "response_types": ["code"],
1006                "grant_types": ["authorization_code"],
1007            }));
1008
1009        let response = state.request(request).await;
1010        response.assert_status(StatusCode::CREATED);
1011
1012        let ClientRegistrationResponse { client_id, .. } = response.json();
1013
1014        // Let's provision a user and create a session for them. This part is hard to
1015        // test with just HTTP requests, so we'll use the repository directly.
1016        let mut repo = state.repository().await.unwrap();
1017
1018        let user = repo
1019            .user()
1020            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1021            .await
1022            .unwrap();
1023
1024        let browser_session = repo
1025            .browser_session()
1026            .add(&mut state.rng(), &state.clock, &user, None)
1027            .await
1028            .unwrap();
1029
1030        // Lookup the client in the database.
1031        let client = repo
1032            .oauth2_client()
1033            .find_by_client_id(&client_id)
1034            .await
1035            .unwrap()
1036            .unwrap();
1037
1038        // Start a grant
1039        let code = "thisisaverysecurecode";
1040        let grant = repo
1041            .oauth2_authorization_grant()
1042            .add(
1043                &mut state.rng(),
1044                &state.clock,
1045                &client,
1046                "https://example.com/redirect".parse().unwrap(),
1047                Scope::from_iter([OPENID]),
1048                Some(AuthorizationCode {
1049                    code: code.to_owned(),
1050                    pkce: None,
1051                }),
1052                Some("state".to_owned()),
1053                Some("nonce".to_owned()),
1054                ResponseMode::Query,
1055                false,
1056                None,
1057                None,
1058            )
1059            .await
1060            .unwrap();
1061
1062        let session = repo
1063            .oauth2_session()
1064            .add_from_browser_session(
1065                &mut state.rng(),
1066                &state.clock,
1067                &client,
1068                &browser_session,
1069                grant.scope.clone(),
1070            )
1071            .await
1072            .unwrap();
1073
1074        // And fulfill it
1075        let grant = repo
1076            .oauth2_authorization_grant()
1077            .fulfill(&state.clock, &session, grant)
1078            .await
1079            .unwrap();
1080
1081        repo.save().await.unwrap();
1082
1083        // Now call the token endpoint to get an access token.
1084        let request =
1085            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1086                "grant_type": "authorization_code",
1087                "code": code,
1088                "redirect_uri": grant.redirect_uri,
1089                "client_id": client.client_id,
1090            }));
1091
1092        let response = state.request(request).await;
1093        response.assert_status(StatusCode::OK);
1094
1095        let AccessTokenResponse { access_token, .. } = response.json();
1096
1097        // Check that the token is valid
1098        assert!(state.is_access_token_valid(&access_token).await);
1099
1100        // Exchange it again, this it should fail
1101        let request =
1102            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1103                "grant_type": "authorization_code",
1104                "code": code,
1105                "redirect_uri": grant.redirect_uri,
1106                "client_id": client.client_id,
1107            }));
1108
1109        let response = state.request(request).await;
1110        response.assert_status(StatusCode::BAD_REQUEST);
1111        let error: ClientError = response.json();
1112        assert_eq!(error.error, ClientErrorCode::InvalidGrant);
1113
1114        // The token should still be valid
1115        assert!(state.is_access_token_valid(&access_token).await);
1116
1117        // Now wait a bit
1118        state.clock.advance(Duration::try_minutes(1).unwrap());
1119
1120        // Exchange it again, this it should fail
1121        let request =
1122            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1123                "grant_type": "authorization_code",
1124                "code": code,
1125                "redirect_uri": grant.redirect_uri,
1126                "client_id": client.client_id,
1127            }));
1128
1129        let response = state.request(request).await;
1130        response.assert_status(StatusCode::BAD_REQUEST);
1131        let error: ClientError = response.json();
1132        assert_eq!(error.error, ClientErrorCode::InvalidGrant);
1133
1134        // And it should have revoked the token we got
1135        assert!(!state.is_access_token_valid(&access_token).await);
1136
1137        // Try another one and wait for too long before exchanging it
1138        let mut repo = state.repository().await.unwrap();
1139        let code = "thisisanothercode";
1140        let grant = repo
1141            .oauth2_authorization_grant()
1142            .add(
1143                &mut state.rng(),
1144                &state.clock,
1145                &client,
1146                "https://example.com/redirect".parse().unwrap(),
1147                Scope::from_iter([OPENID]),
1148                Some(AuthorizationCode {
1149                    code: code.to_owned(),
1150                    pkce: None,
1151                }),
1152                Some("state".to_owned()),
1153                Some("nonce".to_owned()),
1154                ResponseMode::Query,
1155                false,
1156                None,
1157                None,
1158            )
1159            .await
1160            .unwrap();
1161
1162        let session = repo
1163            .oauth2_session()
1164            .add_from_browser_session(
1165                &mut state.rng(),
1166                &state.clock,
1167                &client,
1168                &browser_session,
1169                grant.scope.clone(),
1170            )
1171            .await
1172            .unwrap();
1173
1174        // And fulfill it
1175        let grant = repo
1176            .oauth2_authorization_grant()
1177            .fulfill(&state.clock, &session, grant)
1178            .await
1179            .unwrap();
1180
1181        repo.save().await.unwrap();
1182
1183        // Now wait a bit
1184        state
1185            .clock
1186            .advance(Duration::microseconds(15 * 60 * 1000 * 1000));
1187
1188        // Exchange it, it should fail
1189        let request =
1190            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1191                "grant_type": "authorization_code",
1192                "code": code,
1193                "redirect_uri": grant.redirect_uri,
1194                "client_id": client.client_id,
1195            }));
1196
1197        let response = state.request(request).await;
1198        response.assert_status(StatusCode::BAD_REQUEST);
1199        let ClientError { error, .. } = response.json();
1200        assert_eq!(error, ClientErrorCode::InvalidGrant);
1201    }
1202
1203    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1204    async fn test_refresh_token_grant(pool: PgPool) {
1205        setup();
1206        let state = TestState::from_pool(pool).await.unwrap();
1207
1208        // Provision a client
1209        let request =
1210            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1211                "client_uri": "https://example.com/",
1212                "redirect_uris": ["https://example.com/callback"],
1213                "token_endpoint_auth_method": "none",
1214                "response_types": ["code"],
1215                "grant_types": ["authorization_code", "refresh_token"],
1216            }));
1217
1218        let response = state.request(request).await;
1219        response.assert_status(StatusCode::CREATED);
1220
1221        let ClientRegistrationResponse { client_id, .. } = response.json();
1222
1223        // Let's provision a user and create a session for them. This part is hard to
1224        // test with just HTTP requests, so we'll use the repository directly.
1225        let mut repo = state.repository().await.unwrap();
1226
1227        let user = repo
1228            .user()
1229            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1230            .await
1231            .unwrap();
1232
1233        let browser_session = repo
1234            .browser_session()
1235            .add(&mut state.rng(), &state.clock, &user, None)
1236            .await
1237            .unwrap();
1238
1239        // Lookup the client in the database.
1240        let client = repo
1241            .oauth2_client()
1242            .find_by_client_id(&client_id)
1243            .await
1244            .unwrap()
1245            .unwrap();
1246
1247        // Get a token pair
1248        let session = repo
1249            .oauth2_session()
1250            .add_from_browser_session(
1251                &mut state.rng(),
1252                &state.clock,
1253                &client,
1254                &browser_session,
1255                Scope::from_iter([OPENID]),
1256            )
1257            .await
1258            .unwrap();
1259
1260        let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
1261            generate_token_pair(
1262                &mut state.rng(),
1263                &state.clock,
1264                &mut repo,
1265                &session,
1266                Duration::microseconds(5 * 60 * 1000 * 1000),
1267            )
1268            .await
1269            .unwrap();
1270
1271        repo.save().await.unwrap();
1272
1273        // First check that the token is valid
1274        assert!(state.is_access_token_valid(&access_token).await);
1275
1276        // Now call the token endpoint to get an access token.
1277        let request =
1278            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1279                "grant_type": "refresh_token",
1280                "refresh_token": refresh_token,
1281                "client_id": client.client_id,
1282            }));
1283
1284        let response = state.request(request).await;
1285        response.assert_status(StatusCode::OK);
1286
1287        let old_access_token = access_token;
1288        let old_refresh_token = refresh_token;
1289        let response: AccessTokenResponse = response.json();
1290        let access_token = response.access_token;
1291        let refresh_token = response.refresh_token.expect("to have a refresh token");
1292
1293        // Check that the new token is valid
1294        assert!(state.is_access_token_valid(&access_token).await);
1295
1296        // Check that the old token is no longer valid
1297        assert!(!state.is_access_token_valid(&old_access_token).await);
1298
1299        // Call it again with the old token, it should fail
1300        let request =
1301            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1302                "grant_type": "refresh_token",
1303                "refresh_token": old_refresh_token,
1304                "client_id": client.client_id,
1305            }));
1306
1307        let response = state.request(request).await;
1308        response.assert_status(StatusCode::BAD_REQUEST);
1309        let ClientError { error, .. } = response.json();
1310        assert_eq!(error, ClientErrorCode::InvalidGrant);
1311
1312        // Call it again with the new token, it should work
1313        let request =
1314            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1315                "grant_type": "refresh_token",
1316                "refresh_token": refresh_token,
1317                "client_id": client.client_id,
1318            }));
1319
1320        let response = state.request(request).await;
1321        response.assert_status(StatusCode::OK);
1322        let _: AccessTokenResponse = response.json();
1323    }
1324
1325    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1326    async fn test_double_refresh(pool: PgPool) {
1327        setup();
1328        let state = TestState::from_pool(pool).await.unwrap();
1329
1330        // Provision a client
1331        let request =
1332            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1333                "client_uri": "https://example.com/",
1334                "redirect_uris": ["https://example.com/callback"],
1335                "token_endpoint_auth_method": "none",
1336                "response_types": ["code"],
1337                "grant_types": ["authorization_code", "refresh_token"],
1338            }));
1339
1340        let response = state.request(request).await;
1341        response.assert_status(StatusCode::CREATED);
1342
1343        let ClientRegistrationResponse { client_id, .. } = response.json();
1344
1345        // Let's provision a user and create a session for them. This part is hard to
1346        // test with just HTTP requests, so we'll use the repository directly.
1347        let mut repo = state.repository().await.unwrap();
1348
1349        let user = repo
1350            .user()
1351            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1352            .await
1353            .unwrap();
1354
1355        let browser_session = repo
1356            .browser_session()
1357            .add(&mut state.rng(), &state.clock, &user, None)
1358            .await
1359            .unwrap();
1360
1361        // Lookup the client in the database.
1362        let client = repo
1363            .oauth2_client()
1364            .find_by_client_id(&client_id)
1365            .await
1366            .unwrap()
1367            .unwrap();
1368
1369        // Get a token pair
1370        let session = repo
1371            .oauth2_session()
1372            .add_from_browser_session(
1373                &mut state.rng(),
1374                &state.clock,
1375                &client,
1376                &browser_session,
1377                Scope::from_iter([OPENID]),
1378            )
1379            .await
1380            .unwrap();
1381
1382        let (AccessToken { access_token, .. }, RefreshToken { refresh_token, .. }) =
1383            generate_token_pair(
1384                &mut state.rng(),
1385                &state.clock,
1386                &mut repo,
1387                &session,
1388                Duration::microseconds(5 * 60 * 1000 * 1000),
1389            )
1390            .await
1391            .unwrap();
1392
1393        repo.save().await.unwrap();
1394
1395        // First check that the token is valid
1396        assert!(state.is_access_token_valid(&access_token).await);
1397
1398        // Now call the token endpoint to get an access token.
1399        let request =
1400            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1401                "grant_type": "refresh_token",
1402                "refresh_token": refresh_token,
1403                "client_id": client.client_id,
1404            }));
1405
1406        let first_response = state.request(request).await;
1407        first_response.assert_status(StatusCode::OK);
1408        let first_response: AccessTokenResponse = first_response.json();
1409
1410        // Call a second time, it should work, as we haven't done anything yet with the
1411        // token
1412        let request =
1413            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1414                "grant_type": "refresh_token",
1415                "refresh_token": refresh_token,
1416                "client_id": client.client_id,
1417            }));
1418
1419        let second_response = state.request(request).await;
1420        second_response.assert_status(StatusCode::OK);
1421        let second_response: AccessTokenResponse = second_response.json();
1422
1423        // Check that we got new tokens
1424        assert_ne!(first_response.access_token, second_response.access_token);
1425        assert_ne!(first_response.refresh_token, second_response.refresh_token);
1426
1427        // Check that the old-new token is invalid
1428        assert!(
1429            !state
1430                .is_access_token_valid(&first_response.access_token)
1431                .await
1432        );
1433
1434        // Check that the new-new token is valid
1435        assert!(
1436            state
1437                .is_access_token_valid(&second_response.access_token)
1438                .await
1439        );
1440
1441        // Do a third refresh, this one should not work, as we've used the new
1442        // access token
1443        let request =
1444            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1445                "grant_type": "refresh_token",
1446                "refresh_token": refresh_token,
1447                "client_id": client.client_id,
1448            }));
1449
1450        let third_response = state.request(request).await;
1451        third_response.assert_status(StatusCode::BAD_REQUEST);
1452
1453        // The other reason we consider a new refresh token to be 'used' is if
1454        // it was already used in a refresh
1455        // So, if we do a refresh with the second_response.refresh_token, then
1456        // another refresh with the result, redoing one with
1457        // second_response.refresh_token again should fail
1458        let request =
1459            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1460                "grant_type": "refresh_token",
1461                "refresh_token": second_response.refresh_token,
1462                "client_id": client.client_id,
1463            }));
1464
1465        // This one is fine
1466        let fourth_response = state.request(request).await;
1467        fourth_response.assert_status(StatusCode::OK);
1468        let fourth_response: AccessTokenResponse = fourth_response.json();
1469
1470        // Do another one, it should be fine as well
1471        let request =
1472            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1473                "grant_type": "refresh_token",
1474                "refresh_token": fourth_response.refresh_token,
1475                "client_id": client.client_id,
1476            }));
1477
1478        let fifth_response = state.request(request).await;
1479        fifth_response.assert_status(StatusCode::OK);
1480
1481        // But now, if we re-do with the second_response.refresh_token, it should
1482        // fail
1483        let request =
1484            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1485                "grant_type": "refresh_token",
1486                "refresh_token": second_response.refresh_token,
1487                "client_id": client.client_id,
1488            }));
1489
1490        let sixth_response = state.request(request).await;
1491        sixth_response.assert_status(StatusCode::BAD_REQUEST);
1492    }
1493
1494    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1495    async fn test_client_credentials(pool: PgPool) {
1496        setup();
1497        let state = TestState::from_pool(pool).await.unwrap();
1498
1499        // Provision a client
1500        let request =
1501            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1502                "client_uri": "https://example.com/",
1503                "token_endpoint_auth_method": "client_secret_post",
1504                "grant_types": ["client_credentials"],
1505            }));
1506
1507        let response = state.request(request).await;
1508        response.assert_status(StatusCode::CREATED);
1509
1510        let response: ClientRegistrationResponse = response.json();
1511        let client_id = response.client_id;
1512        let client_secret = response.client_secret.expect("to have a client secret");
1513
1514        // Call the token endpoint with an empty scope
1515        let request =
1516            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1517                "grant_type": "client_credentials",
1518                "client_id": client_id,
1519                "client_secret": client_secret,
1520            }));
1521
1522        let response = state.request(request).await;
1523        response.assert_status(StatusCode::OK);
1524
1525        let response: AccessTokenResponse = response.json();
1526        assert!(response.refresh_token.is_none());
1527        assert!(response.expires_in.is_some());
1528        assert!(response.scope.is_none());
1529
1530        // Revoke the token
1531        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1532            "token": response.access_token,
1533            "client_id": client_id,
1534            "client_secret": client_secret,
1535        }));
1536
1537        let response = state.request(request).await;
1538        response.assert_status(StatusCode::OK);
1539
1540        // We should be allowed to ask for the GraphQL API scope
1541        let request =
1542            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1543                "grant_type": "client_credentials",
1544                "client_id": client_id,
1545                "client_secret": client_secret,
1546                "scope": "urn:mas:graphql:*"
1547            }));
1548
1549        let response = state.request(request).await;
1550        response.assert_status(StatusCode::OK);
1551
1552        let response: AccessTokenResponse = response.json();
1553        assert!(response.refresh_token.is_none());
1554        assert!(response.expires_in.is_some());
1555        assert_eq!(response.scope, Some("urn:mas:graphql:*".parse().unwrap()));
1556
1557        // Revoke the token
1558        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1559            "token": response.access_token,
1560            "client_id": client_id,
1561            "client_secret": client_secret,
1562        }));
1563
1564        let response = state.request(request).await;
1565        response.assert_status(StatusCode::OK);
1566
1567        // We should be NOT allowed to ask for the MAS admin scope
1568        let request =
1569            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1570                "grant_type": "client_credentials",
1571                "client_id": client_id,
1572                "client_secret": client_secret,
1573                "scope": "urn:mas:admin"
1574            }));
1575
1576        let response = state.request(request).await;
1577        response.assert_status(StatusCode::FORBIDDEN);
1578
1579        let ClientError { error, .. } = response.json();
1580        assert_eq!(error, ClientErrorCode::InvalidScope);
1581
1582        // Now, if we add the client to the admin list in the policy, it should work
1583        let state = {
1584            let mut state = state;
1585            state.policy_factory = crate::test_utils::policy_factory(
1586                "example.com",
1587                serde_json::json!({
1588                    "admin_clients": [client_id]
1589                }),
1590            )
1591            .await
1592            .unwrap();
1593            state
1594        };
1595
1596        let request =
1597            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1598                "grant_type": "client_credentials",
1599                "client_id": client_id,
1600                "client_secret": client_secret,
1601                "scope": "urn:mas:admin"
1602            }));
1603
1604        let response = state.request(request).await;
1605        response.assert_status(StatusCode::OK);
1606
1607        let response: AccessTokenResponse = response.json();
1608        assert!(response.refresh_token.is_none());
1609        assert!(response.expires_in.is_some());
1610        assert_eq!(response.scope, Some("urn:mas:admin".parse().unwrap()));
1611
1612        // Revoke the token
1613        let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
1614            "token": response.access_token,
1615            "client_id": client_id,
1616            "client_secret": client_secret,
1617        }));
1618
1619        let response = state.request(request).await;
1620        response.assert_status(StatusCode::OK);
1621    }
1622
1623    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1624    async fn test_device_code_grant(pool: PgPool) {
1625        setup();
1626        let state = TestState::from_pool(pool).await.unwrap();
1627
1628        // Provision a client
1629        let request =
1630            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1631                "client_uri": "https://example.com/",
1632                "token_endpoint_auth_method": "none",
1633                "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "refresh_token"],
1634                "response_types": [],
1635            }));
1636
1637        let response = state.request(request).await;
1638        response.assert_status(StatusCode::CREATED);
1639
1640        let response: ClientRegistrationResponse = response.json();
1641        let client_id = response.client_id;
1642
1643        // Start a device code grant
1644        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1645            serde_json::json!({
1646                "client_id": client_id,
1647                "scope": "openid",
1648            }),
1649        );
1650        let response = state.request(request).await;
1651        response.assert_status(StatusCode::OK);
1652
1653        let device_grant: DeviceAuthorizationResponse = response.json();
1654
1655        // Poll the token endpoint, it should be pending
1656        let request =
1657            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1658                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1659                "device_code": device_grant.device_code,
1660                "client_id": client_id,
1661            }));
1662        let response = state.request(request).await;
1663        response.assert_status(StatusCode::FORBIDDEN);
1664
1665        let ClientError { error, .. } = response.json();
1666        assert_eq!(error, ClientErrorCode::AuthorizationPending);
1667
1668        // Let's provision a user and create a browser session for them. This part is
1669        // hard to test with just HTTP requests, so we'll use the repository
1670        // directly.
1671        let mut repo = state.repository().await.unwrap();
1672
1673        let user = repo
1674            .user()
1675            .add(&mut state.rng(), &state.clock, "alice".to_owned())
1676            .await
1677            .unwrap();
1678
1679        let browser_session = repo
1680            .browser_session()
1681            .add(&mut state.rng(), &state.clock, &user, None)
1682            .await
1683            .unwrap();
1684
1685        // Find the grant
1686        let grant = repo
1687            .oauth2_device_code_grant()
1688            .find_by_user_code(&device_grant.user_code)
1689            .await
1690            .unwrap()
1691            .unwrap();
1692
1693        // And fulfill it
1694        let grant = repo
1695            .oauth2_device_code_grant()
1696            .fulfill(&state.clock, grant, &browser_session)
1697            .await
1698            .unwrap();
1699
1700        repo.save().await.unwrap();
1701
1702        // Now call the token endpoint to get an access token.
1703        let request =
1704            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1705                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1706                "device_code": grant.device_code,
1707                "client_id": client_id,
1708            }));
1709
1710        let response = state.request(request).await;
1711        response.assert_status(StatusCode::OK);
1712
1713        let response: AccessTokenResponse = response.json();
1714
1715        // Check that the token is valid
1716        assert!(state.is_access_token_valid(&response.access_token).await);
1717        // We advertised the refresh token grant type, so we should have a refresh token
1718        assert!(response.refresh_token.is_some());
1719        // We asked for the openid scope, so we should have an ID token
1720        assert!(response.id_token.is_some());
1721
1722        // Calling it again should fail
1723        let request =
1724            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1725                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1726                "device_code": grant.device_code,
1727                "client_id": client_id,
1728            }));
1729        let response = state.request(request).await;
1730        response.assert_status(StatusCode::BAD_REQUEST);
1731
1732        let ClientError { error, .. } = response.json();
1733        assert_eq!(error, ClientErrorCode::InvalidGrant);
1734
1735        // Do another grant and make it expire
1736        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1737            serde_json::json!({
1738                "client_id": client_id,
1739                "scope": "openid",
1740            }),
1741        );
1742        let response = state.request(request).await;
1743        response.assert_status(StatusCode::OK);
1744
1745        let device_grant: DeviceAuthorizationResponse = response.json();
1746
1747        // Poll the token endpoint, it should be pending
1748        let request =
1749            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1750                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1751                "device_code": device_grant.device_code,
1752                "client_id": client_id,
1753            }));
1754        let response = state.request(request).await;
1755        response.assert_status(StatusCode::FORBIDDEN);
1756
1757        let ClientError { error, .. } = response.json();
1758        assert_eq!(error, ClientErrorCode::AuthorizationPending);
1759
1760        state.clock.advance(Duration::try_hours(1).unwrap());
1761
1762        // Poll again, it should be expired
1763        let request =
1764            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1765                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1766                "device_code": device_grant.device_code,
1767                "client_id": client_id,
1768            }));
1769        let response = state.request(request).await;
1770        response.assert_status(StatusCode::FORBIDDEN);
1771
1772        let ClientError { error, .. } = response.json();
1773        assert_eq!(error, ClientErrorCode::ExpiredToken);
1774
1775        // Do another grant and reject it
1776        let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
1777            serde_json::json!({
1778                "client_id": client_id,
1779                "scope": "openid",
1780            }),
1781        );
1782        let response = state.request(request).await;
1783        response.assert_status(StatusCode::OK);
1784
1785        let device_grant: DeviceAuthorizationResponse = response.json();
1786
1787        // Find the grant and reject it
1788        let mut repo = state.repository().await.unwrap();
1789
1790        // Find the grant
1791        let grant = repo
1792            .oauth2_device_code_grant()
1793            .find_by_user_code(&device_grant.user_code)
1794            .await
1795            .unwrap()
1796            .unwrap();
1797
1798        // And reject it
1799        let grant = repo
1800            .oauth2_device_code_grant()
1801            .reject(&state.clock, grant, &browser_session)
1802            .await
1803            .unwrap();
1804
1805        repo.save().await.unwrap();
1806
1807        // Poll the token endpoint, it should be rejected
1808        let request =
1809            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1810                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1811                "device_code": grant.device_code,
1812                "client_id": client_id,
1813            }));
1814        let response = state.request(request).await;
1815        response.assert_status(StatusCode::FORBIDDEN);
1816
1817        let ClientError { error, .. } = response.json();
1818        assert_eq!(error, ClientErrorCode::AccessDenied);
1819    }
1820
1821    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1822    async fn test_unsupported_grant(pool: PgPool) {
1823        setup();
1824        let state = TestState::from_pool(pool).await.unwrap();
1825
1826        // Provision a client
1827        let request =
1828            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
1829                "client_uri": "https://example.com/",
1830                "redirect_uris": ["https://example.com/callback"],
1831                "token_endpoint_auth_method": "client_secret_post",
1832                "grant_types": ["password"],
1833                "response_types": [],
1834            }));
1835
1836        let response = state.request(request).await;
1837        response.assert_status(StatusCode::CREATED);
1838
1839        let response: ClientRegistrationResponse = response.json();
1840        let client_id = response.client_id;
1841        let client_secret = response.client_secret.expect("to have a client secret");
1842
1843        // Call the token endpoint with an unsupported grant type
1844        let request =
1845            Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
1846                "grant_type": "password",
1847                "client_id": client_id,
1848                "client_secret": client_secret,
1849                "username": "john",
1850                "password": "hunter2",
1851            }));
1852
1853        let response = state.request(request).await;
1854        response.assert_status(StatusCode::BAD_REQUEST);
1855        let ClientError { error, .. } = response.json();
1856        assert_eq!(error, ClientErrorCode::UnsupportedGrantType);
1857    }
1858}