1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12 BrowserSession, CompatSession, CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device,
13 User,
14};
15use mas_storage::{
16 Clock, Page, Pagination,
17 compat::{CompatSessionFilter, CompatSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use url::Url;
25use uuid::Uuid;
26
27use crate::{
28 DatabaseError, DatabaseInconsistencyError,
29 filter::{Filter, StatementExt, StatementWithJoinsExt},
30 iden::{CompatSessions, CompatSsoLogins},
31 pagination::QueryBuilderExt,
32 tracing::ExecuteExt,
33};
34
35pub struct PgCompatSessionRepository<'c> {
37 conn: &'c mut PgConnection,
38}
39
40impl<'c> PgCompatSessionRepository<'c> {
41 pub fn new(conn: &'c mut PgConnection) -> Self {
44 Self { conn }
45 }
46}
47
48struct CompatSessionLookup {
49 compat_session_id: Uuid,
50 device_id: Option<String>,
51 human_name: Option<String>,
52 user_id: Uuid,
53 user_session_id: Option<Uuid>,
54 created_at: DateTime<Utc>,
55 finished_at: Option<DateTime<Utc>>,
56 is_synapse_admin: bool,
57 user_agent: Option<String>,
58 last_active_at: Option<DateTime<Utc>>,
59 last_active_ip: Option<IpAddr>,
60}
61
62impl From<CompatSessionLookup> for CompatSession {
63 fn from(value: CompatSessionLookup) -> Self {
64 let id = value.compat_session_id.into();
65
66 let state = match value.finished_at {
67 None => CompatSessionState::Valid,
68 Some(finished_at) => CompatSessionState::Finished { finished_at },
69 };
70
71 CompatSession {
72 id,
73 state,
74 user_id: value.user_id.into(),
75 user_session_id: value.user_session_id.map(Ulid::from),
76 device: value.device_id.map(Device::from),
77 human_name: value.human_name,
78 created_at: value.created_at,
79 is_synapse_admin: value.is_synapse_admin,
80 user_agent: value.user_agent,
81 last_active_at: value.last_active_at,
82 last_active_ip: value.last_active_ip,
83 }
84 }
85}
86
87#[derive(sqlx::FromRow)]
88#[enum_def]
89struct CompatSessionAndSsoLoginLookup {
90 compat_session_id: Uuid,
91 device_id: Option<String>,
92 human_name: Option<String>,
93 user_id: Uuid,
94 user_session_id: Option<Uuid>,
95 created_at: DateTime<Utc>,
96 finished_at: Option<DateTime<Utc>>,
97 is_synapse_admin: bool,
98 user_agent: Option<String>,
99 last_active_at: Option<DateTime<Utc>>,
100 last_active_ip: Option<IpAddr>,
101 compat_sso_login_id: Option<Uuid>,
102 compat_sso_login_token: Option<String>,
103 compat_sso_login_redirect_uri: Option<String>,
104 compat_sso_login_created_at: Option<DateTime<Utc>>,
105 compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
106 compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
107}
108
109impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSsoLogin>) {
110 type Error = DatabaseInconsistencyError;
111
112 fn try_from(value: CompatSessionAndSsoLoginLookup) -> Result<Self, Self::Error> {
113 let id = value.compat_session_id.into();
114
115 let state = match value.finished_at {
116 None => CompatSessionState::Valid,
117 Some(finished_at) => CompatSessionState::Finished { finished_at },
118 };
119
120 let session = CompatSession {
121 id,
122 state,
123 user_id: value.user_id.into(),
124 device: value.device_id.map(Device::from),
125 human_name: value.human_name,
126 user_session_id: value.user_session_id.map(Ulid::from),
127 created_at: value.created_at,
128 is_synapse_admin: value.is_synapse_admin,
129 user_agent: value.user_agent,
130 last_active_at: value.last_active_at,
131 last_active_ip: value.last_active_ip,
132 };
133
134 match (
135 value.compat_sso_login_id,
136 value.compat_sso_login_token,
137 value.compat_sso_login_redirect_uri,
138 value.compat_sso_login_created_at,
139 value.compat_sso_login_fulfilled_at,
140 value.compat_sso_login_exchanged_at,
141 ) {
142 (None, None, None, None, None, None) => Ok((session, None)),
143 (
144 Some(id),
145 Some(login_token),
146 Some(redirect_uri),
147 Some(created_at),
148 fulfilled_at,
149 exchanged_at,
150 ) => {
151 let id = id.into();
152 let redirect_uri = Url::parse(&redirect_uri).map_err(|e| {
153 DatabaseInconsistencyError::on("compat_sso_logins")
154 .column("redirect_uri")
155 .row(id)
156 .source(e)
157 })?;
158
159 let state = match (fulfilled_at, exchanged_at) {
160 (Some(fulfilled_at), Some(exchanged_at)) => CompatSsoLoginState::Exchanged {
161 fulfilled_at,
162 exchanged_at,
163 compat_session_id: session.id,
164 },
165 _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
166 };
167
168 let login = CompatSsoLogin {
169 id,
170 redirect_uri,
171 login_token,
172 created_at,
173 state,
174 };
175
176 Ok((session, Some(login)))
177 }
178 _ => Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
179 }
180 }
181}
182
183impl Filter for CompatSessionFilter<'_> {
184 fn generate_condition(&self, has_joins: bool) -> impl sea_query::IntoCondition {
185 sea_query::Condition::all()
186 .add_option(self.user().map(|user| {
187 Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
188 }))
189 .add_option(self.browser_session().map(|browser_session| {
190 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
191 .eq(Uuid::from(browser_session.id))
192 }))
193 .add_option(self.state().map(|state| {
194 if state.is_active() {
195 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
196 } else {
197 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
198 }
199 }))
200 .add_option(self.auth_type().map(|auth_type| {
201 if has_joins {
204 if auth_type.is_sso_login() {
205 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
206 .is_not_null()
207 } else {
208 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
209 .is_null()
210 }
211 } else {
212 let compat_sso_logins = Query::select()
216 .expr(Expr::col((
217 CompatSsoLogins::Table,
218 CompatSsoLogins::CompatSessionId,
219 )))
220 .from(CompatSsoLogins::Table)
221 .take();
222
223 if auth_type.is_sso_login() {
224 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
225 .eq(Expr::any(compat_sso_logins))
226 } else {
227 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
228 .ne(Expr::all(compat_sso_logins))
229 }
230 }
231 }))
232 .add_option(self.last_active_after().map(|last_active_after| {
233 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
234 .gt(last_active_after)
235 }))
236 .add_option(self.last_active_before().map(|last_active_before| {
237 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
238 .lt(last_active_before)
239 }))
240 .add_option(self.device().map(|device| {
241 Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
242 }))
243 }
244}
245
246#[async_trait]
247impl CompatSessionRepository for PgCompatSessionRepository<'_> {
248 type Error = DatabaseError;
249
250 #[tracing::instrument(
251 name = "db.compat_session.lookup",
252 skip_all,
253 fields(
254 db.query.text,
255 compat_session.id = %id,
256 ),
257 err,
258 )]
259 async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
260 let res = sqlx::query_as!(
261 CompatSessionLookup,
262 r#"
263 SELECT compat_session_id
264 , device_id
265 , human_name
266 , user_id
267 , user_session_id
268 , created_at
269 , finished_at
270 , is_synapse_admin
271 , user_agent
272 , last_active_at
273 , last_active_ip as "last_active_ip: IpAddr"
274 FROM compat_sessions
275 WHERE compat_session_id = $1
276 "#,
277 Uuid::from(id),
278 )
279 .traced()
280 .fetch_optional(&mut *self.conn)
281 .await?;
282
283 let Some(res) = res else { return Ok(None) };
284
285 Ok(Some(res.into()))
286 }
287
288 #[tracing::instrument(
289 name = "db.compat_session.add",
290 skip_all,
291 fields(
292 db.query.text,
293 compat_session.id,
294 %user.id,
295 %user.username,
296 compat_session.device.id = device.as_str(),
297 ),
298 err,
299 )]
300 async fn add(
301 &mut self,
302 rng: &mut (dyn RngCore + Send),
303 clock: &dyn Clock,
304 user: &User,
305 device: Device,
306 browser_session: Option<&BrowserSession>,
307 is_synapse_admin: bool,
308 human_name: Option<String>,
309 ) -> Result<CompatSession, Self::Error> {
310 let created_at = clock.now();
311 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
312 tracing::Span::current().record("compat_session.id", tracing::field::display(id));
313
314 sqlx::query!(
315 r#"
316 INSERT INTO compat_sessions
317 (compat_session_id, user_id, device_id,
318 user_session_id, created_at, is_synapse_admin,
319 human_name)
320 VALUES ($1, $2, $3, $4, $5, $6, $7)
321 "#,
322 Uuid::from(id),
323 Uuid::from(user.id),
324 device.as_str(),
325 browser_session.map(|s| Uuid::from(s.id)),
326 created_at,
327 is_synapse_admin,
328 human_name.as_deref(),
329 )
330 .traced()
331 .execute(&mut *self.conn)
332 .await?;
333
334 Ok(CompatSession {
335 id,
336 state: CompatSessionState::default(),
337 user_id: user.id,
338 device: Some(device),
339 human_name,
340 user_session_id: browser_session.map(|s| s.id),
341 created_at,
342 is_synapse_admin,
343 user_agent: None,
344 last_active_at: None,
345 last_active_ip: None,
346 })
347 }
348
349 #[tracing::instrument(
350 name = "db.compat_session.finish",
351 skip_all,
352 fields(
353 db.query.text,
354 %compat_session.id,
355 user.id = %compat_session.user_id,
356 compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
357 ),
358 err,
359 )]
360 async fn finish(
361 &mut self,
362 clock: &dyn Clock,
363 compat_session: CompatSession,
364 ) -> Result<CompatSession, Self::Error> {
365 let finished_at = clock.now();
366
367 let res = sqlx::query!(
368 r#"
369 UPDATE compat_sessions cs
370 SET finished_at = $2
371 WHERE compat_session_id = $1
372 "#,
373 Uuid::from(compat_session.id),
374 finished_at,
375 )
376 .traced()
377 .execute(&mut *self.conn)
378 .await?;
379
380 DatabaseError::ensure_affected_rows(&res, 1)?;
381
382 let compat_session = compat_session
383 .finish(finished_at)
384 .map_err(DatabaseError::to_invalid_operation)?;
385
386 Ok(compat_session)
387 }
388
389 #[tracing::instrument(
390 name = "db.compat_session.finish_bulk",
391 skip_all,
392 fields(db.query.text),
393 err,
394 )]
395 async fn finish_bulk(
396 &mut self,
397 clock: &dyn Clock,
398 filter: CompatSessionFilter<'_>,
399 ) -> Result<usize, Self::Error> {
400 let finished_at = clock.now();
401 let (sql, arguments) = Query::update()
402 .table(CompatSessions::Table)
403 .value(CompatSessions::FinishedAt, finished_at)
404 .apply_filter(filter)
405 .build_sqlx(PostgresQueryBuilder);
406
407 let res = sqlx::query_with(&sql, arguments)
408 .traced()
409 .execute(&mut *self.conn)
410 .await?;
411
412 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
413 }
414
415 #[tracing::instrument(
416 name = "db.compat_session.list",
417 skip_all,
418 fields(
419 db.query.text,
420 ),
421 err,
422 )]
423 async fn list(
424 &mut self,
425 filter: CompatSessionFilter<'_>,
426 pagination: Pagination,
427 ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> {
428 let (sql, arguments) = Query::select()
429 .expr_as(
430 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
431 CompatSessionAndSsoLoginLookupIden::CompatSessionId,
432 )
433 .expr_as(
434 Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
435 CompatSessionAndSsoLoginLookupIden::DeviceId,
436 )
437 .expr_as(
438 Expr::col((CompatSessions::Table, CompatSessions::HumanName)),
439 CompatSessionAndSsoLoginLookupIden::HumanName,
440 )
441 .expr_as(
442 Expr::col((CompatSessions::Table, CompatSessions::UserId)),
443 CompatSessionAndSsoLoginLookupIden::UserId,
444 )
445 .expr_as(
446 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
447 CompatSessionAndSsoLoginLookupIden::UserSessionId,
448 )
449 .expr_as(
450 Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
451 CompatSessionAndSsoLoginLookupIden::CreatedAt,
452 )
453 .expr_as(
454 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
455 CompatSessionAndSsoLoginLookupIden::FinishedAt,
456 )
457 .expr_as(
458 Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
459 CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin,
460 )
461 .expr_as(
462 Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
463 CompatSessionAndSsoLoginLookupIden::UserAgent,
464 )
465 .expr_as(
466 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
467 CompatSessionAndSsoLoginLookupIden::LastActiveAt,
468 )
469 .expr_as(
470 Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
471 CompatSessionAndSsoLoginLookupIden::LastActiveIp,
472 )
473 .expr_as(
474 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
475 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginId,
476 )
477 .expr_as(
478 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
479 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginToken,
480 )
481 .expr_as(
482 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
483 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginRedirectUri,
484 )
485 .expr_as(
486 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
487 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginCreatedAt,
488 )
489 .expr_as(
490 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
491 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginFulfilledAt,
492 )
493 .expr_as(
494 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
495 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginExchangedAt,
496 )
497 .from(CompatSessions::Table)
498 .left_join(
499 CompatSsoLogins::Table,
500 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
501 .equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
502 )
503 .apply_filter_with_joins(filter)
504 .generate_pagination(
505 (CompatSessions::Table, CompatSessions::CompatSessionId),
506 pagination,
507 )
508 .build_sqlx(PostgresQueryBuilder);
509
510 let edges: Vec<CompatSessionAndSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
511 .traced()
512 .fetch_all(&mut *self.conn)
513 .await?;
514
515 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
516
517 Ok(page)
518 }
519
520 #[tracing::instrument(
521 name = "db.compat_session.count",
522 skip_all,
523 fields(
524 db.query.text,
525 ),
526 err,
527 )]
528 async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error> {
529 let (sql, arguments) = sea_query::Query::select()
530 .expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count())
531 .from(CompatSessions::Table)
532 .apply_filter(filter)
533 .build_sqlx(PostgresQueryBuilder);
534
535 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
536 .traced()
537 .fetch_one(&mut *self.conn)
538 .await?;
539
540 count
541 .try_into()
542 .map_err(DatabaseError::to_invalid_operation)
543 }
544
545 #[tracing::instrument(
546 name = "db.compat_session.record_batch_activity",
547 skip_all,
548 fields(
549 db.query.text,
550 ),
551 err,
552 )]
553 async fn record_batch_activity(
554 &mut self,
555 activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
556 ) -> Result<(), Self::Error> {
557 let mut ids = Vec::with_capacity(activity.len());
558 let mut last_activities = Vec::with_capacity(activity.len());
559 let mut ips = Vec::with_capacity(activity.len());
560
561 for (id, last_activity, ip) in activity {
562 ids.push(Uuid::from(id));
563 last_activities.push(last_activity);
564 ips.push(ip);
565 }
566
567 let res = sqlx::query!(
568 r#"
569 UPDATE compat_sessions
570 SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)
571 , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)
572 FROM (
573 SELECT *
574 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
575 AS t(compat_session_id, last_active_at, last_active_ip)
576 ) AS t
577 WHERE compat_sessions.compat_session_id = t.compat_session_id
578 "#,
579 &ids,
580 &last_activities,
581 &ips as &[Option<IpAddr>],
582 )
583 .traced()
584 .execute(&mut *self.conn)
585 .await?;
586
587 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
588
589 Ok(())
590 }
591
592 #[tracing::instrument(
593 name = "db.compat_session.record_user_agent",
594 skip_all,
595 fields(
596 db.query.text,
597 %compat_session.id,
598 ),
599 err,
600 )]
601 async fn record_user_agent(
602 &mut self,
603 mut compat_session: CompatSession,
604 user_agent: String,
605 ) -> Result<CompatSession, Self::Error> {
606 let res = sqlx::query!(
607 r#"
608 UPDATE compat_sessions
609 SET user_agent = $2
610 WHERE compat_session_id = $1
611 "#,
612 Uuid::from(compat_session.id),
613 &*user_agent,
614 )
615 .traced()
616 .execute(&mut *self.conn)
617 .await?;
618
619 compat_session.user_agent = Some(user_agent);
620
621 DatabaseError::ensure_affected_rows(&res, 1)?;
622
623 Ok(compat_session)
624 }
625
626 #[tracing::instrument(
627 name = "repository.compat_session.set_human_name",
628 skip(self),
629 fields(
630 compat_session.id = %compat_session.id,
631 compat_session.human_name = ?human_name,
632 ),
633 err,
634 )]
635 async fn set_human_name(
636 &mut self,
637 mut compat_session: CompatSession,
638 human_name: Option<String>,
639 ) -> Result<CompatSession, Self::Error> {
640 let res = sqlx::query!(
641 r#"
642 UPDATE compat_sessions
643 SET human_name = $2
644 WHERE compat_session_id = $1
645 "#,
646 Uuid::from(compat_session.id),
647 human_name.as_deref(),
648 )
649 .traced()
650 .execute(&mut *self.conn)
651 .await?;
652
653 compat_session.human_name = human_name;
654
655 DatabaseError::ensure_affected_rows(&res, 1)?;
656
657 Ok(compat_session)
658 }
659}