syn2mas/
migration.rs

1// Copyright 2024 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6//! # Migration
7//!
8//! This module provides the high-level logic for performing the Synapse-to-MAS
9//! database migration.
10//!
11//! This module does not implement any of the safety checks that should be run
12//! *before* the migration.
13
14use std::time::Instant;
15
16use chrono::{DateTime, Utc};
17use compact_str::CompactString;
18use futures_util::{SinkExt, StreamExt as _, TryFutureExt, TryStreamExt as _};
19use mas_storage::Clock;
20use rand::{RngCore, SeedableRng};
21use thiserror::Error;
22use thiserror_ext::ContextInto;
23use tokio_util::sync::PollSender;
24use tracing::{Instrument as _, Level, info};
25use ulid::Ulid;
26use uuid::{NonNilUuid, Uuid};
27
28use crate::{
29    HashMap, ProgressCounter, RandomState, SynapseReader,
30    mas_writer::{
31        self, MasNewCompatAccessToken, MasNewCompatRefreshToken, MasNewCompatSession,
32        MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
33        MasNewUserPassword, MasWriteBuffer, MasWriter,
34    },
35    progress::{EntityType, Progress},
36    synapse_reader::{
37        self, ExtractLocalpartError, FullUserId, SynapseAccessToken, SynapseDevice,
38        SynapseExternalId, SynapseRefreshableTokenPair, SynapseThreepid, SynapseUser,
39    },
40};
41
42#[derive(Debug, Error, ContextInto)]
43pub enum Error {
44    #[error("error when reading synapse DB ({context}): {source}")]
45    Synapse {
46        source: synapse_reader::Error,
47        context: String,
48    },
49    #[error("error when writing to MAS DB ({context}): {source}")]
50    Mas {
51        source: mas_writer::Error,
52        context: String,
53    },
54    #[error("failed to extract localpart of {user:?}: {source}")]
55    ExtractLocalpart {
56        source: ExtractLocalpartError,
57        user: FullUserId,
58    },
59    #[error("channel closed")]
60    ChannelClosed,
61
62    #[error("task failed ({context}): {source}")]
63    Join {
64        source: tokio::task::JoinError,
65        context: String,
66    },
67
68    #[error("user {user} was not found for migration but a row in {table} was found for them")]
69    MissingUserFromDependentTable { table: String, user: FullUserId },
70    #[error(
71        "missing a mapping for the auth provider with ID {synapse_id:?} (used by {user} and maybe other users)"
72    )]
73    MissingAuthProviderMapping {
74        /// `auth_provider` ID of the provider in Synapse, for which we have no
75        /// mapping
76        synapse_id: String,
77        /// a user that is using this auth provider
78        user: FullUserId,
79    },
80}
81
82bitflags::bitflags! {
83    #[derive(Debug, Clone, Copy)]
84    struct UserFlags: u8 {
85        const IS_SYNAPSE_ADMIN = 0b0000_0001;
86        const IS_DEACTIVATED = 0b0000_0010;
87        const IS_GUEST = 0b0000_0100;
88        const IS_APPSERVICE = 0b0000_1000;
89    }
90}
91
92impl UserFlags {
93    const fn is_deactivated(self) -> bool {
94        self.contains(UserFlags::IS_DEACTIVATED)
95    }
96
97    const fn is_guest(self) -> bool {
98        self.contains(UserFlags::IS_GUEST)
99    }
100
101    const fn is_synapse_admin(self) -> bool {
102        self.contains(UserFlags::IS_SYNAPSE_ADMIN)
103    }
104
105    const fn is_appservice(self) -> bool {
106        self.contains(UserFlags::IS_APPSERVICE)
107    }
108}
109
110#[derive(Debug, Clone, Copy)]
111struct UserInfo {
112    mas_user_id: Option<NonNilUuid>,
113    flags: UserFlags,
114}
115
116struct MigrationState {
117    /// The server name we're migrating from
118    server_name: String,
119
120    /// Lookup table from user localpart to that user's infos
121    users: HashMap<CompactString, UserInfo>,
122
123    /// Mapping of MAS user ID + device ID to a MAS compat session ID.
124    devices_to_compat_sessions: HashMap<(NonNilUuid, CompactString), Uuid>,
125
126    /// A mapping of Synapse external ID providers to MAS upstream OAuth 2.0
127    /// provider ID
128    provider_id_mapping: std::collections::HashMap<String, Uuid>,
129}
130
131/// Performs a migration from Synapse's database to MAS' database.
132///
133/// # Panics
134///
135/// - If there are more than `usize::MAX` users
136///
137/// # Errors
138///
139/// Errors are returned under the following circumstances:
140///
141/// - An underlying database access error, either to MAS or to Synapse.
142/// - Invalid data in the Synapse database.
143#[expect(clippy::implicit_hasher)]
144pub async fn migrate(
145    mut synapse: SynapseReader<'_>,
146    mas: MasWriter,
147    server_name: String,
148    clock: &dyn Clock,
149    rng: &mut impl RngCore,
150    provider_id_mapping: std::collections::HashMap<String, Uuid>,
151    progress: &Progress,
152) -> Result<(), Error> {
153    let counts = synapse.count_rows().await.into_synapse("counting users")?;
154
155    let state = MigrationState {
156        server_name,
157        // We oversize the hashmaps, as the estimates are innaccurate, and we would like to avoid
158        // reallocations.
159        users: HashMap::with_capacity_and_hasher(counts.users * 9 / 8, RandomState::default()),
160        devices_to_compat_sessions: HashMap::with_capacity_and_hasher(
161            counts.devices * 9 / 8,
162            RandomState::default(),
163        ),
164        provider_id_mapping,
165    };
166
167    let progress_counter = progress.migrating_data(EntityType::Users, counts.users);
168    let (mas, state) = migrate_users(&mut synapse, mas, state, rng, progress_counter).await?;
169
170    let progress_counter = progress.migrating_data(EntityType::ThreePids, counts.threepids);
171    let (mas, state) = migrate_threepids(&mut synapse, mas, rng, state, progress_counter).await?;
172
173    let progress_counter = progress.migrating_data(EntityType::ExternalIds, counts.external_ids);
174    let (mas, state) =
175        migrate_external_ids(&mut synapse, mas, rng, state, progress_counter).await?;
176
177    let progress_counter = progress.migrating_data(
178        EntityType::NonRefreshableAccessTokens,
179        counts.access_tokens - counts.refresh_tokens,
180    );
181    let (mas, state) =
182        migrate_unrefreshable_access_tokens(&mut synapse, mas, clock, rng, state, progress_counter)
183            .await?;
184
185    let progress_counter =
186        progress.migrating_data(EntityType::RefreshableTokens, counts.refresh_tokens);
187    let (mas, state) =
188        migrate_refreshable_token_pairs(&mut synapse, mas, clock, rng, state, progress_counter)
189            .await?;
190
191    let progress_counter = progress.migrating_data(EntityType::Devices, counts.devices);
192    let (mas, _state) = migrate_devices(&mut synapse, mas, rng, state, progress_counter).await?;
193
194    synapse
195        .finish()
196        .await
197        .into_synapse("failed to close Synapse reader")?;
198
199    mas.finish(progress)
200        .await
201        .into_mas("failed to finalise MAS database")?;
202
203    Ok(())
204}
205
206#[tracing::instrument(skip_all, level = Level::INFO)]
207async fn migrate_users(
208    synapse: &mut SynapseReader<'_>,
209    mut mas: MasWriter,
210    mut state: MigrationState,
211    rng: &mut impl RngCore,
212    progress_counter: ProgressCounter,
213) -> Result<(MasWriter, MigrationState), Error> {
214    let start = Instant::now();
215    let progress_counter_ = progress_counter.clone();
216
217    let (tx, mut rx) = tokio::sync::mpsc::channel::<SynapseUser>(100 * 1024);
218
219    // create a new RNG seeded from the passed RNG so that we can move it into the
220    // spawned task
221    let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
222    let task = tokio::spawn(
223        async move {
224            let mut user_buffer = MasWriteBuffer::new(&mas);
225            let mut password_buffer = MasWriteBuffer::new(&mas);
226
227            while let Some(user) = rx.recv().await {
228                // Handling an edge case: some AS users may have invalid localparts containing
229                // extra `:` characters. These users are ignored and a warning is logged.
230                if user.appservice_id.is_some()
231                    && user
232                        .name
233                        .0
234                        .strip_suffix(&format!(":{}", state.server_name))
235                        .is_some_and(|localpart| localpart.contains(':'))
236                {
237                    tracing::warn!("AS user {} has invalid localpart, ignoring!", user.name.0);
238                    continue;
239                }
240
241                let (mas_user, mas_password_opt) =
242                    transform_user(&user, &state.server_name, &mut rng)?;
243
244                let mut flags = UserFlags::empty();
245                if bool::from(user.admin) {
246                    flags |= UserFlags::IS_SYNAPSE_ADMIN;
247                }
248                if bool::from(user.deactivated) {
249                    flags |= UserFlags::IS_DEACTIVATED;
250                }
251                if bool::from(user.is_guest) {
252                    flags |= UserFlags::IS_GUEST;
253                }
254                if user.appservice_id.is_some() {
255                    flags |= UserFlags::IS_APPSERVICE;
256
257                    progress_counter.increment_skipped();
258
259                    // Special case for appservice users: we don't insert them into the database
260                    // We just record the user's information in the state and continue
261                    state.users.insert(
262                        CompactString::new(&mas_user.username),
263                        UserInfo {
264                            mas_user_id: None,
265                            flags,
266                        },
267                    );
268                    continue;
269                }
270
271                state.users.insert(
272                    CompactString::new(&mas_user.username),
273                    UserInfo {
274                        mas_user_id: Some(mas_user.user_id),
275                        flags,
276                    },
277                );
278
279                user_buffer
280                    .write(&mut mas, mas_user)
281                    .await
282                    .into_mas("writing user")?;
283
284                if let Some(mas_password) = mas_password_opt {
285                    password_buffer
286                        .write(&mut mas, mas_password)
287                        .await
288                        .into_mas("writing password")?;
289                }
290
291                progress_counter.increment_migrated();
292            }
293
294            user_buffer
295                .finish(&mut mas)
296                .await
297                .into_mas("writing users")?;
298            password_buffer
299                .finish(&mut mas)
300                .await
301                .into_mas("writing passwords")?;
302
303            Ok((mas, state))
304        }
305        .instrument(tracing::info_span!("ingest_task")),
306    );
307
308    // In case this has an error, we still want to join the task, so we look at the
309    // error later
310    let res = synapse
311        .read_users()
312        .map_err(|e| e.into_synapse("reading users"))
313        .forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
314        .inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
315        .await;
316
317    let (mas, state) = task.await.into_join("user write task")??;
318
319    res?;
320
321    info!(
322        "{} users migrated ({} skipped) in {:.1}s",
323        progress_counter_.migrated(),
324        progress_counter_.skipped(),
325        Instant::now().duration_since(start).as_secs_f64()
326    );
327
328    Ok((mas, state))
329}
330
331#[tracing::instrument(skip_all, level = Level::INFO)]
332async fn migrate_threepids(
333    synapse: &mut SynapseReader<'_>,
334    mut mas: MasWriter,
335    rng: &mut impl RngCore,
336    state: MigrationState,
337    progress_counter: ProgressCounter,
338) -> Result<(MasWriter, MigrationState), Error> {
339    let start = Instant::now();
340    let progress_counter_ = progress_counter.clone();
341
342    let (tx, mut rx) = tokio::sync::mpsc::channel::<SynapseThreepid>(100 * 1024);
343
344    // create a new RNG seeded from the passed RNG so that we can move it into the
345    // spawned task
346    let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
347    let task = tokio::spawn(
348        async move {
349            let mut email_buffer = MasWriteBuffer::new(&mas);
350            let mut unsupported_buffer = MasWriteBuffer::new(&mas);
351
352            while let Some(threepid) = rx.recv().await {
353                let SynapseThreepid {
354                    user_id: synapse_user_id,
355                    medium,
356                    address,
357                    added_at,
358                } = threepid;
359                let created_at: DateTime<Utc> = added_at.into();
360
361                let username = synapse_user_id
362                    .extract_localpart(&state.server_name)
363                    .into_extract_localpart(synapse_user_id.clone())?
364                    .to_owned();
365                let Some(user_infos) = state.users.get(username.as_str()).copied() else {
366                    return Err(Error::MissingUserFromDependentTable {
367                        table: "user_threepids".to_owned(),
368                        user: synapse_user_id,
369                    });
370                };
371
372                let Some(mas_user_id) = user_infos.mas_user_id else {
373                    progress_counter.increment_skipped();
374                    continue;
375                };
376
377                if medium == "email" {
378                    email_buffer
379                        .write(
380                            &mut mas,
381                            MasNewEmailThreepid {
382                                user_id: mas_user_id,
383                                user_email_id: Uuid::from(Ulid::from_datetime_with_source(
384                                    created_at.into(),
385                                    &mut rng,
386                                )),
387                                email: address,
388                                created_at,
389                            },
390                        )
391                        .await
392                        .into_mas("writing email")?;
393                } else {
394                    unsupported_buffer
395                        .write(
396                            &mut mas,
397                            MasNewUnsupportedThreepid {
398                                user_id: mas_user_id,
399                                medium,
400                                address,
401                                created_at,
402                            },
403                        )
404                        .await
405                        .into_mas("writing unsupported threepid")?;
406                }
407
408                progress_counter.increment_migrated();
409            }
410
411            email_buffer
412                .finish(&mut mas)
413                .await
414                .into_mas("writing email threepids")?;
415            unsupported_buffer
416                .finish(&mut mas)
417                .await
418                .into_mas("writing unsupported threepids")?;
419
420            Ok((mas, state))
421        }
422        .instrument(tracing::info_span!("ingest_task")),
423    );
424
425    // In case this has an error, we still want to join the task, so we look at the
426    // error later
427    let res = synapse
428        .read_threepids()
429        .map_err(|e| e.into_synapse("reading threepids"))
430        .forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
431        .inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
432        .await;
433
434    let (mas, state) = task.await.into_join("threepid write task")??;
435
436    res?;
437
438    info!(
439        "{} third-party IDs migrated ({} skipped) in {:.1}s",
440        progress_counter_.migrated(),
441        progress_counter_.skipped(),
442        Instant::now().duration_since(start).as_secs_f64()
443    );
444
445    Ok((mas, state))
446}
447
448#[tracing::instrument(skip_all, level = Level::INFO)]
449async fn migrate_external_ids(
450    synapse: &mut SynapseReader<'_>,
451    mut mas: MasWriter,
452    rng: &mut impl RngCore,
453    state: MigrationState,
454    progress_counter: ProgressCounter,
455) -> Result<(MasWriter, MigrationState), Error> {
456    let start = Instant::now();
457    let progress_counter_ = progress_counter.clone();
458
459    let (tx, mut rx) = tokio::sync::mpsc::channel::<SynapseExternalId>(100 * 1024);
460
461    // create a new RNG seeded from the passed RNG so that we can move it into the
462    // spawned task
463    let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
464    let task = tokio::spawn(
465        async move {
466            let mut write_buffer = MasWriteBuffer::new(&mas);
467
468            while let Some(extid) = rx.recv().await {
469                let SynapseExternalId {
470                    user_id: synapse_user_id,
471                    auth_provider,
472                    external_id: subject,
473                } = extid;
474                let username = synapse_user_id
475                    .extract_localpart(&state.server_name)
476                    .into_extract_localpart(synapse_user_id.clone())?
477                    .to_owned();
478                let Some(user_infos) = state.users.get(username.as_str()).copied() else {
479                    return Err(Error::MissingUserFromDependentTable {
480                        table: "user_external_ids".to_owned(),
481                        user: synapse_user_id,
482                    });
483                };
484
485                let Some(mas_user_id) = user_infos.mas_user_id else {
486                    progress_counter.increment_skipped();
487                    continue;
488                };
489
490                let Some(&upstream_provider_id) = state.provider_id_mapping.get(&auth_provider)
491                else {
492                    return Err(Error::MissingAuthProviderMapping {
493                        synapse_id: auth_provider,
494                        user: synapse_user_id,
495                    });
496                };
497
498                // To save having to store user creation times, extract it from the ULID
499                // This gives millisecond precision — good enough.
500                let user_created_ts = Ulid::from(mas_user_id.get()).datetime();
501
502                let link_id: Uuid =
503                    Ulid::from_datetime_with_source(user_created_ts, &mut rng).into();
504
505                write_buffer
506                    .write(
507                        &mut mas,
508                        MasNewUpstreamOauthLink {
509                            link_id,
510                            user_id: mas_user_id,
511                            upstream_provider_id,
512                            subject,
513                            created_at: user_created_ts.into(),
514                        },
515                    )
516                    .await
517                    .into_mas("failed to write upstream link")?;
518
519                progress_counter.increment_migrated();
520            }
521
522            write_buffer
523                .finish(&mut mas)
524                .await
525                .into_mas("writing upstream links")?;
526
527            Ok((mas, state))
528        }
529        .instrument(tracing::info_span!("ingest_task")),
530    );
531
532    // In case this has an error, we still want to join the task, so we look at the
533    // error later
534    let res = synapse
535        .read_user_external_ids()
536        .map_err(|e| e.into_synapse("reading external ID"))
537        .forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
538        .inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
539        .await;
540
541    let (mas, state) = task.await.into_join("external IDs write task")??;
542
543    res?;
544
545    info!(
546        "{} upstream links (external IDs) migrated ({} skipped) in {:.1}s",
547        progress_counter_.migrated(),
548        progress_counter_.skipped(),
549        Instant::now().duration_since(start).as_secs_f64()
550    );
551
552    Ok((mas, state))
553}
554
555/// Migrate devices from Synapse to MAS (as compat sessions).
556///
557/// In order to get the right session creation timestamps, the access tokens
558/// must counterintuitively be migrated first, with the ULIDs passed in as
559/// `devices`.
560///
561/// This is because only access tokens store a timestamp that in any way
562/// resembles a creation timestamp.
563#[tracing::instrument(skip_all, level = Level::INFO)]
564async fn migrate_devices(
565    synapse: &mut SynapseReader<'_>,
566    mut mas: MasWriter,
567    rng: &mut impl RngCore,
568    mut state: MigrationState,
569    progress_counter: ProgressCounter,
570) -> Result<(MasWriter, MigrationState), Error> {
571    let start = Instant::now();
572    let progress_counter_ = progress_counter.clone();
573
574    let (tx, mut rx) = tokio::sync::mpsc::channel(100 * 1024);
575
576    // create a new RNG seeded from the passed RNG so that we can move it into the
577    // spawned task
578    let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
579    let task = tokio::spawn(
580        async move {
581            let mut write_buffer = MasWriteBuffer::new(&mas);
582
583            while let Some(device) = rx.recv().await {
584                let SynapseDevice {
585                    user_id: synapse_user_id,
586                    device_id,
587                    display_name,
588                    last_seen,
589                    ip,
590                    user_agent,
591                } = device;
592                let username = synapse_user_id
593                    .extract_localpart(&state.server_name)
594                    .into_extract_localpart(synapse_user_id.clone())?
595                    .to_owned();
596                let Some(user_infos) = state.users.get(username.as_str()).copied() else {
597                    return Err(Error::MissingUserFromDependentTable {
598                        table: "devices".to_owned(),
599                        user: synapse_user_id,
600                    });
601                };
602
603                let Some(mas_user_id) = user_infos.mas_user_id else {
604                    progress_counter.increment_skipped();
605                    continue;
606                };
607
608                if user_infos.flags.is_deactivated()
609                    || user_infos.flags.is_guest()
610                    || user_infos.flags.is_appservice()
611                {
612                    continue;
613                }
614
615                let session_id = *state
616                    .devices_to_compat_sessions
617                    .entry((mas_user_id, CompactString::new(&device_id)))
618                    .or_insert_with(||
619                // We don't have a creation time for this device (as it has no access token),
620                // so use now as a least-evil fallback.
621                Ulid::with_source(&mut rng).into());
622                let created_at = Ulid::from(session_id).datetime().into();
623
624                // As we're using a real IP type in the MAS database, it is possible
625                // that we encounter invalid IP addresses in the Synapse database.
626                // In that case, we should ignore them, but still log a warning.
627                // One special case: Synapse will record '-' as IP in some cases, we don't want
628                // to log about those
629                let last_active_ip = ip.filter(|ip| ip != "-").and_then(|ip| {
630                    ip.parse()
631                        .map_err(|e| {
632                            tracing::warn!(
633                                error = &e as &dyn std::error::Error,
634                                mxid = %synapse_user_id,
635                                %device_id,
636                                %ip,
637                                "Failed to parse device IP, ignoring"
638                            );
639                        })
640                        .ok()
641                });
642
643                write_buffer
644                    .write(
645                        &mut mas,
646                        MasNewCompatSession {
647                            session_id,
648                            user_id: mas_user_id,
649                            device_id: Some(device_id),
650                            human_name: display_name,
651                            created_at,
652                            is_synapse_admin: user_infos.flags.is_synapse_admin(),
653                            last_active_at: last_seen.map(DateTime::from),
654                            last_active_ip,
655                            user_agent,
656                        },
657                    )
658                    .await
659                    .into_mas("writing compat sessions")?;
660
661                progress_counter.increment_migrated();
662            }
663
664            write_buffer
665                .finish(&mut mas)
666                .await
667                .into_mas("writing compat sessions")?;
668
669            Ok((mas, state))
670        }
671        .instrument(tracing::info_span!("ingest_task")),
672    );
673
674    // In case this has an error, we still want to join the task, so we look at the
675    // error later
676    let res = synapse
677        .read_devices()
678        .map_err(|e| e.into_synapse("reading devices"))
679        .forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
680        .inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
681        .await;
682
683    let (mas, state) = task.await.into_join("device write task")??;
684
685    res?;
686
687    info!(
688        "{} devices migrated ({} skipped) in {:.1}s",
689        progress_counter_.migrated(),
690        progress_counter_.skipped(),
691        Instant::now().duration_since(start).as_secs_f64()
692    );
693
694    Ok((mas, state))
695}
696
697/// Migrates unrefreshable access tokens (those without an associated refresh
698/// token). Some of these may be deviceless.
699#[tracing::instrument(skip_all, level = Level::INFO)]
700async fn migrate_unrefreshable_access_tokens(
701    synapse: &mut SynapseReader<'_>,
702    mut mas: MasWriter,
703    clock: &dyn Clock,
704    rng: &mut impl RngCore,
705    mut state: MigrationState,
706    progress_counter: ProgressCounter,
707) -> Result<(MasWriter, MigrationState), Error> {
708    let start = Instant::now();
709    let progress_counter_ = progress_counter.clone();
710
711    let (tx, mut rx) = tokio::sync::mpsc::channel(100 * 1024);
712
713    let now = clock.now();
714    // create a new RNG seeded from the passed RNG so that we can move it into the
715    // spawned task
716    let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
717    let task = tokio::spawn(
718        async move {
719            let mut write_buffer = MasWriteBuffer::new(&mas);
720            let mut deviceless_session_write_buffer = MasWriteBuffer::new(&mas);
721
722            while let Some(token) = rx.recv().await {
723                let SynapseAccessToken {
724                    user_id: synapse_user_id,
725                    device_id,
726                    token,
727                    valid_until_ms,
728                    last_validated,
729                } = token;
730                let username = synapse_user_id
731                    .extract_localpart(&state.server_name)
732                    .into_extract_localpart(synapse_user_id.clone())?
733                    .to_owned();
734                let Some(user_infos) = state.users.get(username.as_str()).copied() else {
735                    return Err(Error::MissingUserFromDependentTable {
736                        table: "access_tokens".to_owned(),
737                        user: synapse_user_id,
738                    });
739                };
740
741                let Some(mas_user_id) = user_infos.mas_user_id else {
742                    progress_counter.increment_skipped();
743                    continue;
744                };
745
746                if user_infos.flags.is_deactivated()
747                    || user_infos.flags.is_guest()
748                    || user_infos.flags.is_appservice()
749                {
750                    progress_counter.increment_skipped();
751                    continue;
752                }
753
754                // It's not always accurate, but last_validated is *often* the creation time of
755                // the device If we don't have one, then use the current time as a
756                // fallback.
757                let created_at = last_validated.map_or_else(|| now, DateTime::from);
758
759                let session_id = if let Some(device_id) = device_id {
760                    // Use the existing device_id if this is the second token for a device
761                    *state
762                        .devices_to_compat_sessions
763                        .entry((mas_user_id, CompactString::new(&device_id)))
764                        .or_insert_with(|| {
765                            Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng))
766                        })
767                } else {
768                    // If this is a deviceless access token, create a deviceless compat session
769                    // for it (since otherwise we won't create one whilst migrating devices)
770                    let deviceless_session_id =
771                        Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
772
773                    deviceless_session_write_buffer
774                        .write(
775                            &mut mas,
776                            MasNewCompatSession {
777                                session_id: deviceless_session_id,
778                                user_id: mas_user_id,
779                                device_id: None,
780                                human_name: None,
781                                created_at,
782                                is_synapse_admin: false,
783                                last_active_at: None,
784                                last_active_ip: None,
785                                user_agent: None,
786                            },
787                        )
788                        .await
789                        .into_mas("failed to write deviceless compat sessions")?;
790
791                    deviceless_session_id
792                };
793
794                let token_id =
795                    Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
796
797                write_buffer
798                    .write(
799                        &mut mas,
800                        MasNewCompatAccessToken {
801                            token_id,
802                            session_id,
803                            access_token: token,
804                            created_at,
805                            expires_at: valid_until_ms.map(DateTime::from),
806                        },
807                    )
808                    .await
809                    .into_mas("writing compat access tokens")?;
810
811                progress_counter.increment_migrated();
812            }
813            write_buffer
814                .finish(&mut mas)
815                .await
816                .into_mas("writing compat access tokens")?;
817            deviceless_session_write_buffer
818                .finish(&mut mas)
819                .await
820                .into_mas("writing deviceless compat sessions")?;
821
822            Ok((mas, state))
823        }
824        .instrument(tracing::info_span!("ingest_task")),
825    );
826
827    // In case this has an error, we still want to join the task, so we look at the
828    // error later
829    let res = synapse
830        .read_unrefreshable_access_tokens()
831        .map_err(|e| e.into_synapse("reading tokens"))
832        .forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
833        .inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
834        .await;
835
836    let (mas, state) = task.await.into_join("token write task")??;
837
838    res?;
839
840    info!(
841        "{} non-refreshable access tokens migrated ({} skipped) in {:.1}s",
842        progress_counter_.migrated(),
843        progress_counter_.skipped(),
844        Instant::now().duration_since(start).as_secs_f64()
845    );
846
847    Ok((mas, state))
848}
849
850/// Migrates (access token, refresh token) pairs.
851/// Does not migrate non-refreshable access tokens.
852#[tracing::instrument(skip_all, level = Level::INFO)]
853async fn migrate_refreshable_token_pairs(
854    synapse: &mut SynapseReader<'_>,
855    mut mas: MasWriter,
856    clock: &dyn Clock,
857    rng: &mut impl RngCore,
858    mut state: MigrationState,
859    progress_counter: ProgressCounter,
860) -> Result<(MasWriter, MigrationState), Error> {
861    let start = Instant::now();
862    let progress_counter_ = progress_counter.clone();
863
864    let (tx, mut rx) = tokio::sync::mpsc::channel::<SynapseRefreshableTokenPair>(100 * 1024);
865
866    // create a new RNG seeded from the passed RNG so that we can move it into the
867    // spawned task
868    let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
869    let now = clock.now();
870    let task = tokio::spawn(
871        async move {
872            let mut access_token_write_buffer = MasWriteBuffer::new(&mas);
873            let mut refresh_token_write_buffer = MasWriteBuffer::new(&mas);
874
875            while let Some(token) = rx.recv().await {
876                let SynapseRefreshableTokenPair {
877                    user_id: synapse_user_id,
878                    device_id,
879                    access_token,
880                    refresh_token,
881                    valid_until_ms,
882                    last_validated,
883                } = token;
884
885                let username = synapse_user_id
886                    .extract_localpart(&state.server_name)
887                    .into_extract_localpart(synapse_user_id.clone())?
888                    .to_owned();
889                let Some(user_infos) = state.users.get(username.as_str()).copied() else {
890                    return Err(Error::MissingUserFromDependentTable {
891                        table: "refresh_tokens".to_owned(),
892                        user: synapse_user_id,
893                    });
894                };
895
896                let Some(mas_user_id) = user_infos.mas_user_id else {
897                    progress_counter.increment_skipped();
898                    continue;
899                };
900
901                if user_infos.flags.is_deactivated()
902                    || user_infos.flags.is_guest()
903                    || user_infos.flags.is_appservice()
904                {
905                    progress_counter.increment_skipped();
906                    continue;
907                }
908
909                // It's not always accurate, but last_validated is *often* the creation time of
910                // the device If we don't have one, then use the current time as a
911                // fallback.
912                let created_at = last_validated.map_or_else(|| now, DateTime::from);
913
914                // Use the existing device_id if this is the second token for a device
915                let session_id = *state
916                    .devices_to_compat_sessions
917                    .entry((mas_user_id, CompactString::new(&device_id)))
918                    .or_insert_with(|| {
919                        Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng))
920                    });
921
922                let access_token_id =
923                    Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
924                let refresh_token_id =
925                    Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
926
927                access_token_write_buffer
928                    .write(
929                        &mut mas,
930                        MasNewCompatAccessToken {
931                            token_id: access_token_id,
932                            session_id,
933                            access_token,
934                            created_at,
935                            expires_at: valid_until_ms.map(DateTime::from),
936                        },
937                    )
938                    .await
939                    .into_mas("writing compat access tokens")?;
940                refresh_token_write_buffer
941                    .write(
942                        &mut mas,
943                        MasNewCompatRefreshToken {
944                            refresh_token_id,
945                            session_id,
946                            access_token_id,
947                            refresh_token,
948                            created_at,
949                        },
950                    )
951                    .await
952                    .into_mas("writing compat refresh tokens")?;
953
954                progress_counter.increment_migrated();
955            }
956
957            access_token_write_buffer
958                .finish(&mut mas)
959                .await
960                .into_mas("writing compat access tokens")?;
961
962            refresh_token_write_buffer
963                .finish(&mut mas)
964                .await
965                .into_mas("writing compat refresh tokens")?;
966            Ok((mas, state))
967        }
968        .instrument(tracing::info_span!("ingest_task")),
969    );
970
971    // In case this has an error, we still want to join the task, so we look at the
972    // error later
973    let res = synapse
974        .read_refreshable_token_pairs()
975        .map_err(|e| e.into_synapse("reading refresh token pairs"))
976        .forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
977        .inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
978        .await;
979
980    let (mas, state) = task.await.into_join("refresh token write task")??;
981
982    res?;
983
984    info!(
985        "{} refreshable token pairs migrated ({} skipped) in {:.1}s",
986        progress_counter_.migrated(),
987        progress_counter_.skipped(),
988        Instant::now().duration_since(start).as_secs_f64()
989    );
990
991    Ok((mas, state))
992}
993
994fn transform_user(
995    user: &SynapseUser,
996    server_name: &str,
997    rng: &mut impl RngCore,
998) -> Result<(MasNewUser, Option<MasNewUserPassword>), Error> {
999    let username = user
1000        .name
1001        .extract_localpart(server_name)
1002        .into_extract_localpart(user.name.clone())?
1003        .to_owned();
1004
1005    let user_id = Uuid::from(Ulid::from_datetime_with_source(
1006        DateTime::<Utc>::from(user.creation_ts).into(),
1007        rng,
1008    ))
1009    .try_into()
1010    .expect("ULID generation lead to a nil UUID, this is a bug!");
1011
1012    let new_user = MasNewUser {
1013        user_id,
1014        username,
1015        created_at: user.creation_ts.into(),
1016        locked_at: user.locked.then_some(user.creation_ts.into()),
1017        deactivated_at: bool::from(user.deactivated).then_some(user.creation_ts.into()),
1018        can_request_admin: bool::from(user.admin),
1019        is_guest: bool::from(user.is_guest),
1020    };
1021
1022    let mas_password = user
1023        .password_hash
1024        .clone()
1025        .map(|password_hash| MasNewUserPassword {
1026            user_password_id: Uuid::from(Ulid::from_datetime_with_source(
1027                DateTime::<Utc>::from(user.creation_ts).into(),
1028                rng,
1029            )),
1030            user_id: new_user.user_id,
1031            hashed_password: password_hash,
1032            created_at: new_user.created_at,
1033        });
1034
1035    Ok((new_user, mas_password))
1036}