1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Session, SessionState, User};
12use mas_storage::{
13 Clock, Page, Pagination,
14 oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
15};
16use oauth2_types::scope::{Scope, ScopeToken};
17use rand::RngCore;
18use sea_query::{Expr, PgFunc, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25 DatabaseError, DatabaseInconsistencyError,
26 filter::{Filter, StatementExt},
27 iden::{OAuth2Clients, OAuth2Sessions},
28 pagination::QueryBuilderExt,
29 tracing::ExecuteExt,
30};
31
32pub struct PgOAuth2SessionRepository<'c> {
34 conn: &'c mut PgConnection,
35}
36
37impl<'c> PgOAuth2SessionRepository<'c> {
38 pub fn new(conn: &'c mut PgConnection) -> Self {
41 Self { conn }
42 }
43}
44
45#[derive(sqlx::FromRow)]
46#[enum_def]
47struct OAuthSessionLookup {
48 oauth2_session_id: Uuid,
49 user_id: Option<Uuid>,
50 user_session_id: Option<Uuid>,
51 oauth2_client_id: Uuid,
52 scope_list: Vec<String>,
53 created_at: DateTime<Utc>,
54 finished_at: Option<DateTime<Utc>>,
55 user_agent: Option<String>,
56 last_active_at: Option<DateTime<Utc>>,
57 last_active_ip: Option<IpAddr>,
58 human_name: Option<String>,
59}
60
61impl TryFrom<OAuthSessionLookup> for Session {
62 type Error = DatabaseInconsistencyError;
63
64 fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
65 let id = Ulid::from(value.oauth2_session_id);
66 let scope: Result<Scope, _> = value
67 .scope_list
68 .iter()
69 .map(|s| s.parse::<ScopeToken>())
70 .collect();
71 let scope = scope.map_err(|e| {
72 DatabaseInconsistencyError::on("oauth2_sessions")
73 .column("scope")
74 .row(id)
75 .source(e)
76 })?;
77
78 let state = match value.finished_at {
79 None => SessionState::Valid,
80 Some(finished_at) => SessionState::Finished { finished_at },
81 };
82
83 Ok(Session {
84 id,
85 state,
86 created_at: value.created_at,
87 client_id: value.oauth2_client_id.into(),
88 user_id: value.user_id.map(Ulid::from),
89 user_session_id: value.user_session_id.map(Ulid::from),
90 scope,
91 user_agent: value.user_agent,
92 last_active_at: value.last_active_at,
93 last_active_ip: value.last_active_ip,
94 human_name: value.human_name,
95 })
96 }
97}
98
99impl Filter for OAuth2SessionFilter<'_> {
100 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
101 sea_query::Condition::all()
102 .add_option(self.user().map(|user| {
103 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
104 }))
105 .add_option(self.client().map(|client| {
106 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
107 .eq(Uuid::from(client.id))
108 }))
109 .add_option(self.client_kind().map(|client_kind| {
110 let static_clients = Query::select()
114 .expr(Expr::col((
115 OAuth2Clients::Table,
116 OAuth2Clients::OAuth2ClientId,
117 )))
118 .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
119 .from(OAuth2Clients::Table)
120 .take();
121 if client_kind.is_static() {
122 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
123 .eq(Expr::any(static_clients))
124 } else {
125 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
126 .ne(Expr::all(static_clients))
127 }
128 }))
129 .add_option(self.device().map(|device| {
130 if let Ok(scope_token) = device.to_scope_token() {
131 Expr::val(scope_token.to_string()).eq(PgFunc::any(Expr::col((
132 OAuth2Sessions::Table,
133 OAuth2Sessions::ScopeList,
134 ))))
135 } else {
136 Expr::val(false).into()
138 }
139 }))
140 .add_option(self.browser_session().map(|browser_session| {
141 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
142 .eq(Uuid::from(browser_session.id))
143 }))
144 .add_option(self.state().map(|state| {
145 if state.is_active() {
146 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
147 } else {
148 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
149 }
150 }))
151 .add_option(self.scope().map(|scope| {
152 let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
153 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
154 }))
155 .add_option(self.any_user().map(|any_user| {
156 if any_user {
157 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
158 } else {
159 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
160 }
161 }))
162 .add_option(self.last_active_after().map(|last_active_after| {
163 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
164 .gt(last_active_after)
165 }))
166 .add_option(self.last_active_before().map(|last_active_before| {
167 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
168 .lt(last_active_before)
169 }))
170 }
171}
172
173#[async_trait]
174impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
175 type Error = DatabaseError;
176
177 #[tracing::instrument(
178 name = "db.oauth2_session.lookup",
179 skip_all,
180 fields(
181 db.query.text,
182 session.id = %id,
183 ),
184 err,
185 )]
186 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
187 let res = sqlx::query_as!(
188 OAuthSessionLookup,
189 r#"
190 SELECT oauth2_session_id
191 , user_id
192 , user_session_id
193 , oauth2_client_id
194 , scope_list
195 , created_at
196 , finished_at
197 , user_agent
198 , last_active_at
199 , last_active_ip as "last_active_ip: IpAddr"
200 , human_name
201 FROM oauth2_sessions
202
203 WHERE oauth2_session_id = $1
204 "#,
205 Uuid::from(id),
206 )
207 .traced()
208 .fetch_optional(&mut *self.conn)
209 .await?;
210
211 let Some(session) = res else { return Ok(None) };
212
213 Ok(Some(session.try_into()?))
214 }
215
216 #[tracing::instrument(
217 name = "db.oauth2_session.add",
218 skip_all,
219 fields(
220 db.query.text,
221 %client.id,
222 session.id,
223 session.scope = %scope,
224 ),
225 err,
226 )]
227 async fn add(
228 &mut self,
229 rng: &mut (dyn RngCore + Send),
230 clock: &dyn Clock,
231 client: &Client,
232 user: Option<&User>,
233 user_session: Option<&BrowserSession>,
234 scope: Scope,
235 ) -> Result<Session, Self::Error> {
236 let created_at = clock.now();
237 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
238 tracing::Span::current().record("session.id", tracing::field::display(id));
239
240 let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
241
242 sqlx::query!(
243 r#"
244 INSERT INTO oauth2_sessions
245 ( oauth2_session_id
246 , user_id
247 , user_session_id
248 , oauth2_client_id
249 , scope_list
250 , created_at
251 )
252 VALUES ($1, $2, $3, $4, $5, $6)
253 "#,
254 Uuid::from(id),
255 user.map(|u| Uuid::from(u.id)),
256 user_session.map(|s| Uuid::from(s.id)),
257 Uuid::from(client.id),
258 &scope_list,
259 created_at,
260 )
261 .traced()
262 .execute(&mut *self.conn)
263 .await?;
264
265 Ok(Session {
266 id,
267 state: SessionState::Valid,
268 created_at,
269 user_id: user.map(|u| u.id),
270 user_session_id: user_session.map(|s| s.id),
271 client_id: client.id,
272 scope,
273 user_agent: None,
274 last_active_at: None,
275 last_active_ip: None,
276 human_name: None,
277 })
278 }
279
280 #[tracing::instrument(
281 name = "db.oauth2_session.finish_bulk",
282 skip_all,
283 fields(
284 db.query.text,
285 ),
286 err,
287 )]
288 async fn finish_bulk(
289 &mut self,
290 clock: &dyn Clock,
291 filter: OAuth2SessionFilter<'_>,
292 ) -> Result<usize, Self::Error> {
293 let finished_at = clock.now();
294 let (sql, arguments) = Query::update()
295 .table(OAuth2Sessions::Table)
296 .value(OAuth2Sessions::FinishedAt, finished_at)
297 .apply_filter(filter)
298 .build_sqlx(PostgresQueryBuilder);
299
300 let res = sqlx::query_with(&sql, arguments)
301 .traced()
302 .execute(&mut *self.conn)
303 .await?;
304
305 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
306 }
307
308 #[tracing::instrument(
309 name = "db.oauth2_session.finish",
310 skip_all,
311 fields(
312 db.query.text,
313 %session.id,
314 %session.scope,
315 client.id = %session.client_id,
316 ),
317 err,
318 )]
319 async fn finish(
320 &mut self,
321 clock: &dyn Clock,
322 session: Session,
323 ) -> Result<Session, Self::Error> {
324 let finished_at = clock.now();
325 let res = sqlx::query!(
326 r#"
327 UPDATE oauth2_sessions
328 SET finished_at = $2
329 WHERE oauth2_session_id = $1
330 "#,
331 Uuid::from(session.id),
332 finished_at,
333 )
334 .traced()
335 .execute(&mut *self.conn)
336 .await?;
337
338 DatabaseError::ensure_affected_rows(&res, 1)?;
339
340 session
341 .finish(finished_at)
342 .map_err(DatabaseError::to_invalid_operation)
343 }
344
345 #[tracing::instrument(
346 name = "db.oauth2_session.list",
347 skip_all,
348 fields(
349 db.query.text,
350 ),
351 err,
352 )]
353 async fn list(
354 &mut self,
355 filter: OAuth2SessionFilter<'_>,
356 pagination: Pagination,
357 ) -> Result<Page<Session>, Self::Error> {
358 let (sql, arguments) = Query::select()
359 .expr_as(
360 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
361 OAuthSessionLookupIden::Oauth2SessionId,
362 )
363 .expr_as(
364 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
365 OAuthSessionLookupIden::UserId,
366 )
367 .expr_as(
368 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
369 OAuthSessionLookupIden::UserSessionId,
370 )
371 .expr_as(
372 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
373 OAuthSessionLookupIden::Oauth2ClientId,
374 )
375 .expr_as(
376 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
377 OAuthSessionLookupIden::ScopeList,
378 )
379 .expr_as(
380 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
381 OAuthSessionLookupIden::CreatedAt,
382 )
383 .expr_as(
384 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
385 OAuthSessionLookupIden::FinishedAt,
386 )
387 .expr_as(
388 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
389 OAuthSessionLookupIden::UserAgent,
390 )
391 .expr_as(
392 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
393 OAuthSessionLookupIden::LastActiveAt,
394 )
395 .expr_as(
396 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
397 OAuthSessionLookupIden::LastActiveIp,
398 )
399 .expr_as(
400 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
401 OAuthSessionLookupIden::HumanName,
402 )
403 .from(OAuth2Sessions::Table)
404 .apply_filter(filter)
405 .generate_pagination(
406 (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
407 pagination,
408 )
409 .build_sqlx(PostgresQueryBuilder);
410
411 let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
412 .traced()
413 .fetch_all(&mut *self.conn)
414 .await?;
415
416 let page = pagination.process(edges).try_map(Session::try_from)?;
417
418 Ok(page)
419 }
420
421 #[tracing::instrument(
422 name = "db.oauth2_session.count",
423 skip_all,
424 fields(
425 db.query.text,
426 ),
427 err,
428 )]
429 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
430 let (sql, arguments) = Query::select()
431 .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
432 .from(OAuth2Sessions::Table)
433 .apply_filter(filter)
434 .build_sqlx(PostgresQueryBuilder);
435
436 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
437 .traced()
438 .fetch_one(&mut *self.conn)
439 .await?;
440
441 count
442 .try_into()
443 .map_err(DatabaseError::to_invalid_operation)
444 }
445
446 #[tracing::instrument(
447 name = "db.oauth2_session.record_batch_activity",
448 skip_all,
449 fields(
450 db.query.text,
451 ),
452 err,
453 )]
454 async fn record_batch_activity(
455 &mut self,
456 activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
457 ) -> Result<(), Self::Error> {
458 let mut ids = Vec::with_capacity(activity.len());
459 let mut last_activities = Vec::with_capacity(activity.len());
460 let mut ips = Vec::with_capacity(activity.len());
461
462 for (id, last_activity, ip) in activity {
463 ids.push(Uuid::from(id));
464 last_activities.push(last_activity);
465 ips.push(ip);
466 }
467
468 let res = sqlx::query!(
469 r#"
470 UPDATE oauth2_sessions
471 SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
472 , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
473 FROM (
474 SELECT *
475 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
476 AS t(oauth2_session_id, last_active_at, last_active_ip)
477 ) AS t
478 WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
479 "#,
480 &ids,
481 &last_activities,
482 &ips as &[Option<IpAddr>],
483 )
484 .traced()
485 .execute(&mut *self.conn)
486 .await?;
487
488 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
489
490 Ok(())
491 }
492
493 #[tracing::instrument(
494 name = "db.oauth2_session.record_user_agent",
495 skip_all,
496 fields(
497 db.query.text,
498 %session.id,
499 %session.scope,
500 client.id = %session.client_id,
501 session.user_agent = user_agent,
502 ),
503 err,
504 )]
505 async fn record_user_agent(
506 &mut self,
507 mut session: Session,
508 user_agent: String,
509 ) -> Result<Session, Self::Error> {
510 let res = sqlx::query!(
511 r#"
512 UPDATE oauth2_sessions
513 SET user_agent = $2
514 WHERE oauth2_session_id = $1
515 "#,
516 Uuid::from(session.id),
517 &*user_agent,
518 )
519 .traced()
520 .execute(&mut *self.conn)
521 .await?;
522
523 session.user_agent = Some(user_agent);
524
525 DatabaseError::ensure_affected_rows(&res, 1)?;
526
527 Ok(session)
528 }
529
530 #[tracing::instrument(
531 name = "repository.oauth2_session.set_human_name",
532 skip(self),
533 fields(
534 client.id = %session.client_id,
535 session.human_name = ?human_name,
536 ),
537 err,
538 )]
539 async fn set_human_name(
540 &mut self,
541 mut session: Session,
542 human_name: Option<String>,
543 ) -> Result<Session, Self::Error> {
544 let res = sqlx::query!(
545 r#"
546 UPDATE oauth2_sessions
547 SET human_name = $2
548 WHERE oauth2_session_id = $1
549 "#,
550 Uuid::from(session.id),
551 human_name.as_deref(),
552 )
553 .traced()
554 .execute(&mut *self.conn)
555 .await?;
556
557 session.human_name = human_name;
558
559 DatabaseError::ensure_affected_rows(&res, 1)?;
560
561 Ok(session)
562 }
563}