mas_data_model/oauth2/
authorization_grant.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use chrono::{DateTime, Utc};
8use mas_iana::oauth::PkceCodeChallengeMethod;
9use oauth2_types::{
10    pkce::{CodeChallengeError, CodeChallengeMethodExt},
11    requests::ResponseMode,
12    scope::{OPENID, PROFILE, Scope},
13};
14use rand::{
15    RngCore,
16    distributions::{Alphanumeric, DistString},
17};
18use ruma_common::UserId;
19use serde::Serialize;
20use ulid::Ulid;
21use url::Url;
22
23use super::session::Session;
24use crate::InvalidTransitionError;
25
26#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
27pub struct Pkce {
28    pub challenge_method: PkceCodeChallengeMethod,
29    pub challenge: String,
30}
31
32impl Pkce {
33    /// Create a new PKCE challenge, with the given method and challenge.
34    #[must_use]
35    pub fn new(challenge_method: PkceCodeChallengeMethod, challenge: String) -> Self {
36        Pkce {
37            challenge_method,
38            challenge,
39        }
40    }
41
42    /// Verify the PKCE challenge.
43    ///
44    /// # Errors
45    ///
46    /// Returns an error if the verifier is invalid.
47    pub fn verify(&self, verifier: &str) -> Result<(), CodeChallengeError> {
48        self.challenge_method.verify(&self.challenge, verifier)
49    }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
53pub struct AuthorizationCode {
54    pub code: String,
55    pub pkce: Option<Pkce>,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
59#[serde(tag = "stage", rename_all = "lowercase")]
60pub enum AuthorizationGrantStage {
61    #[default]
62    Pending,
63    Fulfilled {
64        session_id: Ulid,
65        fulfilled_at: DateTime<Utc>,
66    },
67    Exchanged {
68        session_id: Ulid,
69        fulfilled_at: DateTime<Utc>,
70        exchanged_at: DateTime<Utc>,
71    },
72    Cancelled {
73        cancelled_at: DateTime<Utc>,
74    },
75}
76
77impl AuthorizationGrantStage {
78    #[must_use]
79    pub fn new() -> Self {
80        Self::Pending
81    }
82
83    fn fulfill(
84        self,
85        fulfilled_at: DateTime<Utc>,
86        session: &Session,
87    ) -> Result<Self, InvalidTransitionError> {
88        match self {
89            Self::Pending => Ok(Self::Fulfilled {
90                fulfilled_at,
91                session_id: session.id,
92            }),
93            _ => Err(InvalidTransitionError),
94        }
95    }
96
97    fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
98        match self {
99            Self::Fulfilled {
100                fulfilled_at,
101                session_id,
102            } => Ok(Self::Exchanged {
103                fulfilled_at,
104                exchanged_at,
105                session_id,
106            }),
107            _ => Err(InvalidTransitionError),
108        }
109    }
110
111    fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
112        match self {
113            Self::Pending => Ok(Self::Cancelled { cancelled_at }),
114            _ => Err(InvalidTransitionError),
115        }
116    }
117
118    /// Returns `true` if the authorization grant stage is [`Pending`].
119    ///
120    /// [`Pending`]: AuthorizationGrantStage::Pending
121    #[must_use]
122    pub fn is_pending(&self) -> bool {
123        matches!(self, Self::Pending)
124    }
125
126    /// Returns `true` if the authorization grant stage is [`Fulfilled`].
127    ///
128    /// [`Fulfilled`]: AuthorizationGrantStage::Fulfilled
129    #[must_use]
130    pub fn is_fulfilled(&self) -> bool {
131        matches!(self, Self::Fulfilled { .. })
132    }
133
134    /// Returns `true` if the authorization grant stage is [`Exchanged`].
135    ///
136    /// [`Exchanged`]: AuthorizationGrantStage::Exchanged
137    #[must_use]
138    pub fn is_exchanged(&self) -> bool {
139        matches!(self, Self::Exchanged { .. })
140    }
141}
142
143pub enum LoginHint<'a> {
144    MXID(&'a UserId),
145    None,
146}
147
148#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
149pub struct AuthorizationGrant {
150    pub id: Ulid,
151    #[serde(flatten)]
152    pub stage: AuthorizationGrantStage,
153    pub code: Option<AuthorizationCode>,
154    pub client_id: Ulid,
155    pub redirect_uri: Url,
156    pub scope: Scope,
157    pub state: Option<String>,
158    pub nonce: Option<String>,
159    pub response_mode: ResponseMode,
160    pub response_type_id_token: bool,
161    pub created_at: DateTime<Utc>,
162    pub login_hint: Option<String>,
163    pub locale: Option<String>,
164}
165
166impl std::ops::Deref for AuthorizationGrant {
167    type Target = AuthorizationGrantStage;
168
169    fn deref(&self) -> &Self::Target {
170        &self.stage
171    }
172}
173
174impl AuthorizationGrant {
175    #[must_use]
176    pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint {
177        let Some(login_hint) = &self.login_hint else {
178            return LoginHint::None;
179        };
180
181        // Return none if the format is incorrect
182        let Some((prefix, value)) = login_hint.split_once(':') else {
183            return LoginHint::None;
184        };
185
186        match prefix {
187            "mxid" => {
188                // Instead of erroring just return none
189                let Ok(mxid) = <&UserId>::try_from(value) else {
190                    return LoginHint::None;
191                };
192
193                // Only handle MXIDs for current homeserver
194                if mxid.server_name() != homeserver {
195                    return LoginHint::None;
196                }
197
198                LoginHint::MXID(mxid)
199            }
200            // Unknown hint type, treat as none
201            _ => LoginHint::None,
202        }
203    }
204
205    /// Mark the authorization grant as exchanged.
206    ///
207    /// # Errors
208    ///
209    /// Returns an error if the authorization grant is not [`Fulfilled`].
210    ///
211    /// [`Fulfilled`]: AuthorizationGrantStage::Fulfilled
212    pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
213        self.stage = self.stage.exchange(exchanged_at)?;
214        Ok(self)
215    }
216
217    /// Mark the authorization grant as fulfilled.
218    ///
219    /// # Errors
220    ///
221    /// Returns an error if the authorization grant is not [`Pending`].
222    ///
223    /// [`Pending`]: AuthorizationGrantStage::Pending
224    pub fn fulfill(
225        mut self,
226        fulfilled_at: DateTime<Utc>,
227        session: &Session,
228    ) -> Result<Self, InvalidTransitionError> {
229        self.stage = self.stage.fulfill(fulfilled_at, session)?;
230        Ok(self)
231    }
232
233    /// Mark the authorization grant as cancelled.
234    ///
235    /// # Errors
236    ///
237    /// Returns an error if the authorization grant is not [`Pending`].
238    ///
239    /// [`Pending`]: AuthorizationGrantStage::Pending
240    ///
241    /// # TODO
242    ///
243    /// This appears to be unused
244    pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
245        self.stage = self.stage.cancel(canceld_at)?;
246        Ok(self)
247    }
248
249    #[doc(hidden)]
250    pub fn sample(now: DateTime<Utc>, rng: &mut impl RngCore) -> Self {
251        Self {
252            id: Ulid::from_datetime_with_source(now.into(), rng),
253            stage: AuthorizationGrantStage::Pending,
254            code: Some(AuthorizationCode {
255                code: Alphanumeric.sample_string(rng, 10),
256                pkce: None,
257            }),
258            client_id: Ulid::from_datetime_with_source(now.into(), rng),
259            redirect_uri: Url::parse("http://localhost:8080").unwrap(),
260            scope: Scope::from_iter([OPENID, PROFILE]),
261            state: Some(Alphanumeric.sample_string(rng, 10)),
262            nonce: Some(Alphanumeric.sample_string(rng, 10)),
263            response_mode: ResponseMode::Query,
264            response_type_id_token: false,
265            created_at: now,
266            login_hint: Some(String::from("mxid:@example-user:example.com")),
267            locale: Some(String::from("fr")),
268        }
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use rand::thread_rng;
275
276    use super::*;
277
278    #[test]
279    fn no_login_hint() {
280        #[allow(clippy::disallowed_methods)]
281        let mut rng = thread_rng();
282
283        #[allow(clippy::disallowed_methods)]
284        let now = Utc::now();
285
286        let grant = AuthorizationGrant {
287            login_hint: None,
288            ..AuthorizationGrant::sample(now, &mut rng)
289        };
290
291        let hint = grant.parse_login_hint("example.com");
292
293        assert!(matches!(hint, LoginHint::None));
294    }
295
296    #[test]
297    fn valid_login_hint() {
298        #[allow(clippy::disallowed_methods)]
299        let mut rng = thread_rng();
300
301        #[allow(clippy::disallowed_methods)]
302        let now = Utc::now();
303
304        let grant = AuthorizationGrant {
305            login_hint: Some(String::from("mxid:@example-user:example.com")),
306            ..AuthorizationGrant::sample(now, &mut rng)
307        };
308
309        let hint = grant.parse_login_hint("example.com");
310
311        assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user"));
312    }
313
314    #[test]
315    fn invalid_login_hint() {
316        #[allow(clippy::disallowed_methods)]
317        let mut rng = thread_rng();
318
319        #[allow(clippy::disallowed_methods)]
320        let now = Utc::now();
321
322        let grant = AuthorizationGrant {
323            login_hint: Some(String::from("example-user")),
324            ..AuthorizationGrant::sample(now, &mut rng)
325        };
326
327        let hint = grant.parse_login_hint("example.com");
328
329        assert!(matches!(hint, LoginHint::None));
330    }
331
332    #[test]
333    fn valid_login_hint_for_wrong_homeserver() {
334        #[allow(clippy::disallowed_methods)]
335        let mut rng = thread_rng();
336
337        #[allow(clippy::disallowed_methods)]
338        let now = Utc::now();
339
340        let grant = AuthorizationGrant {
341            login_hint: Some(String::from("mxid:@example-user:matrix.org")),
342            ..AuthorizationGrant::sample(now, &mut rng)
343        };
344
345        let hint = grant.parse_login_hint("example.com");
346
347        assert!(matches!(hint, LoginHint::None));
348    }
349
350    #[test]
351    fn unknown_login_hint_type() {
352        #[allow(clippy::disallowed_methods)]
353        let mut rng = thread_rng();
354
355        #[allow(clippy::disallowed_methods)]
356        let now = Utc::now();
357
358        let grant = AuthorizationGrant {
359            login_hint: Some(String::from("something:anything")),
360            ..AuthorizationGrant::sample(now, &mut rng)
361        };
362
363        let hint = grant.parse_login_hint("example.com");
364
365        assert!(matches!(hint, LoginHint::None));
366    }
367}