mas_storage_pg/oauth2/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! A module containing the PostgreSQL implementations of the OAuth2-related
8//! repositories
9
10mod access_token;
11mod authorization_grant;
12mod client;
13mod device_code_grant;
14mod refresh_token;
15mod session;
16
17pub use self::{
18    access_token::PgOAuth2AccessTokenRepository,
19    authorization_grant::PgOAuth2AuthorizationGrantRepository, client::PgOAuth2ClientRepository,
20    device_code_grant::PgOAuth2DeviceCodeGrantRepository,
21    refresh_token::PgOAuth2RefreshTokenRepository, session::PgOAuth2SessionRepository,
22};
23
24#[cfg(test)]
25mod tests {
26    use chrono::Duration;
27    use mas_data_model::AuthorizationCode;
28    use mas_storage::{
29        Clock, Pagination,
30        clock::MockClock,
31        oauth2::{OAuth2DeviceCodeGrantParams, OAuth2SessionFilter, OAuth2SessionRepository},
32    };
33    use oauth2_types::{
34        requests::{GrantType, ResponseMode},
35        scope::{EMAIL, OPENID, PROFILE, Scope},
36    };
37    use rand::SeedableRng;
38    use rand_chacha::ChaChaRng;
39    use sqlx::PgPool;
40    use ulid::Ulid;
41
42    use crate::PgRepository;
43
44    #[sqlx::test(migrator = "crate::MIGRATOR")]
45    async fn test_repositories(pool: PgPool) {
46        let mut rng = ChaChaRng::seed_from_u64(42);
47        let clock = MockClock::default();
48        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
49
50        // Lookup a non-existing client
51        let client = repo.oauth2_client().lookup(Ulid::nil()).await.unwrap();
52        assert_eq!(client, None);
53
54        // Find a non-existing client by client id
55        let client = repo
56            .oauth2_client()
57            .find_by_client_id("some-client-id")
58            .await
59            .unwrap();
60        assert_eq!(client, None);
61
62        // Create a client
63        let client = repo
64            .oauth2_client()
65            .add(
66                &mut rng,
67                &clock,
68                vec!["https://example.com/redirect".parse().unwrap()],
69                None,
70                None,
71                None,
72                vec![GrantType::AuthorizationCode],
73                Some("Test client".to_owned()),
74                Some("https://example.com/logo.png".parse().unwrap()),
75                Some("https://example.com/".parse().unwrap()),
76                Some("https://example.com/policy".parse().unwrap()),
77                Some("https://example.com/tos".parse().unwrap()),
78                Some("https://example.com/jwks.json".parse().unwrap()),
79                None,
80                None,
81                None,
82                None,
83                None,
84                Some("https://example.com/login".parse().unwrap()),
85            )
86            .await
87            .unwrap();
88
89        // Lookup the same client by id
90        let client_lookup = repo
91            .oauth2_client()
92            .lookup(client.id)
93            .await
94            .unwrap()
95            .expect("client not found");
96        assert_eq!(client, client_lookup);
97
98        // Find the same client by client id
99        let client_lookup = repo
100            .oauth2_client()
101            .find_by_client_id(&client.client_id)
102            .await
103            .unwrap()
104            .expect("client not found");
105        assert_eq!(client, client_lookup);
106
107        // Lookup a non-existing grant
108        let grant = repo
109            .oauth2_authorization_grant()
110            .lookup(Ulid::nil())
111            .await
112            .unwrap();
113        assert_eq!(grant, None);
114
115        // Find a non-existing grant by code
116        let grant = repo
117            .oauth2_authorization_grant()
118            .find_by_code("code")
119            .await
120            .unwrap();
121        assert_eq!(grant, None);
122
123        // Create an authorization grant
124        let grant = repo
125            .oauth2_authorization_grant()
126            .add(
127                &mut rng,
128                &clock,
129                &client,
130                "https://example.com/redirect".parse().unwrap(),
131                Scope::from_iter([OPENID]),
132                Some(AuthorizationCode {
133                    code: "code".to_owned(),
134                    pkce: None,
135                }),
136                Some("state".to_owned()),
137                Some("nonce".to_owned()),
138                ResponseMode::Query,
139                true,
140                None,
141                None,
142            )
143            .await
144            .unwrap();
145        assert!(grant.is_pending());
146
147        // Lookup the same grant by id
148        let grant_lookup = repo
149            .oauth2_authorization_grant()
150            .lookup(grant.id)
151            .await
152            .unwrap()
153            .expect("grant not found");
154        assert_eq!(grant, grant_lookup);
155
156        // Find the same grant by code
157        let grant_lookup = repo
158            .oauth2_authorization_grant()
159            .find_by_code("code")
160            .await
161            .unwrap()
162            .expect("grant not found");
163        assert_eq!(grant, grant_lookup);
164
165        // Create a user and a start a user session
166        let user = repo
167            .user()
168            .add(&mut rng, &clock, "john".to_owned())
169            .await
170            .unwrap();
171        let user_session = repo
172            .browser_session()
173            .add(&mut rng, &clock, &user, None)
174            .await
175            .unwrap();
176
177        // Lookup a non-existing session
178        let session = repo.oauth2_session().lookup(Ulid::nil()).await.unwrap();
179        assert_eq!(session, None);
180
181        // Create an OAuth session
182        let session = repo
183            .oauth2_session()
184            .add_from_browser_session(
185                &mut rng,
186                &clock,
187                &client,
188                &user_session,
189                grant.scope.clone(),
190            )
191            .await
192            .unwrap();
193
194        // Mark the grant as fulfilled
195        let grant = repo
196            .oauth2_authorization_grant()
197            .fulfill(&clock, &session, grant)
198            .await
199            .unwrap();
200        assert!(grant.is_fulfilled());
201
202        // Lookup the same session by id
203        let session_lookup = repo
204            .oauth2_session()
205            .lookup(session.id)
206            .await
207            .unwrap()
208            .expect("session not found");
209        assert_eq!(session, session_lookup);
210
211        // Mark the grant as exchanged
212        let grant = repo
213            .oauth2_authorization_grant()
214            .exchange(&clock, grant)
215            .await
216            .unwrap();
217        assert!(grant.is_exchanged());
218
219        // Lookup a non-existing token
220        let token = repo
221            .oauth2_access_token()
222            .lookup(Ulid::nil())
223            .await
224            .unwrap();
225        assert_eq!(token, None);
226
227        // Find a non-existing token
228        let token = repo
229            .oauth2_access_token()
230            .find_by_token("aabbcc")
231            .await
232            .unwrap();
233        assert_eq!(token, None);
234
235        // Create an access token
236        let access_token = repo
237            .oauth2_access_token()
238            .add(
239                &mut rng,
240                &clock,
241                &session,
242                "aabbcc".to_owned(),
243                Some(Duration::try_minutes(5).unwrap()),
244            )
245            .await
246            .unwrap();
247
248        // Lookup the same token by id
249        let access_token_lookup = repo
250            .oauth2_access_token()
251            .lookup(access_token.id)
252            .await
253            .unwrap()
254            .expect("token not found");
255        assert_eq!(access_token, access_token_lookup);
256
257        // Find the same token by token
258        let access_token_lookup = repo
259            .oauth2_access_token()
260            .find_by_token("aabbcc")
261            .await
262            .unwrap()
263            .expect("token not found");
264        assert_eq!(access_token, access_token_lookup);
265
266        // Lookup a non-existing refresh token
267        let refresh_token = repo
268            .oauth2_refresh_token()
269            .lookup(Ulid::nil())
270            .await
271            .unwrap();
272        assert_eq!(refresh_token, None);
273
274        // Find a non-existing refresh token
275        let refresh_token = repo
276            .oauth2_refresh_token()
277            .find_by_token("aabbcc")
278            .await
279            .unwrap();
280        assert_eq!(refresh_token, None);
281
282        // Create a refresh token
283        let refresh_token = repo
284            .oauth2_refresh_token()
285            .add(
286                &mut rng,
287                &clock,
288                &session,
289                &access_token,
290                "aabbcc".to_owned(),
291            )
292            .await
293            .unwrap();
294
295        // Lookup the same refresh token by id
296        let refresh_token_lookup = repo
297            .oauth2_refresh_token()
298            .lookup(refresh_token.id)
299            .await
300            .unwrap()
301            .expect("refresh token not found");
302        assert_eq!(refresh_token, refresh_token_lookup);
303
304        // Find the same refresh token by token
305        let refresh_token_lookup = repo
306            .oauth2_refresh_token()
307            .find_by_token("aabbcc")
308            .await
309            .unwrap()
310            .expect("refresh token not found");
311        assert_eq!(refresh_token, refresh_token_lookup);
312
313        assert!(access_token.is_valid(clock.now()));
314        clock.advance(Duration::try_minutes(6).unwrap());
315        assert!(!access_token.is_valid(clock.now()));
316
317        // XXX: we might want to create a new access token
318        clock.advance(Duration::try_minutes(-6).unwrap()); // Go back in time
319        assert!(access_token.is_valid(clock.now()));
320
321        // Create a new refresh token to be able to consume the old one
322        let new_refresh_token = repo
323            .oauth2_refresh_token()
324            .add(
325                &mut rng,
326                &clock,
327                &session,
328                &access_token,
329                "ddeeff".to_owned(),
330            )
331            .await
332            .unwrap();
333
334        // Mark the access token as revoked
335        let access_token = repo
336            .oauth2_access_token()
337            .revoke(&clock, access_token)
338            .await
339            .unwrap();
340        assert!(!access_token.is_valid(clock.now()));
341
342        // Mark the refresh token as consumed
343        assert!(refresh_token.is_valid());
344        let refresh_token = repo
345            .oauth2_refresh_token()
346            .consume(&clock, refresh_token, &new_refresh_token)
347            .await
348            .unwrap();
349        assert!(!refresh_token.is_valid());
350
351        // Record the user-agent on the session
352        assert!(session.user_agent.is_none());
353        let session = repo
354            .oauth2_session()
355            .record_user_agent(session, "Mozilla/5.0".to_owned())
356            .await
357            .unwrap();
358        assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
359
360        // Reload the session and check the user-agent
361        let session = repo
362            .oauth2_session()
363            .lookup(session.id)
364            .await
365            .unwrap()
366            .expect("session not found");
367        assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
368
369        // Mark the session as finished
370        assert!(session.is_valid());
371        let session = repo.oauth2_session().finish(&clock, session).await.unwrap();
372        assert!(!session.is_valid());
373    }
374
375    /// Test the [`OAuth2SessionRepository::list`] and
376    /// [`OAuth2SessionRepository::count`] methods.
377    #[sqlx::test(migrator = "crate::MIGRATOR")]
378    async fn test_list_sessions(pool: PgPool) {
379        let mut rng = ChaChaRng::seed_from_u64(42);
380        let clock = MockClock::default();
381        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
382
383        // Create two users and their corresponding browser sessions
384        let user1 = repo
385            .user()
386            .add(&mut rng, &clock, "alice".to_owned())
387            .await
388            .unwrap();
389        let user1_session = repo
390            .browser_session()
391            .add(&mut rng, &clock, &user1, None)
392            .await
393            .unwrap();
394
395        let user2 = repo
396            .user()
397            .add(&mut rng, &clock, "bob".to_owned())
398            .await
399            .unwrap();
400        let user2_session = repo
401            .browser_session()
402            .add(&mut rng, &clock, &user2, None)
403            .await
404            .unwrap();
405
406        // Create two clients
407        let client1 = repo
408            .oauth2_client()
409            .add(
410                &mut rng,
411                &clock,
412                vec!["https://first.example.com/redirect".parse().unwrap()],
413                None,
414                None,
415                None,
416                vec![GrantType::AuthorizationCode],
417                Some("First client".to_owned()),
418                Some("https://first.example.com/logo.png".parse().unwrap()),
419                Some("https://first.example.com/".parse().unwrap()),
420                Some("https://first.example.com/policy".parse().unwrap()),
421                Some("https://first.example.com/tos".parse().unwrap()),
422                Some("https://first.example.com/jwks.json".parse().unwrap()),
423                None,
424                None,
425                None,
426                None,
427                None,
428                Some("https://first.example.com/login".parse().unwrap()),
429            )
430            .await
431            .unwrap();
432        let client2 = repo
433            .oauth2_client()
434            .add(
435                &mut rng,
436                &clock,
437                vec!["https://second.example.com/redirect".parse().unwrap()],
438                None,
439                None,
440                None,
441                vec![GrantType::AuthorizationCode],
442                Some("Second client".to_owned()),
443                Some("https://second.example.com/logo.png".parse().unwrap()),
444                Some("https://second.example.com/".parse().unwrap()),
445                Some("https://second.example.com/policy".parse().unwrap()),
446                Some("https://second.example.com/tos".parse().unwrap()),
447                Some("https://second.example.com/jwks.json".parse().unwrap()),
448                None,
449                None,
450                None,
451                None,
452                None,
453                Some("https://second.example.com/login".parse().unwrap()),
454            )
455            .await
456            .unwrap();
457
458        let scope = Scope::from_iter([OPENID, EMAIL]);
459        let scope2 = Scope::from_iter([OPENID, PROFILE]);
460
461        // Create two sessions for each user, one with each client
462        // We're moving the clock forward by 1 minute between each session to ensure
463        // we're getting consistent ordering in lists.
464        let session11 = repo
465            .oauth2_session()
466            .add_from_browser_session(&mut rng, &clock, &client1, &user1_session, scope.clone())
467            .await
468            .unwrap();
469        clock.advance(Duration::try_minutes(1).unwrap());
470
471        let session12 = repo
472            .oauth2_session()
473            .add_from_browser_session(&mut rng, &clock, &client1, &user2_session, scope.clone())
474            .await
475            .unwrap();
476        clock.advance(Duration::try_minutes(1).unwrap());
477
478        let session21 = repo
479            .oauth2_session()
480            .add_from_browser_session(&mut rng, &clock, &client2, &user1_session, scope2.clone())
481            .await
482            .unwrap();
483        clock.advance(Duration::try_minutes(1).unwrap());
484
485        let session22 = repo
486            .oauth2_session()
487            .add_from_browser_session(&mut rng, &clock, &client2, &user2_session, scope2.clone())
488            .await
489            .unwrap();
490        clock.advance(Duration::try_minutes(1).unwrap());
491
492        // We're also finishing two of the sessions
493        let session11 = repo
494            .oauth2_session()
495            .finish(&clock, session11)
496            .await
497            .unwrap();
498        let session22 = repo
499            .oauth2_session()
500            .finish(&clock, session22)
501            .await
502            .unwrap();
503
504        let pagination = Pagination::first(10);
505
506        // First, list all the sessions
507        let filter = OAuth2SessionFilter::new().for_any_user();
508        let list = repo
509            .oauth2_session()
510            .list(filter, pagination)
511            .await
512            .unwrap();
513        assert!(!list.has_next_page);
514        assert_eq!(list.edges.len(), 4);
515        assert_eq!(list.edges[0], session11);
516        assert_eq!(list.edges[1], session12);
517        assert_eq!(list.edges[2], session21);
518        assert_eq!(list.edges[3], session22);
519
520        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 4);
521
522        // Now filter for only one user
523        let filter = OAuth2SessionFilter::new().for_user(&user1);
524        let list = repo
525            .oauth2_session()
526            .list(filter, pagination)
527            .await
528            .unwrap();
529        assert!(!list.has_next_page);
530        assert_eq!(list.edges.len(), 2);
531        assert_eq!(list.edges[0], session11);
532        assert_eq!(list.edges[1], session21);
533
534        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
535
536        // Filter for only one client
537        let filter = OAuth2SessionFilter::new().for_client(&client1);
538        let list = repo
539            .oauth2_session()
540            .list(filter, pagination)
541            .await
542            .unwrap();
543        assert!(!list.has_next_page);
544        assert_eq!(list.edges.len(), 2);
545        assert_eq!(list.edges[0], session11);
546        assert_eq!(list.edges[1], session12);
547
548        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
549
550        // Filter for both a user and a client
551        let filter = OAuth2SessionFilter::new()
552            .for_user(&user2)
553            .for_client(&client2);
554        let list = repo
555            .oauth2_session()
556            .list(filter, pagination)
557            .await
558            .unwrap();
559        assert!(!list.has_next_page);
560        assert_eq!(list.edges.len(), 1);
561        assert_eq!(list.edges[0], session22);
562
563        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
564
565        // Filter for active sessions
566        let filter = OAuth2SessionFilter::new().active_only();
567        let list = repo
568            .oauth2_session()
569            .list(filter, pagination)
570            .await
571            .unwrap();
572        assert!(!list.has_next_page);
573        assert_eq!(list.edges.len(), 2);
574        assert_eq!(list.edges[0], session12);
575        assert_eq!(list.edges[1], session21);
576
577        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
578
579        // Filter for finished sessions
580        let filter = OAuth2SessionFilter::new().finished_only();
581        let list = repo
582            .oauth2_session()
583            .list(filter, pagination)
584            .await
585            .unwrap();
586        assert!(!list.has_next_page);
587        assert_eq!(list.edges.len(), 2);
588        assert_eq!(list.edges[0], session11);
589        assert_eq!(list.edges[1], session22);
590
591        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
592
593        // Combine the finished filter with the user filter
594        let filter = OAuth2SessionFilter::new().finished_only().for_user(&user2);
595        let list = repo
596            .oauth2_session()
597            .list(filter, pagination)
598            .await
599            .unwrap();
600        assert!(!list.has_next_page);
601        assert_eq!(list.edges.len(), 1);
602        assert_eq!(list.edges[0], session22);
603
604        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
605
606        // Combine the finished filter with the client filter
607        let filter = OAuth2SessionFilter::new()
608            .finished_only()
609            .for_client(&client2);
610        let list = repo
611            .oauth2_session()
612            .list(filter, pagination)
613            .await
614            .unwrap();
615        assert!(!list.has_next_page);
616        assert_eq!(list.edges.len(), 1);
617        assert_eq!(list.edges[0], session22);
618
619        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
620
621        // Combine the active filter with the user filter
622        let filter = OAuth2SessionFilter::new().active_only().for_user(&user2);
623        let list = repo
624            .oauth2_session()
625            .list(filter, pagination)
626            .await
627            .unwrap();
628        assert!(!list.has_next_page);
629        assert_eq!(list.edges.len(), 1);
630        assert_eq!(list.edges[0], session12);
631
632        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
633
634        // Combine the active filter with the client filter
635        let filter = OAuth2SessionFilter::new()
636            .active_only()
637            .for_client(&client2);
638        let list = repo
639            .oauth2_session()
640            .list(filter, pagination)
641            .await
642            .unwrap();
643        assert!(!list.has_next_page);
644        assert_eq!(list.edges.len(), 1);
645        assert_eq!(list.edges[0], session21);
646
647        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
648
649        // Try the scope filter. We should get all sessions with the "openid" scope
650        let scope = Scope::from_iter([OPENID]);
651        let filter = OAuth2SessionFilter::new().with_scope(&scope);
652        let list = repo
653            .oauth2_session()
654            .list(filter, pagination)
655            .await
656            .unwrap();
657        assert!(!list.has_next_page);
658        assert_eq!(list.edges.len(), 4);
659        assert_eq!(list.edges[0], session11);
660        assert_eq!(list.edges[1], session12);
661        assert_eq!(list.edges[2], session21);
662        assert_eq!(list.edges[3], session22);
663        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 4);
664
665        // We should get all sessions with the "openid" and "email" scope
666        let scope = Scope::from_iter([OPENID, EMAIL]);
667        let filter = OAuth2SessionFilter::new().with_scope(&scope);
668        let list = repo
669            .oauth2_session()
670            .list(filter, pagination)
671            .await
672            .unwrap();
673        assert!(!list.has_next_page);
674        assert_eq!(list.edges.len(), 2);
675        assert_eq!(list.edges[0], session11);
676        assert_eq!(list.edges[1], session12);
677        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 2);
678
679        // Try combining the scope filter with the user filter
680        let filter = OAuth2SessionFilter::new()
681            .with_scope(&scope)
682            .for_user(&user1);
683        let list = repo
684            .oauth2_session()
685            .list(filter, pagination)
686            .await
687            .unwrap();
688        assert_eq!(list.edges.len(), 1);
689        assert_eq!(list.edges[0], session11);
690        assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1);
691
692        // Finish all sessions of a client in batch
693        let affected = repo
694            .oauth2_session()
695            .finish_bulk(
696                &clock,
697                OAuth2SessionFilter::new()
698                    .for_client(&client1)
699                    .active_only(),
700            )
701            .await
702            .unwrap();
703        assert_eq!(affected, 1);
704
705        // We should have 3 finished sessions
706        assert_eq!(
707            repo.oauth2_session()
708                .count(OAuth2SessionFilter::new().finished_only())
709                .await
710                .unwrap(),
711            3
712        );
713
714        // We should have 1 active sessions
715        assert_eq!(
716            repo.oauth2_session()
717                .count(OAuth2SessionFilter::new().active_only())
718                .await
719                .unwrap(),
720            1
721        );
722    }
723
724    /// Test the [`OAuth2DeviceCodeGrantRepository`] implementation
725    #[sqlx::test(migrator = "crate::MIGRATOR")]
726    async fn test_device_code_grant_repository(pool: PgPool) {
727        let mut rng = ChaChaRng::seed_from_u64(42);
728        let clock = MockClock::default();
729        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
730
731        // Provision a client
732        let client = repo
733            .oauth2_client()
734            .add(
735                &mut rng,
736                &clock,
737                vec!["https://example.com/redirect".parse().unwrap()],
738                None,
739                None,
740                None,
741                vec![GrantType::AuthorizationCode],
742                Some("Example".to_owned()),
743                Some("https://example.com/logo.png".parse().unwrap()),
744                Some("https://example.com/".parse().unwrap()),
745                Some("https://example.com/policy".parse().unwrap()),
746                Some("https://example.com/tos".parse().unwrap()),
747                Some("https://example.com/jwks.json".parse().unwrap()),
748                None,
749                None,
750                None,
751                None,
752                None,
753                Some("https://example.com/login".parse().unwrap()),
754            )
755            .await
756            .unwrap();
757
758        // Provision a user
759        let user = repo
760            .user()
761            .add(&mut rng, &clock, "john".to_owned())
762            .await
763            .unwrap();
764
765        // Provision a browser session
766        let browser_session = repo
767            .browser_session()
768            .add(&mut rng, &clock, &user, None)
769            .await
770            .unwrap();
771
772        let user_code = "usercode";
773        let device_code = "devicecode";
774        let scope = Scope::from_iter([OPENID, EMAIL]);
775
776        // Create a device code grant
777        let grant = repo
778            .oauth2_device_code_grant()
779            .add(
780                &mut rng,
781                &clock,
782                OAuth2DeviceCodeGrantParams {
783                    client: &client,
784                    scope: scope.clone(),
785                    device_code: device_code.to_owned(),
786                    user_code: user_code.to_owned(),
787                    expires_in: Duration::try_minutes(5).unwrap(),
788                    ip_address: None,
789                    user_agent: None,
790                },
791            )
792            .await
793            .unwrap();
794
795        assert!(grant.is_pending());
796
797        // Check that we can find the grant by ID
798        let id = grant.id;
799        let lookup = repo.oauth2_device_code_grant().lookup(id).await.unwrap();
800        assert_eq!(lookup.as_ref(), Some(&grant));
801
802        // Check that we can find the grant by device code
803        let lookup = repo
804            .oauth2_device_code_grant()
805            .find_by_device_code(device_code)
806            .await
807            .unwrap();
808        assert_eq!(lookup.as_ref(), Some(&grant));
809
810        // Check that we can find the grant by user code
811        let lookup = repo
812            .oauth2_device_code_grant()
813            .find_by_user_code(user_code)
814            .await
815            .unwrap();
816        assert_eq!(lookup.as_ref(), Some(&grant));
817
818        // Let's mark it as fulfilled
819        let grant = repo
820            .oauth2_device_code_grant()
821            .fulfill(&clock, grant, &browser_session)
822            .await
823            .unwrap();
824        assert!(!grant.is_pending());
825        assert!(grant.is_fulfilled());
826
827        // Check that we can't mark it as rejected now
828        let res = repo
829            .oauth2_device_code_grant()
830            .reject(&clock, grant, &browser_session)
831            .await;
832        assert!(res.is_err());
833
834        // Look it up again
835        let grant = repo
836            .oauth2_device_code_grant()
837            .lookup(id)
838            .await
839            .unwrap()
840            .unwrap();
841
842        // We can't mark it as fulfilled again
843        let res = repo
844            .oauth2_device_code_grant()
845            .fulfill(&clock, grant, &browser_session)
846            .await;
847        assert!(res.is_err());
848
849        // Look it up again
850        let grant = repo
851            .oauth2_device_code_grant()
852            .lookup(id)
853            .await
854            .unwrap()
855            .unwrap();
856
857        // Create an OAuth 2.0 session
858        let session = repo
859            .oauth2_session()
860            .add_from_browser_session(&mut rng, &clock, &client, &browser_session, scope.clone())
861            .await
862            .unwrap();
863
864        // We can mark it as exchanged
865        let grant = repo
866            .oauth2_device_code_grant()
867            .exchange(&clock, grant, &session)
868            .await
869            .unwrap();
870        assert!(!grant.is_pending());
871        assert!(!grant.is_fulfilled());
872        assert!(grant.is_exchanged());
873
874        // We can't mark it as exchanged again
875        let res = repo
876            .oauth2_device_code_grant()
877            .exchange(&clock, grant, &session)
878            .await;
879        assert!(res.is_err());
880
881        // Do a new grant to reject it
882        let grant = repo
883            .oauth2_device_code_grant()
884            .add(
885                &mut rng,
886                &clock,
887                OAuth2DeviceCodeGrantParams {
888                    client: &client,
889                    scope: scope.clone(),
890                    device_code: "second_devicecode".to_owned(),
891                    user_code: "second_usercode".to_owned(),
892                    expires_in: Duration::try_minutes(5).unwrap(),
893                    ip_address: None,
894                    user_agent: None,
895                },
896            )
897            .await
898            .unwrap();
899
900        let id = grant.id;
901
902        // We can mark it as rejected
903        let grant = repo
904            .oauth2_device_code_grant()
905            .reject(&clock, grant, &browser_session)
906            .await
907            .unwrap();
908        assert!(!grant.is_pending());
909        assert!(grant.is_rejected());
910
911        // We can't mark it as rejected again
912        let res = repo
913            .oauth2_device_code_grant()
914            .reject(&clock, grant, &browser_session)
915            .await;
916        assert!(res.is_err());
917
918        // Look it up again
919        let grant = repo
920            .oauth2_device_code_grant()
921            .lookup(id)
922            .await
923            .unwrap()
924            .unwrap();
925
926        // We can't mark it as fulfilled
927        let res = repo
928            .oauth2_device_code_grant()
929            .fulfill(&clock, grant, &browser_session)
930            .await;
931        assert!(res.is_err());
932
933        // Look it up again
934        let grant = repo
935            .oauth2_device_code_grant()
936            .lookup(id)
937            .await
938            .unwrap()
939            .unwrap();
940
941        // We can't mark it as exchanged
942        let res = repo
943            .oauth2_device_code_grant()
944            .exchange(&clock, grant, &session)
945            .await;
946        assert!(res.is_err());
947    }
948}