mas_storage_pg/compat/
session.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12    BrowserSession, CompatSession, CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device,
13    User,
14};
15use mas_storage::{
16    Clock, Page, Pagination,
17    compat::{CompatSessionFilter, CompatSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use url::Url;
25use uuid::Uuid;
26
27use crate::{
28    DatabaseError, DatabaseInconsistencyError,
29    filter::{Filter, StatementExt, StatementWithJoinsExt},
30    iden::{CompatSessions, CompatSsoLogins},
31    pagination::QueryBuilderExt,
32    tracing::ExecuteExt,
33};
34
35/// An implementation of [`CompatSessionRepository`] for a PostgreSQL connection
36pub struct PgCompatSessionRepository<'c> {
37    conn: &'c mut PgConnection,
38}
39
40impl<'c> PgCompatSessionRepository<'c> {
41    /// Create a new [`PgCompatSessionRepository`] from an active PostgreSQL
42    /// connection
43    pub fn new(conn: &'c mut PgConnection) -> Self {
44        Self { conn }
45    }
46}
47
48struct CompatSessionLookup {
49    compat_session_id: Uuid,
50    device_id: Option<String>,
51    human_name: Option<String>,
52    user_id: Uuid,
53    user_session_id: Option<Uuid>,
54    created_at: DateTime<Utc>,
55    finished_at: Option<DateTime<Utc>>,
56    is_synapse_admin: bool,
57    user_agent: Option<String>,
58    last_active_at: Option<DateTime<Utc>>,
59    last_active_ip: Option<IpAddr>,
60}
61
62impl From<CompatSessionLookup> for CompatSession {
63    fn from(value: CompatSessionLookup) -> Self {
64        let id = value.compat_session_id.into();
65
66        let state = match value.finished_at {
67            None => CompatSessionState::Valid,
68            Some(finished_at) => CompatSessionState::Finished { finished_at },
69        };
70
71        CompatSession {
72            id,
73            state,
74            user_id: value.user_id.into(),
75            user_session_id: value.user_session_id.map(Ulid::from),
76            device: value.device_id.map(Device::from),
77            human_name: value.human_name,
78            created_at: value.created_at,
79            is_synapse_admin: value.is_synapse_admin,
80            user_agent: value.user_agent,
81            last_active_at: value.last_active_at,
82            last_active_ip: value.last_active_ip,
83        }
84    }
85}
86
87#[derive(sqlx::FromRow)]
88#[enum_def]
89struct CompatSessionAndSsoLoginLookup {
90    compat_session_id: Uuid,
91    device_id: Option<String>,
92    human_name: Option<String>,
93    user_id: Uuid,
94    user_session_id: Option<Uuid>,
95    created_at: DateTime<Utc>,
96    finished_at: Option<DateTime<Utc>>,
97    is_synapse_admin: bool,
98    user_agent: Option<String>,
99    last_active_at: Option<DateTime<Utc>>,
100    last_active_ip: Option<IpAddr>,
101    compat_sso_login_id: Option<Uuid>,
102    compat_sso_login_token: Option<String>,
103    compat_sso_login_redirect_uri: Option<String>,
104    compat_sso_login_created_at: Option<DateTime<Utc>>,
105    compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
106    compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
107}
108
109impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSsoLogin>) {
110    type Error = DatabaseInconsistencyError;
111
112    fn try_from(value: CompatSessionAndSsoLoginLookup) -> Result<Self, Self::Error> {
113        let id = value.compat_session_id.into();
114
115        let state = match value.finished_at {
116            None => CompatSessionState::Valid,
117            Some(finished_at) => CompatSessionState::Finished { finished_at },
118        };
119
120        let session = CompatSession {
121            id,
122            state,
123            user_id: value.user_id.into(),
124            device: value.device_id.map(Device::from),
125            human_name: value.human_name,
126            user_session_id: value.user_session_id.map(Ulid::from),
127            created_at: value.created_at,
128            is_synapse_admin: value.is_synapse_admin,
129            user_agent: value.user_agent,
130            last_active_at: value.last_active_at,
131            last_active_ip: value.last_active_ip,
132        };
133
134        match (
135            value.compat_sso_login_id,
136            value.compat_sso_login_token,
137            value.compat_sso_login_redirect_uri,
138            value.compat_sso_login_created_at,
139            value.compat_sso_login_fulfilled_at,
140            value.compat_sso_login_exchanged_at,
141        ) {
142            (None, None, None, None, None, None) => Ok((session, None)),
143            (
144                Some(id),
145                Some(login_token),
146                Some(redirect_uri),
147                Some(created_at),
148                fulfilled_at,
149                exchanged_at,
150            ) => {
151                let id = id.into();
152                let redirect_uri = Url::parse(&redirect_uri).map_err(|e| {
153                    DatabaseInconsistencyError::on("compat_sso_logins")
154                        .column("redirect_uri")
155                        .row(id)
156                        .source(e)
157                })?;
158
159                let state = match (fulfilled_at, exchanged_at) {
160                    (Some(fulfilled_at), Some(exchanged_at)) => CompatSsoLoginState::Exchanged {
161                        fulfilled_at,
162                        exchanged_at,
163                        compat_session_id: session.id,
164                    },
165                    _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
166                };
167
168                let login = CompatSsoLogin {
169                    id,
170                    redirect_uri,
171                    login_token,
172                    created_at,
173                    state,
174                };
175
176                Ok((session, Some(login)))
177            }
178            _ => Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
179        }
180    }
181}
182
183impl Filter for CompatSessionFilter<'_> {
184    fn generate_condition(&self, has_joins: bool) -> impl sea_query::IntoCondition {
185        sea_query::Condition::all()
186            .add_option(self.user().map(|user| {
187                Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
188            }))
189            .add_option(self.browser_session().map(|browser_session| {
190                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
191                    .eq(Uuid::from(browser_session.id))
192            }))
193            .add_option(self.state().map(|state| {
194                if state.is_active() {
195                    Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
196                } else {
197                    Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
198                }
199            }))
200            .add_option(self.auth_type().map(|auth_type| {
201                // In in the SELECT to list sessions, we can rely on the JOINed table, whereas
202                // in other queries we need to do a subquery
203                if has_joins {
204                    if auth_type.is_sso_login() {
205                        Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
206                            .is_not_null()
207                    } else {
208                        Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
209                            .is_null()
210                    }
211                } else {
212                    // This builds either a:
213                    // `WHERE compat_session_id = ANY(...)`
214                    // or a `WHERE compat_session_id <> ALL(...)`
215                    let compat_sso_logins = Query::select()
216                        .expr(Expr::col((
217                            CompatSsoLogins::Table,
218                            CompatSsoLogins::CompatSessionId,
219                        )))
220                        .from(CompatSsoLogins::Table)
221                        .take();
222
223                    if auth_type.is_sso_login() {
224                        Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
225                            .eq(Expr::any(compat_sso_logins))
226                    } else {
227                        Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
228                            .ne(Expr::all(compat_sso_logins))
229                    }
230                }
231            }))
232            .add_option(self.last_active_after().map(|last_active_after| {
233                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
234                    .gt(last_active_after)
235            }))
236            .add_option(self.last_active_before().map(|last_active_before| {
237                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
238                    .lt(last_active_before)
239            }))
240            .add_option(self.device().map(|device| {
241                Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
242            }))
243    }
244}
245
246#[async_trait]
247impl CompatSessionRepository for PgCompatSessionRepository<'_> {
248    type Error = DatabaseError;
249
250    #[tracing::instrument(
251        name = "db.compat_session.lookup",
252        skip_all,
253        fields(
254            db.query.text,
255            compat_session.id = %id,
256        ),
257        err,
258    )]
259    async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
260        let res = sqlx::query_as!(
261            CompatSessionLookup,
262            r#"
263                SELECT compat_session_id
264                     , device_id
265                     , human_name
266                     , user_id
267                     , user_session_id
268                     , created_at
269                     , finished_at
270                     , is_synapse_admin
271                     , user_agent
272                     , last_active_at
273                     , last_active_ip as "last_active_ip: IpAddr"
274                FROM compat_sessions
275                WHERE compat_session_id = $1
276            "#,
277            Uuid::from(id),
278        )
279        .traced()
280        .fetch_optional(&mut *self.conn)
281        .await?;
282
283        let Some(res) = res else { return Ok(None) };
284
285        Ok(Some(res.into()))
286    }
287
288    #[tracing::instrument(
289        name = "db.compat_session.add",
290        skip_all,
291        fields(
292            db.query.text,
293            compat_session.id,
294            %user.id,
295            %user.username,
296            compat_session.device.id = device.as_str(),
297        ),
298        err,
299    )]
300    async fn add(
301        &mut self,
302        rng: &mut (dyn RngCore + Send),
303        clock: &dyn Clock,
304        user: &User,
305        device: Device,
306        browser_session: Option<&BrowserSession>,
307        is_synapse_admin: bool,
308        human_name: Option<String>,
309    ) -> Result<CompatSession, Self::Error> {
310        let created_at = clock.now();
311        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
312        tracing::Span::current().record("compat_session.id", tracing::field::display(id));
313
314        sqlx::query!(
315            r#"
316                INSERT INTO compat_sessions
317                    (compat_session_id, user_id, device_id,
318                     user_session_id, created_at, is_synapse_admin,
319                     human_name)
320                VALUES ($1, $2, $3, $4, $5, $6, $7)
321            "#,
322            Uuid::from(id),
323            Uuid::from(user.id),
324            device.as_str(),
325            browser_session.map(|s| Uuid::from(s.id)),
326            created_at,
327            is_synapse_admin,
328            human_name.as_deref(),
329        )
330        .traced()
331        .execute(&mut *self.conn)
332        .await?;
333
334        Ok(CompatSession {
335            id,
336            state: CompatSessionState::default(),
337            user_id: user.id,
338            device: Some(device),
339            human_name,
340            user_session_id: browser_session.map(|s| s.id),
341            created_at,
342            is_synapse_admin,
343            user_agent: None,
344            last_active_at: None,
345            last_active_ip: None,
346        })
347    }
348
349    #[tracing::instrument(
350        name = "db.compat_session.finish",
351        skip_all,
352        fields(
353            db.query.text,
354            %compat_session.id,
355            user.id = %compat_session.user_id,
356            compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
357        ),
358        err,
359    )]
360    async fn finish(
361        &mut self,
362        clock: &dyn Clock,
363        compat_session: CompatSession,
364    ) -> Result<CompatSession, Self::Error> {
365        let finished_at = clock.now();
366
367        let res = sqlx::query!(
368            r#"
369                UPDATE compat_sessions cs
370                SET finished_at = $2
371                WHERE compat_session_id = $1
372            "#,
373            Uuid::from(compat_session.id),
374            finished_at,
375        )
376        .traced()
377        .execute(&mut *self.conn)
378        .await?;
379
380        DatabaseError::ensure_affected_rows(&res, 1)?;
381
382        let compat_session = compat_session
383            .finish(finished_at)
384            .map_err(DatabaseError::to_invalid_operation)?;
385
386        Ok(compat_session)
387    }
388
389    #[tracing::instrument(
390        name = "db.compat_session.finish_bulk",
391        skip_all,
392        fields(db.query.text),
393        err,
394    )]
395    async fn finish_bulk(
396        &mut self,
397        clock: &dyn Clock,
398        filter: CompatSessionFilter<'_>,
399    ) -> Result<usize, Self::Error> {
400        let finished_at = clock.now();
401        let (sql, arguments) = Query::update()
402            .table(CompatSessions::Table)
403            .value(CompatSessions::FinishedAt, finished_at)
404            .apply_filter(filter)
405            .build_sqlx(PostgresQueryBuilder);
406
407        let res = sqlx::query_with(&sql, arguments)
408            .traced()
409            .execute(&mut *self.conn)
410            .await?;
411
412        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
413    }
414
415    #[tracing::instrument(
416        name = "db.compat_session.list",
417        skip_all,
418        fields(
419            db.query.text,
420        ),
421        err,
422    )]
423    async fn list(
424        &mut self,
425        filter: CompatSessionFilter<'_>,
426        pagination: Pagination,
427    ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> {
428        let (sql, arguments) = Query::select()
429            .expr_as(
430                Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
431                CompatSessionAndSsoLoginLookupIden::CompatSessionId,
432            )
433            .expr_as(
434                Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
435                CompatSessionAndSsoLoginLookupIden::DeviceId,
436            )
437            .expr_as(
438                Expr::col((CompatSessions::Table, CompatSessions::HumanName)),
439                CompatSessionAndSsoLoginLookupIden::HumanName,
440            )
441            .expr_as(
442                Expr::col((CompatSessions::Table, CompatSessions::UserId)),
443                CompatSessionAndSsoLoginLookupIden::UserId,
444            )
445            .expr_as(
446                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
447                CompatSessionAndSsoLoginLookupIden::UserSessionId,
448            )
449            .expr_as(
450                Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
451                CompatSessionAndSsoLoginLookupIden::CreatedAt,
452            )
453            .expr_as(
454                Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
455                CompatSessionAndSsoLoginLookupIden::FinishedAt,
456            )
457            .expr_as(
458                Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
459                CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin,
460            )
461            .expr_as(
462                Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
463                CompatSessionAndSsoLoginLookupIden::UserAgent,
464            )
465            .expr_as(
466                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
467                CompatSessionAndSsoLoginLookupIden::LastActiveAt,
468            )
469            .expr_as(
470                Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
471                CompatSessionAndSsoLoginLookupIden::LastActiveIp,
472            )
473            .expr_as(
474                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
475                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginId,
476            )
477            .expr_as(
478                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
479                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginToken,
480            )
481            .expr_as(
482                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
483                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginRedirectUri,
484            )
485            .expr_as(
486                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
487                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginCreatedAt,
488            )
489            .expr_as(
490                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
491                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginFulfilledAt,
492            )
493            .expr_as(
494                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
495                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginExchangedAt,
496            )
497            .from(CompatSessions::Table)
498            .left_join(
499                CompatSsoLogins::Table,
500                Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
501                    .equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
502            )
503            .apply_filter_with_joins(filter)
504            .generate_pagination(
505                (CompatSessions::Table, CompatSessions::CompatSessionId),
506                pagination,
507            )
508            .build_sqlx(PostgresQueryBuilder);
509
510        let edges: Vec<CompatSessionAndSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
511            .traced()
512            .fetch_all(&mut *self.conn)
513            .await?;
514
515        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
516
517        Ok(page)
518    }
519
520    #[tracing::instrument(
521        name = "db.compat_session.count",
522        skip_all,
523        fields(
524            db.query.text,
525        ),
526        err,
527    )]
528    async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error> {
529        let (sql, arguments) = sea_query::Query::select()
530            .expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count())
531            .from(CompatSessions::Table)
532            .apply_filter(filter)
533            .build_sqlx(PostgresQueryBuilder);
534
535        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
536            .traced()
537            .fetch_one(&mut *self.conn)
538            .await?;
539
540        count
541            .try_into()
542            .map_err(DatabaseError::to_invalid_operation)
543    }
544
545    #[tracing::instrument(
546        name = "db.compat_session.record_batch_activity",
547        skip_all,
548        fields(
549            db.query.text,
550        ),
551        err,
552    )]
553    async fn record_batch_activity(
554        &mut self,
555        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
556    ) -> Result<(), Self::Error> {
557        let mut ids = Vec::with_capacity(activity.len());
558        let mut last_activities = Vec::with_capacity(activity.len());
559        let mut ips = Vec::with_capacity(activity.len());
560
561        for (id, last_activity, ip) in activity {
562            ids.push(Uuid::from(id));
563            last_activities.push(last_activity);
564            ips.push(ip);
565        }
566
567        let res = sqlx::query!(
568            r#"
569                UPDATE compat_sessions
570                SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)
571                  , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)
572                FROM (
573                    SELECT *
574                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
575                        AS t(compat_session_id, last_active_at, last_active_ip)
576                ) AS t
577                WHERE compat_sessions.compat_session_id = t.compat_session_id
578            "#,
579            &ids,
580            &last_activities,
581            &ips as &[Option<IpAddr>],
582        )
583        .traced()
584        .execute(&mut *self.conn)
585        .await?;
586
587        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
588
589        Ok(())
590    }
591
592    #[tracing::instrument(
593        name = "db.compat_session.record_user_agent",
594        skip_all,
595        fields(
596            db.query.text,
597            %compat_session.id,
598        ),
599        err,
600    )]
601    async fn record_user_agent(
602        &mut self,
603        mut compat_session: CompatSession,
604        user_agent: String,
605    ) -> Result<CompatSession, Self::Error> {
606        let res = sqlx::query!(
607            r#"
608            UPDATE compat_sessions
609            SET user_agent = $2
610            WHERE compat_session_id = $1
611        "#,
612            Uuid::from(compat_session.id),
613            &*user_agent,
614        )
615        .traced()
616        .execute(&mut *self.conn)
617        .await?;
618
619        compat_session.user_agent = Some(user_agent);
620
621        DatabaseError::ensure_affected_rows(&res, 1)?;
622
623        Ok(compat_session)
624    }
625
626    #[tracing::instrument(
627        name = "repository.compat_session.set_human_name",
628        skip(self),
629        fields(
630            compat_session.id = %compat_session.id,
631            compat_session.human_name = ?human_name,
632        ),
633        err,
634    )]
635    async fn set_human_name(
636        &mut self,
637        mut compat_session: CompatSession,
638        human_name: Option<String>,
639    ) -> Result<CompatSession, Self::Error> {
640        let res = sqlx::query!(
641            r#"
642            UPDATE compat_sessions
643            SET human_name = $2
644            WHERE compat_session_id = $1
645        "#,
646            Uuid::from(compat_session.id),
647            human_name.as_deref(),
648        )
649        .traced()
650        .execute(&mut *self.conn)
651        .await?;
652
653        compat_session.human_name = human_name;
654
655        DatabaseError::ensure_affected_rows(&res, 1)?;
656
657        Ok(compat_session)
658    }
659}