mas_storage_pg/oauth2/
authorization_grant.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
11};
12use mas_iana::oauth::PkceCodeChallengeMethod;
13use mas_storage::{Clock, oauth2::OAuth2AuthorizationGrantRepository};
14use oauth2_types::{requests::ResponseMode, scope::Scope};
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use url::Url;
19use uuid::Uuid;
20
21use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
22
23/// An implementation of [`OAuth2AuthorizationGrantRepository`] for a PostgreSQL
24/// connection
25pub struct PgOAuth2AuthorizationGrantRepository<'c> {
26    conn: &'c mut PgConnection,
27}
28
29impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
30    /// Create a new [`PgOAuth2AuthorizationGrantRepository`] from an active
31    /// PostgreSQL connection
32    pub fn new(conn: &'c mut PgConnection) -> Self {
33        Self { conn }
34    }
35}
36
37#[allow(clippy::struct_excessive_bools)]
38struct GrantLookup {
39    oauth2_authorization_grant_id: Uuid,
40    created_at: DateTime<Utc>,
41    cancelled_at: Option<DateTime<Utc>>,
42    fulfilled_at: Option<DateTime<Utc>>,
43    exchanged_at: Option<DateTime<Utc>>,
44    scope: String,
45    state: Option<String>,
46    nonce: Option<String>,
47    redirect_uri: String,
48    response_mode: String,
49    response_type_code: bool,
50    response_type_id_token: bool,
51    authorization_code: Option<String>,
52    code_challenge: Option<String>,
53    code_challenge_method: Option<String>,
54    login_hint: Option<String>,
55    locale: Option<String>,
56    oauth2_client_id: Uuid,
57    oauth2_session_id: Option<Uuid>,
58}
59
60impl TryFrom<GrantLookup> for AuthorizationGrant {
61    type Error = DatabaseInconsistencyError;
62
63    #[allow(clippy::too_many_lines)]
64    fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
65        let id = value.oauth2_authorization_grant_id.into();
66        let scope: Scope = value.scope.parse().map_err(|e| {
67            DatabaseInconsistencyError::on("oauth2_authorization_grants")
68                .column("scope")
69                .row(id)
70                .source(e)
71        })?;
72
73        let stage = match (
74            value.fulfilled_at,
75            value.exchanged_at,
76            value.cancelled_at,
77            value.oauth2_session_id,
78        ) {
79            (None, None, None, None) => AuthorizationGrantStage::Pending,
80            (Some(fulfilled_at), None, None, Some(session_id)) => {
81                AuthorizationGrantStage::Fulfilled {
82                    session_id: session_id.into(),
83                    fulfilled_at,
84                }
85            }
86            (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
87                AuthorizationGrantStage::Exchanged {
88                    session_id: session_id.into(),
89                    fulfilled_at,
90                    exchanged_at,
91                }
92            }
93            (None, None, Some(cancelled_at), None) => {
94                AuthorizationGrantStage::Cancelled { cancelled_at }
95            }
96            _ => {
97                return Err(
98                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
99                        .column("stage")
100                        .row(id),
101                );
102            }
103        };
104
105        let pkce = match (value.code_challenge, value.code_challenge_method) {
106            (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
107                Some(Pkce {
108                    challenge_method: PkceCodeChallengeMethod::Plain,
109                    challenge,
110                })
111            }
112            (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
113                challenge_method: PkceCodeChallengeMethod::S256,
114                challenge,
115            }),
116            (None, None) => None,
117            _ => {
118                return Err(
119                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
120                        .column("code_challenge_method")
121                        .row(id),
122                );
123            }
124        };
125
126        let code: Option<AuthorizationCode> =
127            match (value.response_type_code, value.authorization_code, pkce) {
128                (false, None, None) => None,
129                (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
130                _ => {
131                    return Err(
132                        DatabaseInconsistencyError::on("oauth2_authorization_grants")
133                            .column("authorization_code")
134                            .row(id),
135                    );
136                }
137            };
138
139        let redirect_uri = value.redirect_uri.parse().map_err(|e| {
140            DatabaseInconsistencyError::on("oauth2_authorization_grants")
141                .column("redirect_uri")
142                .row(id)
143                .source(e)
144        })?;
145
146        let response_mode = value.response_mode.parse().map_err(|e| {
147            DatabaseInconsistencyError::on("oauth2_authorization_grants")
148                .column("response_mode")
149                .row(id)
150                .source(e)
151        })?;
152
153        Ok(AuthorizationGrant {
154            id,
155            stage,
156            client_id: value.oauth2_client_id.into(),
157            code,
158            scope,
159            state: value.state,
160            nonce: value.nonce,
161            response_mode,
162            redirect_uri,
163            created_at: value.created_at,
164            response_type_id_token: value.response_type_id_token,
165            login_hint: value.login_hint,
166            locale: value.locale,
167        })
168    }
169}
170
171#[async_trait]
172impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
173    type Error = DatabaseError;
174
175    #[tracing::instrument(
176        name = "db.oauth2_authorization_grant.add",
177        skip_all,
178        fields(
179            db.query.text,
180            grant.id,
181            grant.scope = %scope,
182            %client.id,
183        ),
184        err,
185    )]
186    async fn add(
187        &mut self,
188        rng: &mut (dyn RngCore + Send),
189        clock: &dyn Clock,
190        client: &Client,
191        redirect_uri: Url,
192        scope: Scope,
193        code: Option<AuthorizationCode>,
194        state: Option<String>,
195        nonce: Option<String>,
196        response_mode: ResponseMode,
197        response_type_id_token: bool,
198        login_hint: Option<String>,
199        locale: Option<String>,
200    ) -> Result<AuthorizationGrant, Self::Error> {
201        let code_challenge = code
202            .as_ref()
203            .and_then(|c| c.pkce.as_ref())
204            .map(|p| &p.challenge);
205        let code_challenge_method = code
206            .as_ref()
207            .and_then(|c| c.pkce.as_ref())
208            .map(|p| p.challenge_method.to_string());
209        let code_str = code.as_ref().map(|c| &c.code);
210
211        let created_at = clock.now();
212        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
213        tracing::Span::current().record("grant.id", tracing::field::display(id));
214
215        sqlx::query!(
216            r#"
217                INSERT INTO oauth2_authorization_grants (
218                     oauth2_authorization_grant_id,
219                     oauth2_client_id,
220                     redirect_uri,
221                     scope,
222                     state,
223                     nonce,
224                     response_mode,
225                     code_challenge,
226                     code_challenge_method,
227                     response_type_code,
228                     response_type_id_token,
229                     authorization_code,
230                     login_hint,
231                     locale,
232                     created_at
233                )
234                VALUES
235                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
236            "#,
237            Uuid::from(id),
238            Uuid::from(client.id),
239            redirect_uri.to_string(),
240            scope.to_string(),
241            state,
242            nonce,
243            response_mode.to_string(),
244            code_challenge,
245            code_challenge_method,
246            code.is_some(),
247            response_type_id_token,
248            code_str,
249            login_hint,
250            locale,
251            created_at,
252        )
253        .traced()
254        .execute(&mut *self.conn)
255        .await?;
256
257        Ok(AuthorizationGrant {
258            id,
259            stage: AuthorizationGrantStage::Pending,
260            code,
261            redirect_uri,
262            client_id: client.id,
263            scope,
264            state,
265            nonce,
266            response_mode,
267            created_at,
268            response_type_id_token,
269            login_hint,
270            locale,
271        })
272    }
273
274    #[tracing::instrument(
275        name = "db.oauth2_authorization_grant.lookup",
276        skip_all,
277        fields(
278            db.query.text,
279            grant.id = %id,
280        ),
281        err,
282    )]
283    async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
284        let res = sqlx::query_as!(
285            GrantLookup,
286            r#"
287                SELECT oauth2_authorization_grant_id
288                     , created_at
289                     , cancelled_at
290                     , fulfilled_at
291                     , exchanged_at
292                     , scope
293                     , state
294                     , redirect_uri
295                     , response_mode
296                     , nonce
297                     , oauth2_client_id
298                     , authorization_code
299                     , response_type_code
300                     , response_type_id_token
301                     , code_challenge
302                     , code_challenge_method
303                     , login_hint
304                     , locale
305                     , oauth2_session_id
306                FROM
307                    oauth2_authorization_grants
308
309                WHERE oauth2_authorization_grant_id = $1
310            "#,
311            Uuid::from(id),
312        )
313        .traced()
314        .fetch_optional(&mut *self.conn)
315        .await?;
316
317        let Some(res) = res else { return Ok(None) };
318
319        Ok(Some(res.try_into()?))
320    }
321
322    #[tracing::instrument(
323        name = "db.oauth2_authorization_grant.find_by_code",
324        skip_all,
325        fields(
326            db.query.text,
327        ),
328        err,
329    )]
330    async fn find_by_code(
331        &mut self,
332        code: &str,
333    ) -> Result<Option<AuthorizationGrant>, Self::Error> {
334        let res = sqlx::query_as!(
335            GrantLookup,
336            r#"
337                SELECT oauth2_authorization_grant_id
338                     , created_at
339                     , cancelled_at
340                     , fulfilled_at
341                     , exchanged_at
342                     , scope
343                     , state
344                     , redirect_uri
345                     , response_mode
346                     , nonce
347                     , oauth2_client_id
348                     , authorization_code
349                     , response_type_code
350                     , response_type_id_token
351                     , code_challenge
352                     , code_challenge_method
353                     , login_hint
354                     , locale
355                     , oauth2_session_id
356                FROM
357                    oauth2_authorization_grants
358
359                WHERE authorization_code = $1
360            "#,
361            code,
362        )
363        .traced()
364        .fetch_optional(&mut *self.conn)
365        .await?;
366
367        let Some(res) = res else { return Ok(None) };
368
369        Ok(Some(res.try_into()?))
370    }
371
372    #[tracing::instrument(
373        name = "db.oauth2_authorization_grant.fulfill",
374        skip_all,
375        fields(
376            db.query.text,
377            %grant.id,
378            client.id = %grant.client_id,
379            %session.id,
380        ),
381        err,
382    )]
383    async fn fulfill(
384        &mut self,
385        clock: &dyn Clock,
386        session: &Session,
387        grant: AuthorizationGrant,
388    ) -> Result<AuthorizationGrant, Self::Error> {
389        let fulfilled_at = clock.now();
390        let res = sqlx::query!(
391            r#"
392                UPDATE oauth2_authorization_grants
393                SET fulfilled_at = $2
394                  , oauth2_session_id = $3
395                WHERE oauth2_authorization_grant_id = $1
396            "#,
397            Uuid::from(grant.id),
398            fulfilled_at,
399            Uuid::from(session.id),
400        )
401        .traced()
402        .execute(&mut *self.conn)
403        .await?;
404
405        DatabaseError::ensure_affected_rows(&res, 1)?;
406
407        // XXX: check affected rows & new methods
408        let grant = grant
409            .fulfill(fulfilled_at, session)
410            .map_err(DatabaseError::to_invalid_operation)?;
411
412        Ok(grant)
413    }
414
415    #[tracing::instrument(
416        name = "db.oauth2_authorization_grant.exchange",
417        skip_all,
418        fields(
419            db.query.text,
420            %grant.id,
421            client.id = %grant.client_id,
422        ),
423        err,
424    )]
425    async fn exchange(
426        &mut self,
427        clock: &dyn Clock,
428        grant: AuthorizationGrant,
429    ) -> Result<AuthorizationGrant, Self::Error> {
430        let exchanged_at = clock.now();
431        let res = sqlx::query!(
432            r#"
433                UPDATE oauth2_authorization_grants
434                SET exchanged_at = $2
435                WHERE oauth2_authorization_grant_id = $1
436            "#,
437            Uuid::from(grant.id),
438            exchanged_at,
439        )
440        .traced()
441        .execute(&mut *self.conn)
442        .await?;
443
444        DatabaseError::ensure_affected_rows(&res, 1)?;
445
446        let grant = grant
447            .exchange(exchanged_at)
448            .map_err(DatabaseError::to_invalid_operation)?;
449
450        Ok(grant)
451    }
452}