1use std::collections::HashSet;
8
9use anyhow::Context;
10use async_trait::async_trait;
11use mas_data_model::Device;
12use mas_matrix::ProvisionRequest;
13use mas_storage::{
14 Pagination, RepositoryAccess,
15 compat::CompatSessionFilter,
16 oauth2::OAuth2SessionFilter,
17 queue::{
18 DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, QueueJobRepositoryExt as _,
19 SyncDevicesJob,
20 },
21 user::{UserEmailRepository, UserRepository},
22};
23use tracing::info;
24
25use crate::{
26 State,
27 new_queue::{JobContext, JobError, RunnableJob},
28};
29
30#[async_trait]
34impl RunnableJob for ProvisionUserJob {
35 #[tracing::instrument(
36 name = "job.provision_user"
37 fields(user.id = %self.user_id()),
38 skip_all,
39 )]
40 async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
41 let matrix = state.matrix_connection();
42 let mut repo = state.repository().await.map_err(JobError::retry)?;
43 let mut rng = state.rng();
44 let clock = state.clock();
45
46 let user = repo
47 .user()
48 .lookup(self.user_id())
49 .await
50 .map_err(JobError::retry)?
51 .context("User not found")
52 .map_err(JobError::fail)?;
53
54 let mxid = matrix.mxid(&user.username);
55 let emails = repo
56 .user_email()
57 .all(&user)
58 .await
59 .map_err(JobError::retry)?
60 .into_iter()
61 .map(|email| email.email)
62 .collect();
63 let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails);
64
65 if let Some(display_name) = self.display_name_to_set() {
66 request = request.set_displayname(display_name.to_owned());
67 }
68
69 let created = matrix
70 .provision_user(&request)
71 .await
72 .map_err(JobError::retry)?;
73
74 if created {
75 info!(%user.id, %mxid, "User created");
76 } else {
77 info!(%user.id, %mxid, "User updated");
78 }
79
80 let sync_device_job = SyncDevicesJob::new(&user);
82 repo.queue_job()
83 .schedule_job(&mut rng, &clock, sync_device_job)
84 .await
85 .map_err(JobError::retry)?;
86
87 repo.save().await.map_err(JobError::retry)?;
88
89 Ok(())
90 }
91}
92
93#[async_trait]
97impl RunnableJob for ProvisionDeviceJob {
98 #[tracing::instrument(
99 name = "job.provision_device"
100 fields(
101 user.id = %self.user_id(),
102 device.id = %self.device_id(),
103 ),
104 skip_all,
105 )]
106 async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
107 let mut repo = state.repository().await.map_err(JobError::retry)?;
108 let mut rng = state.rng();
109 let clock = state.clock();
110
111 let user = repo
112 .user()
113 .lookup(self.user_id())
114 .await
115 .map_err(JobError::retry)?
116 .context("User not found")
117 .map_err(JobError::fail)?;
118
119 repo.queue_job()
121 .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
122 .await
123 .map_err(JobError::retry)?;
124
125 Ok(())
126 }
127}
128
129#[async_trait]
133impl RunnableJob for DeleteDeviceJob {
134 #[tracing::instrument(
135 name = "job.delete_device"
136 fields(
137 user.id = %self.user_id(),
138 device.id = %self.device_id(),
139 ),
140 skip_all,
141 )]
142 async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
143 let mut rng = state.rng();
144 let clock = state.clock();
145 let mut repo = state.repository().await.map_err(JobError::retry)?;
146
147 let user = repo
148 .user()
149 .lookup(self.user_id())
150 .await
151 .map_err(JobError::retry)?
152 .context("User not found")
153 .map_err(JobError::fail)?;
154
155 repo.queue_job()
157 .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
158 .await
159 .map_err(JobError::retry)?;
160
161 Ok(())
162 }
163}
164
165#[async_trait]
167impl RunnableJob for SyncDevicesJob {
168 #[tracing::instrument(
169 name = "job.sync_devices",
170 fields(user.id = %self.user_id()),
171 skip_all,
172 )]
173 async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
174 let matrix = state.matrix_connection();
175 let mut repo = state.repository().await.map_err(JobError::retry)?;
176
177 let user = repo
178 .user()
179 .lookup(self.user_id())
180 .await
181 .map_err(JobError::retry)?
182 .context("User not found")
183 .map_err(JobError::fail)?;
184
185 repo.user()
187 .acquire_lock_for_sync(&user)
188 .await
189 .map_err(JobError::retry)?;
190
191 let mut devices = HashSet::new();
192
193 let mut cursor = Pagination::first(100);
195 loop {
196 let page = repo
197 .compat_session()
198 .list(
199 CompatSessionFilter::new().for_user(&user).active_only(),
200 cursor,
201 )
202 .await
203 .map_err(JobError::retry)?;
204
205 for (compat_session, _) in page.edges {
206 if let Some(ref device) = compat_session.device {
207 devices.insert(device.as_str().to_owned());
208 }
209 cursor = cursor.after(compat_session.id);
210 }
211
212 if !page.has_next_page {
213 break;
214 }
215 }
216
217 let mut cursor = Pagination::first(100);
219 loop {
220 let page = repo
221 .oauth2_session()
222 .list(
223 OAuth2SessionFilter::new().for_user(&user).active_only(),
224 cursor,
225 )
226 .await
227 .map_err(JobError::retry)?;
228
229 for oauth2_session in page.edges {
230 for scope in &*oauth2_session.scope {
231 if let Some(device) = Device::from_scope_token(scope) {
232 devices.insert(device.as_str().to_owned());
233 }
234 }
235
236 cursor = cursor.after(oauth2_session.id);
237 }
238
239 if !page.has_next_page {
240 break;
241 }
242 }
243
244 let mxid = matrix.mxid(&user.username);
245 matrix
246 .sync_devices(&mxid, devices)
247 .await
248 .map_err(JobError::retry)?;
249
250 repo.save().await.map_err(JobError::retry)?;
253
254 Ok(())
255 }
256}