mas_storage_pg/oauth2/
session.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Session, SessionState, User};
12use mas_storage::{
13    Clock, Page, Pagination,
14    oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
15};
16use oauth2_types::scope::{Scope, ScopeToken};
17use rand::RngCore;
18use sea_query::{Expr, PgFunc, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25    DatabaseError, DatabaseInconsistencyError,
26    filter::{Filter, StatementExt},
27    iden::{OAuth2Clients, OAuth2Sessions},
28    pagination::QueryBuilderExt,
29    tracing::ExecuteExt,
30};
31
32/// An implementation of [`OAuth2SessionRepository`] for a PostgreSQL connection
33pub struct PgOAuth2SessionRepository<'c> {
34    conn: &'c mut PgConnection,
35}
36
37impl<'c> PgOAuth2SessionRepository<'c> {
38    /// Create a new [`PgOAuth2SessionRepository`] from an active PostgreSQL
39    /// connection
40    pub fn new(conn: &'c mut PgConnection) -> Self {
41        Self { conn }
42    }
43}
44
45#[derive(sqlx::FromRow)]
46#[enum_def]
47struct OAuthSessionLookup {
48    oauth2_session_id: Uuid,
49    user_id: Option<Uuid>,
50    user_session_id: Option<Uuid>,
51    oauth2_client_id: Uuid,
52    scope_list: Vec<String>,
53    created_at: DateTime<Utc>,
54    finished_at: Option<DateTime<Utc>>,
55    user_agent: Option<String>,
56    last_active_at: Option<DateTime<Utc>>,
57    last_active_ip: Option<IpAddr>,
58    human_name: Option<String>,
59}
60
61impl TryFrom<OAuthSessionLookup> for Session {
62    type Error = DatabaseInconsistencyError;
63
64    fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
65        let id = Ulid::from(value.oauth2_session_id);
66        let scope: Result<Scope, _> = value
67            .scope_list
68            .iter()
69            .map(|s| s.parse::<ScopeToken>())
70            .collect();
71        let scope = scope.map_err(|e| {
72            DatabaseInconsistencyError::on("oauth2_sessions")
73                .column("scope")
74                .row(id)
75                .source(e)
76        })?;
77
78        let state = match value.finished_at {
79            None => SessionState::Valid,
80            Some(finished_at) => SessionState::Finished { finished_at },
81        };
82
83        Ok(Session {
84            id,
85            state,
86            created_at: value.created_at,
87            client_id: value.oauth2_client_id.into(),
88            user_id: value.user_id.map(Ulid::from),
89            user_session_id: value.user_session_id.map(Ulid::from),
90            scope,
91            user_agent: value.user_agent,
92            last_active_at: value.last_active_at,
93            last_active_ip: value.last_active_ip,
94            human_name: value.human_name,
95        })
96    }
97}
98
99impl Filter for OAuth2SessionFilter<'_> {
100    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
101        sea_query::Condition::all()
102            .add_option(self.user().map(|user| {
103                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
104            }))
105            .add_option(self.client().map(|client| {
106                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
107                    .eq(Uuid::from(client.id))
108            }))
109            .add_option(self.client_kind().map(|client_kind| {
110                // This builds either a:
111                // `WHERE oauth2_client_id = ANY(...)`
112                // or a `WHERE oauth2_client_id <> ALL(...)`
113                let static_clients = Query::select()
114                    .expr(Expr::col((
115                        OAuth2Clients::Table,
116                        OAuth2Clients::OAuth2ClientId,
117                    )))
118                    .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
119                    .from(OAuth2Clients::Table)
120                    .take();
121                if client_kind.is_static() {
122                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
123                        .eq(Expr::any(static_clients))
124                } else {
125                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
126                        .ne(Expr::all(static_clients))
127                }
128            }))
129            .add_option(self.device().map(|device| {
130                if let Ok(scope_token) = device.to_scope_token() {
131                    Expr::val(scope_token.to_string()).eq(PgFunc::any(Expr::col((
132                        OAuth2Sessions::Table,
133                        OAuth2Sessions::ScopeList,
134                    ))))
135                } else {
136                    // If the device ID can't be encoded as a scope token, match no rows
137                    Expr::val(false).into()
138                }
139            }))
140            .add_option(self.browser_session().map(|browser_session| {
141                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
142                    .eq(Uuid::from(browser_session.id))
143            }))
144            .add_option(self.state().map(|state| {
145                if state.is_active() {
146                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
147                } else {
148                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
149                }
150            }))
151            .add_option(self.scope().map(|scope| {
152                let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
153                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
154            }))
155            .add_option(self.any_user().map(|any_user| {
156                if any_user {
157                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
158                } else {
159                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
160                }
161            }))
162            .add_option(self.last_active_after().map(|last_active_after| {
163                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
164                    .gt(last_active_after)
165            }))
166            .add_option(self.last_active_before().map(|last_active_before| {
167                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
168                    .lt(last_active_before)
169            }))
170    }
171}
172
173#[async_trait]
174impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
175    type Error = DatabaseError;
176
177    #[tracing::instrument(
178        name = "db.oauth2_session.lookup",
179        skip_all,
180        fields(
181            db.query.text,
182            session.id = %id,
183        ),
184        err,
185    )]
186    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
187        let res = sqlx::query_as!(
188            OAuthSessionLookup,
189            r#"
190                SELECT oauth2_session_id
191                     , user_id
192                     , user_session_id
193                     , oauth2_client_id
194                     , scope_list
195                     , created_at
196                     , finished_at
197                     , user_agent
198                     , last_active_at
199                     , last_active_ip as "last_active_ip: IpAddr"
200                     , human_name
201                FROM oauth2_sessions
202
203                WHERE oauth2_session_id = $1
204            "#,
205            Uuid::from(id),
206        )
207        .traced()
208        .fetch_optional(&mut *self.conn)
209        .await?;
210
211        let Some(session) = res else { return Ok(None) };
212
213        Ok(Some(session.try_into()?))
214    }
215
216    #[tracing::instrument(
217        name = "db.oauth2_session.add",
218        skip_all,
219        fields(
220            db.query.text,
221            %client.id,
222            session.id,
223            session.scope = %scope,
224        ),
225        err,
226    )]
227    async fn add(
228        &mut self,
229        rng: &mut (dyn RngCore + Send),
230        clock: &dyn Clock,
231        client: &Client,
232        user: Option<&User>,
233        user_session: Option<&BrowserSession>,
234        scope: Scope,
235    ) -> Result<Session, Self::Error> {
236        let created_at = clock.now();
237        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
238        tracing::Span::current().record("session.id", tracing::field::display(id));
239
240        let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
241
242        sqlx::query!(
243            r#"
244                INSERT INTO oauth2_sessions
245                    ( oauth2_session_id
246                    , user_id
247                    , user_session_id
248                    , oauth2_client_id
249                    , scope_list
250                    , created_at
251                    )
252                VALUES ($1, $2, $3, $4, $5, $6)
253            "#,
254            Uuid::from(id),
255            user.map(|u| Uuid::from(u.id)),
256            user_session.map(|s| Uuid::from(s.id)),
257            Uuid::from(client.id),
258            &scope_list,
259            created_at,
260        )
261        .traced()
262        .execute(&mut *self.conn)
263        .await?;
264
265        Ok(Session {
266            id,
267            state: SessionState::Valid,
268            created_at,
269            user_id: user.map(|u| u.id),
270            user_session_id: user_session.map(|s| s.id),
271            client_id: client.id,
272            scope,
273            user_agent: None,
274            last_active_at: None,
275            last_active_ip: None,
276            human_name: None,
277        })
278    }
279
280    #[tracing::instrument(
281        name = "db.oauth2_session.finish_bulk",
282        skip_all,
283        fields(
284            db.query.text,
285        ),
286        err,
287    )]
288    async fn finish_bulk(
289        &mut self,
290        clock: &dyn Clock,
291        filter: OAuth2SessionFilter<'_>,
292    ) -> Result<usize, Self::Error> {
293        let finished_at = clock.now();
294        let (sql, arguments) = Query::update()
295            .table(OAuth2Sessions::Table)
296            .value(OAuth2Sessions::FinishedAt, finished_at)
297            .apply_filter(filter)
298            .build_sqlx(PostgresQueryBuilder);
299
300        let res = sqlx::query_with(&sql, arguments)
301            .traced()
302            .execute(&mut *self.conn)
303            .await?;
304
305        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
306    }
307
308    #[tracing::instrument(
309        name = "db.oauth2_session.finish",
310        skip_all,
311        fields(
312            db.query.text,
313            %session.id,
314            %session.scope,
315            client.id = %session.client_id,
316        ),
317        err,
318    )]
319    async fn finish(
320        &mut self,
321        clock: &dyn Clock,
322        session: Session,
323    ) -> Result<Session, Self::Error> {
324        let finished_at = clock.now();
325        let res = sqlx::query!(
326            r#"
327                UPDATE oauth2_sessions
328                SET finished_at = $2
329                WHERE oauth2_session_id = $1
330            "#,
331            Uuid::from(session.id),
332            finished_at,
333        )
334        .traced()
335        .execute(&mut *self.conn)
336        .await?;
337
338        DatabaseError::ensure_affected_rows(&res, 1)?;
339
340        session
341            .finish(finished_at)
342            .map_err(DatabaseError::to_invalid_operation)
343    }
344
345    #[tracing::instrument(
346        name = "db.oauth2_session.list",
347        skip_all,
348        fields(
349            db.query.text,
350        ),
351        err,
352    )]
353    async fn list(
354        &mut self,
355        filter: OAuth2SessionFilter<'_>,
356        pagination: Pagination,
357    ) -> Result<Page<Session>, Self::Error> {
358        let (sql, arguments) = Query::select()
359            .expr_as(
360                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
361                OAuthSessionLookupIden::Oauth2SessionId,
362            )
363            .expr_as(
364                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
365                OAuthSessionLookupIden::UserId,
366            )
367            .expr_as(
368                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
369                OAuthSessionLookupIden::UserSessionId,
370            )
371            .expr_as(
372                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
373                OAuthSessionLookupIden::Oauth2ClientId,
374            )
375            .expr_as(
376                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
377                OAuthSessionLookupIden::ScopeList,
378            )
379            .expr_as(
380                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
381                OAuthSessionLookupIden::CreatedAt,
382            )
383            .expr_as(
384                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
385                OAuthSessionLookupIden::FinishedAt,
386            )
387            .expr_as(
388                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
389                OAuthSessionLookupIden::UserAgent,
390            )
391            .expr_as(
392                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
393                OAuthSessionLookupIden::LastActiveAt,
394            )
395            .expr_as(
396                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
397                OAuthSessionLookupIden::LastActiveIp,
398            )
399            .expr_as(
400                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
401                OAuthSessionLookupIden::HumanName,
402            )
403            .from(OAuth2Sessions::Table)
404            .apply_filter(filter)
405            .generate_pagination(
406                (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
407                pagination,
408            )
409            .build_sqlx(PostgresQueryBuilder);
410
411        let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
412            .traced()
413            .fetch_all(&mut *self.conn)
414            .await?;
415
416        let page = pagination.process(edges).try_map(Session::try_from)?;
417
418        Ok(page)
419    }
420
421    #[tracing::instrument(
422        name = "db.oauth2_session.count",
423        skip_all,
424        fields(
425            db.query.text,
426        ),
427        err,
428    )]
429    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
430        let (sql, arguments) = Query::select()
431            .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
432            .from(OAuth2Sessions::Table)
433            .apply_filter(filter)
434            .build_sqlx(PostgresQueryBuilder);
435
436        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
437            .traced()
438            .fetch_one(&mut *self.conn)
439            .await?;
440
441        count
442            .try_into()
443            .map_err(DatabaseError::to_invalid_operation)
444    }
445
446    #[tracing::instrument(
447        name = "db.oauth2_session.record_batch_activity",
448        skip_all,
449        fields(
450            db.query.text,
451        ),
452        err,
453    )]
454    async fn record_batch_activity(
455        &mut self,
456        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
457    ) -> Result<(), Self::Error> {
458        let mut ids = Vec::with_capacity(activity.len());
459        let mut last_activities = Vec::with_capacity(activity.len());
460        let mut ips = Vec::with_capacity(activity.len());
461
462        for (id, last_activity, ip) in activity {
463            ids.push(Uuid::from(id));
464            last_activities.push(last_activity);
465            ips.push(ip);
466        }
467
468        let res = sqlx::query!(
469            r#"
470                UPDATE oauth2_sessions
471                SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
472                  , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
473                FROM (
474                    SELECT *
475                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
476                        AS t(oauth2_session_id, last_active_at, last_active_ip)
477                ) AS t
478                WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
479            "#,
480            &ids,
481            &last_activities,
482            &ips as &[Option<IpAddr>],
483        )
484        .traced()
485        .execute(&mut *self.conn)
486        .await?;
487
488        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
489
490        Ok(())
491    }
492
493    #[tracing::instrument(
494        name = "db.oauth2_session.record_user_agent",
495        skip_all,
496        fields(
497            db.query.text,
498            %session.id,
499            %session.scope,
500            client.id = %session.client_id,
501            session.user_agent = user_agent,
502        ),
503        err,
504    )]
505    async fn record_user_agent(
506        &mut self,
507        mut session: Session,
508        user_agent: String,
509    ) -> Result<Session, Self::Error> {
510        let res = sqlx::query!(
511            r#"
512                UPDATE oauth2_sessions
513                SET user_agent = $2
514                WHERE oauth2_session_id = $1
515            "#,
516            Uuid::from(session.id),
517            &*user_agent,
518        )
519        .traced()
520        .execute(&mut *self.conn)
521        .await?;
522
523        session.user_agent = Some(user_agent);
524
525        DatabaseError::ensure_affected_rows(&res, 1)?;
526
527        Ok(session)
528    }
529
530    #[tracing::instrument(
531        name = "repository.oauth2_session.set_human_name",
532        skip(self),
533        fields(
534            client.id = %session.client_id,
535            session.human_name = ?human_name,
536        ),
537        err,
538    )]
539    async fn set_human_name(
540        &mut self,
541        mut session: Session,
542        human_name: Option<String>,
543    ) -> Result<Session, Self::Error> {
544        let res = sqlx::query!(
545            r#"
546                UPDATE oauth2_sessions
547                SET human_name = $2
548                WHERE oauth2_session_id = $1
549            "#,
550            Uuid::from(session.id),
551            human_name.as_deref(),
552        )
553        .traced()
554        .execute(&mut *self.conn)
555        .await?;
556
557        session.human_name = human_name;
558
559        DatabaseError::ensure_affected_rows(&res, 1)?;
560
561        Ok(session)
562    }
563}