1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
11};
12use mas_iana::oauth::PkceCodeChallengeMethod;
13use mas_storage::{Clock, oauth2::OAuth2AuthorizationGrantRepository};
14use oauth2_types::{requests::ResponseMode, scope::Scope};
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use url::Url;
19use uuid::Uuid;
20
21use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
22
23pub struct PgOAuth2AuthorizationGrantRepository<'c> {
26 conn: &'c mut PgConnection,
27}
28
29impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
30 pub fn new(conn: &'c mut PgConnection) -> Self {
33 Self { conn }
34 }
35}
36
37#[allow(clippy::struct_excessive_bools)]
38struct GrantLookup {
39 oauth2_authorization_grant_id: Uuid,
40 created_at: DateTime<Utc>,
41 cancelled_at: Option<DateTime<Utc>>,
42 fulfilled_at: Option<DateTime<Utc>>,
43 exchanged_at: Option<DateTime<Utc>>,
44 scope: String,
45 state: Option<String>,
46 nonce: Option<String>,
47 redirect_uri: String,
48 response_mode: String,
49 response_type_code: bool,
50 response_type_id_token: bool,
51 authorization_code: Option<String>,
52 code_challenge: Option<String>,
53 code_challenge_method: Option<String>,
54 login_hint: Option<String>,
55 locale: Option<String>,
56 oauth2_client_id: Uuid,
57 oauth2_session_id: Option<Uuid>,
58}
59
60impl TryFrom<GrantLookup> for AuthorizationGrant {
61 type Error = DatabaseInconsistencyError;
62
63 #[allow(clippy::too_many_lines)]
64 fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
65 let id = value.oauth2_authorization_grant_id.into();
66 let scope: Scope = value.scope.parse().map_err(|e| {
67 DatabaseInconsistencyError::on("oauth2_authorization_grants")
68 .column("scope")
69 .row(id)
70 .source(e)
71 })?;
72
73 let stage = match (
74 value.fulfilled_at,
75 value.exchanged_at,
76 value.cancelled_at,
77 value.oauth2_session_id,
78 ) {
79 (None, None, None, None) => AuthorizationGrantStage::Pending,
80 (Some(fulfilled_at), None, None, Some(session_id)) => {
81 AuthorizationGrantStage::Fulfilled {
82 session_id: session_id.into(),
83 fulfilled_at,
84 }
85 }
86 (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
87 AuthorizationGrantStage::Exchanged {
88 session_id: session_id.into(),
89 fulfilled_at,
90 exchanged_at,
91 }
92 }
93 (None, None, Some(cancelled_at), None) => {
94 AuthorizationGrantStage::Cancelled { cancelled_at }
95 }
96 _ => {
97 return Err(
98 DatabaseInconsistencyError::on("oauth2_authorization_grants")
99 .column("stage")
100 .row(id),
101 );
102 }
103 };
104
105 let pkce = match (value.code_challenge, value.code_challenge_method) {
106 (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
107 Some(Pkce {
108 challenge_method: PkceCodeChallengeMethod::Plain,
109 challenge,
110 })
111 }
112 (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
113 challenge_method: PkceCodeChallengeMethod::S256,
114 challenge,
115 }),
116 (None, None) => None,
117 _ => {
118 return Err(
119 DatabaseInconsistencyError::on("oauth2_authorization_grants")
120 .column("code_challenge_method")
121 .row(id),
122 );
123 }
124 };
125
126 let code: Option<AuthorizationCode> =
127 match (value.response_type_code, value.authorization_code, pkce) {
128 (false, None, None) => None,
129 (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
130 _ => {
131 return Err(
132 DatabaseInconsistencyError::on("oauth2_authorization_grants")
133 .column("authorization_code")
134 .row(id),
135 );
136 }
137 };
138
139 let redirect_uri = value.redirect_uri.parse().map_err(|e| {
140 DatabaseInconsistencyError::on("oauth2_authorization_grants")
141 .column("redirect_uri")
142 .row(id)
143 .source(e)
144 })?;
145
146 let response_mode = value.response_mode.parse().map_err(|e| {
147 DatabaseInconsistencyError::on("oauth2_authorization_grants")
148 .column("response_mode")
149 .row(id)
150 .source(e)
151 })?;
152
153 Ok(AuthorizationGrant {
154 id,
155 stage,
156 client_id: value.oauth2_client_id.into(),
157 code,
158 scope,
159 state: value.state,
160 nonce: value.nonce,
161 response_mode,
162 redirect_uri,
163 created_at: value.created_at,
164 response_type_id_token: value.response_type_id_token,
165 login_hint: value.login_hint,
166 locale: value.locale,
167 })
168 }
169}
170
171#[async_trait]
172impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
173 type Error = DatabaseError;
174
175 #[tracing::instrument(
176 name = "db.oauth2_authorization_grant.add",
177 skip_all,
178 fields(
179 db.query.text,
180 grant.id,
181 grant.scope = %scope,
182 %client.id,
183 ),
184 err,
185 )]
186 async fn add(
187 &mut self,
188 rng: &mut (dyn RngCore + Send),
189 clock: &dyn Clock,
190 client: &Client,
191 redirect_uri: Url,
192 scope: Scope,
193 code: Option<AuthorizationCode>,
194 state: Option<String>,
195 nonce: Option<String>,
196 response_mode: ResponseMode,
197 response_type_id_token: bool,
198 login_hint: Option<String>,
199 locale: Option<String>,
200 ) -> Result<AuthorizationGrant, Self::Error> {
201 let code_challenge = code
202 .as_ref()
203 .and_then(|c| c.pkce.as_ref())
204 .map(|p| &p.challenge);
205 let code_challenge_method = code
206 .as_ref()
207 .and_then(|c| c.pkce.as_ref())
208 .map(|p| p.challenge_method.to_string());
209 let code_str = code.as_ref().map(|c| &c.code);
210
211 let created_at = clock.now();
212 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
213 tracing::Span::current().record("grant.id", tracing::field::display(id));
214
215 sqlx::query!(
216 r#"
217 INSERT INTO oauth2_authorization_grants (
218 oauth2_authorization_grant_id,
219 oauth2_client_id,
220 redirect_uri,
221 scope,
222 state,
223 nonce,
224 response_mode,
225 code_challenge,
226 code_challenge_method,
227 response_type_code,
228 response_type_id_token,
229 authorization_code,
230 login_hint,
231 locale,
232 created_at
233 )
234 VALUES
235 ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
236 "#,
237 Uuid::from(id),
238 Uuid::from(client.id),
239 redirect_uri.to_string(),
240 scope.to_string(),
241 state,
242 nonce,
243 response_mode.to_string(),
244 code_challenge,
245 code_challenge_method,
246 code.is_some(),
247 response_type_id_token,
248 code_str,
249 login_hint,
250 locale,
251 created_at,
252 )
253 .traced()
254 .execute(&mut *self.conn)
255 .await?;
256
257 Ok(AuthorizationGrant {
258 id,
259 stage: AuthorizationGrantStage::Pending,
260 code,
261 redirect_uri,
262 client_id: client.id,
263 scope,
264 state,
265 nonce,
266 response_mode,
267 created_at,
268 response_type_id_token,
269 login_hint,
270 locale,
271 })
272 }
273
274 #[tracing::instrument(
275 name = "db.oauth2_authorization_grant.lookup",
276 skip_all,
277 fields(
278 db.query.text,
279 grant.id = %id,
280 ),
281 err,
282 )]
283 async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
284 let res = sqlx::query_as!(
285 GrantLookup,
286 r#"
287 SELECT oauth2_authorization_grant_id
288 , created_at
289 , cancelled_at
290 , fulfilled_at
291 , exchanged_at
292 , scope
293 , state
294 , redirect_uri
295 , response_mode
296 , nonce
297 , oauth2_client_id
298 , authorization_code
299 , response_type_code
300 , response_type_id_token
301 , code_challenge
302 , code_challenge_method
303 , login_hint
304 , locale
305 , oauth2_session_id
306 FROM
307 oauth2_authorization_grants
308
309 WHERE oauth2_authorization_grant_id = $1
310 "#,
311 Uuid::from(id),
312 )
313 .traced()
314 .fetch_optional(&mut *self.conn)
315 .await?;
316
317 let Some(res) = res else { return Ok(None) };
318
319 Ok(Some(res.try_into()?))
320 }
321
322 #[tracing::instrument(
323 name = "db.oauth2_authorization_grant.find_by_code",
324 skip_all,
325 fields(
326 db.query.text,
327 ),
328 err,
329 )]
330 async fn find_by_code(
331 &mut self,
332 code: &str,
333 ) -> Result<Option<AuthorizationGrant>, Self::Error> {
334 let res = sqlx::query_as!(
335 GrantLookup,
336 r#"
337 SELECT oauth2_authorization_grant_id
338 , created_at
339 , cancelled_at
340 , fulfilled_at
341 , exchanged_at
342 , scope
343 , state
344 , redirect_uri
345 , response_mode
346 , nonce
347 , oauth2_client_id
348 , authorization_code
349 , response_type_code
350 , response_type_id_token
351 , code_challenge
352 , code_challenge_method
353 , login_hint
354 , locale
355 , oauth2_session_id
356 FROM
357 oauth2_authorization_grants
358
359 WHERE authorization_code = $1
360 "#,
361 code,
362 )
363 .traced()
364 .fetch_optional(&mut *self.conn)
365 .await?;
366
367 let Some(res) = res else { return Ok(None) };
368
369 Ok(Some(res.try_into()?))
370 }
371
372 #[tracing::instrument(
373 name = "db.oauth2_authorization_grant.fulfill",
374 skip_all,
375 fields(
376 db.query.text,
377 %grant.id,
378 client.id = %grant.client_id,
379 %session.id,
380 ),
381 err,
382 )]
383 async fn fulfill(
384 &mut self,
385 clock: &dyn Clock,
386 session: &Session,
387 grant: AuthorizationGrant,
388 ) -> Result<AuthorizationGrant, Self::Error> {
389 let fulfilled_at = clock.now();
390 let res = sqlx::query!(
391 r#"
392 UPDATE oauth2_authorization_grants
393 SET fulfilled_at = $2
394 , oauth2_session_id = $3
395 WHERE oauth2_authorization_grant_id = $1
396 "#,
397 Uuid::from(grant.id),
398 fulfilled_at,
399 Uuid::from(session.id),
400 )
401 .traced()
402 .execute(&mut *self.conn)
403 .await?;
404
405 DatabaseError::ensure_affected_rows(&res, 1)?;
406
407 let grant = grant
409 .fulfill(fulfilled_at, session)
410 .map_err(DatabaseError::to_invalid_operation)?;
411
412 Ok(grant)
413 }
414
415 #[tracing::instrument(
416 name = "db.oauth2_authorization_grant.exchange",
417 skip_all,
418 fields(
419 db.query.text,
420 %grant.id,
421 client.id = %grant.client_id,
422 ),
423 err,
424 )]
425 async fn exchange(
426 &mut self,
427 clock: &dyn Clock,
428 grant: AuthorizationGrant,
429 ) -> Result<AuthorizationGrant, Self::Error> {
430 let exchanged_at = clock.now();
431 let res = sqlx::query!(
432 r#"
433 UPDATE oauth2_authorization_grants
434 SET exchanged_at = $2
435 WHERE oauth2_authorization_grant_id = $1
436 "#,
437 Uuid::from(grant.id),
438 exchanged_at,
439 )
440 .traced()
441 .execute(&mut *self.conn)
442 .await?;
443
444 DatabaseError::ensure_affected_rows(&res, 1)?;
445
446 let grant = grant
447 .exchange(exchanged_at)
448 .map_err(DatabaseError::to_invalid_operation)?;
449
450 Ok(grant)
451 }
452}