mas_handlers/graphql/mutations/
oauth2_session.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use anyhow::Context as _;
8use async_graphql::{Context, Description, Enum, ID, InputObject, Object};
9use chrono::Duration;
10use mas_data_model::{Device, TokenType};
11use mas_storage::{
12    RepositoryAccess,
13    oauth2::{
14        OAuth2AccessTokenRepository, OAuth2ClientRepository, OAuth2RefreshTokenRepository,
15        OAuth2SessionRepository,
16    },
17    queue::{QueueJobRepositoryExt as _, SyncDevicesJob},
18    user::UserRepository,
19};
20use oauth2_types::scope::Scope;
21
22use crate::graphql::{
23    model::{NodeType, OAuth2Session},
24    state::ContextExt,
25};
26
27#[derive(Default)]
28pub struct OAuth2SessionMutations {
29    _private: (),
30}
31
32/// The input of the `createOauth2Session` mutation.
33#[derive(InputObject)]
34pub struct CreateOAuth2SessionInput {
35    /// The scope of the session
36    scope: String,
37
38    /// The ID of the user for which to create the session
39    user_id: ID,
40
41    /// Whether the session should issue a never-expiring access token
42    permanent: Option<bool>,
43}
44
45/// The payload of the `createOauth2Session` mutation.
46#[derive(Description)]
47pub struct CreateOAuth2SessionPayload {
48    access_token: String,
49    refresh_token: Option<String>,
50    session: mas_data_model::Session,
51}
52
53#[Object(use_type_description)]
54impl CreateOAuth2SessionPayload {
55    /// Access token for this session
56    pub async fn access_token(&self) -> &str {
57        &self.access_token
58    }
59
60    /// Refresh token for this session, if it is not a permanent session
61    pub async fn refresh_token(&self) -> Option<&str> {
62        self.refresh_token.as_deref()
63    }
64
65    /// The OAuth 2.0 session which was just created
66    pub async fn oauth2_session(&self) -> OAuth2Session {
67        OAuth2Session(self.session.clone())
68    }
69}
70
71/// The input of the `endOauth2Session` mutation.
72#[derive(InputObject)]
73pub struct EndOAuth2SessionInput {
74    /// The ID of the session to end.
75    oauth2_session_id: ID,
76}
77
78/// The payload of the `endOauth2Session` mutation.
79pub enum EndOAuth2SessionPayload {
80    NotFound,
81    Ended(mas_data_model::Session),
82}
83
84/// The status of the `endOauth2Session` mutation.
85#[derive(Enum, Copy, Clone, PartialEq, Eq, Debug)]
86enum EndOAuth2SessionStatus {
87    /// The session was ended.
88    Ended,
89
90    /// The session was not found.
91    NotFound,
92}
93
94#[Object]
95impl EndOAuth2SessionPayload {
96    /// The status of the mutation.
97    async fn status(&self) -> EndOAuth2SessionStatus {
98        match self {
99            Self::Ended(_) => EndOAuth2SessionStatus::Ended,
100            Self::NotFound => EndOAuth2SessionStatus::NotFound,
101        }
102    }
103
104    /// Returns the ended session.
105    async fn oauth2_session(&self) -> Option<OAuth2Session> {
106        match self {
107            Self::Ended(session) => Some(OAuth2Session(session.clone())),
108            Self::NotFound => None,
109        }
110    }
111}
112
113/// The input of the `setOauth2SessionName` mutation.
114#[derive(InputObject)]
115pub struct SetOAuth2SessionNameInput {
116    /// The ID of the session to set the name of.
117    oauth2_session_id: ID,
118
119    /// The new name of the session.
120    human_name: String,
121}
122
123/// The payload of the `setOauth2SessionName` mutation.
124pub enum SetOAuth2SessionNamePayload {
125    /// The session was not found.
126    NotFound,
127
128    /// The session was updated.
129    Updated(mas_data_model::Session),
130}
131
132/// The status of the `setOauth2SessionName` mutation.
133#[derive(Enum, Copy, Clone, PartialEq, Eq, Debug)]
134enum SetOAuth2SessionNameStatus {
135    /// The session was updated.
136    Updated,
137
138    /// The session was not found.
139    NotFound,
140}
141
142#[Object]
143impl SetOAuth2SessionNamePayload {
144    /// The status of the mutation.
145    async fn status(&self) -> SetOAuth2SessionNameStatus {
146        match self {
147            Self::Updated(_) => SetOAuth2SessionNameStatus::Updated,
148            Self::NotFound => SetOAuth2SessionNameStatus::NotFound,
149        }
150    }
151
152    /// The session that was updated.
153    async fn oauth2_session(&self) -> Option<OAuth2Session> {
154        match self {
155            Self::Updated(session) => Some(OAuth2Session(session.clone())),
156            Self::NotFound => None,
157        }
158    }
159}
160
161#[Object]
162impl OAuth2SessionMutations {
163    /// Create a new arbitrary OAuth 2.0 Session.
164    ///
165    /// Only available for administrators.
166    async fn create_oauth2_session(
167        &self,
168        ctx: &Context<'_>,
169        input: CreateOAuth2SessionInput,
170    ) -> Result<CreateOAuth2SessionPayload, async_graphql::Error> {
171        let state = ctx.state();
172        let homeserver = state.homeserver_connection();
173        let user_id = NodeType::User.extract_ulid(&input.user_id)?;
174        let scope: Scope = input.scope.parse().context("Invalid scope")?;
175        let permanent = input.permanent.unwrap_or(false);
176        let requester = ctx.requester();
177
178        if !requester.is_admin() {
179            return Err(async_graphql::Error::new("Unauthorized"));
180        }
181
182        let session = requester
183            .oauth2_session()
184            .context("Requester should be a OAuth 2.0 client")?;
185
186        let mut repo = state.repository().await?;
187        let clock = state.clock();
188        let mut rng = state.rng();
189
190        let client = repo
191            .oauth2_client()
192            .lookup(session.client_id)
193            .await?
194            .context("Client not found")?;
195
196        let user = repo
197            .user()
198            .lookup(user_id)
199            .await?
200            .context("User not found")?;
201
202        // Generate a new access token
203        let access_token = TokenType::AccessToken.generate(&mut rng);
204
205        // Create the OAuth 2.0 Session
206        let session = repo
207            .oauth2_session()
208            .add(&mut rng, &clock, &client, Some(&user), None, scope)
209            .await?;
210
211        // Lock the user sync to make sure we don't get into a race condition
212        repo.user().acquire_lock_for_sync(&user).await?;
213
214        // Look for devices to provision
215        let mxid = homeserver.mxid(&user.username);
216        for scope in &*session.scope {
217            if let Some(device) = Device::from_scope_token(scope) {
218                homeserver
219                    .create_device(&mxid, device.as_str(), None)
220                    .await
221                    .context("Failed to provision device")?;
222            }
223        }
224
225        let ttl = if permanent {
226            None
227        } else {
228            Some(Duration::microseconds(5 * 60 * 1000 * 1000))
229        };
230        let access_token = repo
231            .oauth2_access_token()
232            .add(&mut rng, &clock, &session, access_token, ttl)
233            .await?;
234
235        let refresh_token = if permanent {
236            None
237        } else {
238            let refresh_token = TokenType::RefreshToken.generate(&mut rng);
239
240            let refresh_token = repo
241                .oauth2_refresh_token()
242                .add(&mut rng, &clock, &session, &access_token, refresh_token)
243                .await?;
244
245            Some(refresh_token)
246        };
247
248        repo.save().await?;
249
250        Ok(CreateOAuth2SessionPayload {
251            session,
252            access_token: access_token.access_token,
253            refresh_token: refresh_token.map(|t| t.refresh_token),
254        })
255    }
256
257    async fn end_oauth2_session(
258        &self,
259        ctx: &Context<'_>,
260        input: EndOAuth2SessionInput,
261    ) -> Result<EndOAuth2SessionPayload, async_graphql::Error> {
262        let state = ctx.state();
263        let oauth2_session_id = NodeType::OAuth2Session.extract_ulid(&input.oauth2_session_id)?;
264        let requester = ctx.requester();
265
266        let mut repo = state.repository().await?;
267        let clock = state.clock();
268        let mut rng = state.rng();
269
270        let session = repo.oauth2_session().lookup(oauth2_session_id).await?;
271        let Some(session) = session else {
272            return Ok(EndOAuth2SessionPayload::NotFound);
273        };
274
275        if !requester.is_owner_or_admin(&session) {
276            return Ok(EndOAuth2SessionPayload::NotFound);
277        }
278
279        if let Some(user_id) = session.user_id {
280            let user = repo
281                .user()
282                .lookup(user_id)
283                .await?
284                .context("Could not load user")?;
285
286            // Schedule a job to sync the devices of the user with the homeserver
287            repo.queue_job()
288                .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
289                .await?;
290        }
291
292        let session = repo.oauth2_session().finish(&clock, session).await?;
293
294        repo.save().await?;
295
296        Ok(EndOAuth2SessionPayload::Ended(session))
297    }
298
299    async fn set_oauth2_session_name(
300        &self,
301        ctx: &Context<'_>,
302        input: SetOAuth2SessionNameInput,
303    ) -> Result<SetOAuth2SessionNamePayload, async_graphql::Error> {
304        let state = ctx.state();
305        let oauth2_session_id = NodeType::OAuth2Session.extract_ulid(&input.oauth2_session_id)?;
306        let requester = ctx.requester();
307
308        let mut repo = state.repository().await?;
309        let homeserver = state.homeserver_connection();
310
311        let session = repo.oauth2_session().lookup(oauth2_session_id).await?;
312        let Some(session) = session else {
313            return Ok(SetOAuth2SessionNamePayload::NotFound);
314        };
315
316        if !requester.is_owner_or_admin(&session) {
317            return Ok(SetOAuth2SessionNamePayload::NotFound);
318        }
319
320        let user_id = session.user_id.context("Session has no user")?;
321
322        let user = repo
323            .user()
324            .lookup(user_id)
325            .await?
326            .context("User not found")?;
327
328        let session = repo
329            .oauth2_session()
330            .set_human_name(session, Some(input.human_name.clone()))
331            .await?;
332
333        // Update the device on the homeserver side
334        let mxid = homeserver.mxid(&user.username);
335        for scope in &*session.scope {
336            if let Some(device) = Device::from_scope_token(scope) {
337                homeserver
338                    .update_device_display_name(&mxid, device.as_str(), &input.human_name)
339                    .await
340                    .context("Failed to provision device")?;
341            }
342        }
343
344        repo.save().await?;
345
346        Ok(SetOAuth2SessionNamePayload::Updated(session))
347    }
348}