mas_storage_pg/user/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the user-related
8//! repositories
9
10use async_trait::async_trait;
11use mas_data_model::{Clock, User};
12use mas_storage::user::{UserFilter, UserRepository};
13use rand::RngCore;
14use sea_query::{Expr, PostgresQueryBuilder, Query, extension::postgres::PgExpr as _};
15use sea_query_binder::SqlxBinder;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{
21    DatabaseError,
22    filter::{Filter, StatementExt},
23    iden::Users,
24    pagination::QueryBuilderExt,
25    tracing::ExecuteExt,
26};
27
28mod email;
29mod password;
30mod recovery;
31mod registration;
32mod registration_token;
33mod session;
34mod terms;
35
36#[cfg(test)]
37mod tests;
38
39pub use self::{
40    email::PgUserEmailRepository, password::PgUserPasswordRepository,
41    recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
42    registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
43    terms::PgUserTermsRepository,
44};
45
46/// An implementation of [`UserRepository`] for a PostgreSQL connection
47pub struct PgUserRepository<'c> {
48    conn: &'c mut PgConnection,
49}
50
51impl<'c> PgUserRepository<'c> {
52    /// Create a new [`PgUserRepository`] from an active PostgreSQL connection
53    pub fn new(conn: &'c mut PgConnection) -> Self {
54        Self { conn }
55    }
56}
57
58mod priv_ {
59    // The enum_def macro generates a public enum, which we don't want, because it
60    // triggers the missing docs warning
61    #![allow(missing_docs)]
62
63    use chrono::{DateTime, Utc};
64    use mas_storage::pagination::Node;
65    use sea_query::enum_def;
66    use ulid::Ulid;
67    use uuid::Uuid;
68
69    #[derive(Debug, Clone, sqlx::FromRow)]
70    #[enum_def]
71    pub(super) struct UserLookup {
72        pub(super) user_id: Uuid,
73        pub(super) username: String,
74        pub(super) created_at: DateTime<Utc>,
75        pub(super) locked_at: Option<DateTime<Utc>>,
76        pub(super) deactivated_at: Option<DateTime<Utc>>,
77        pub(super) can_request_admin: bool,
78        pub(super) is_guest: bool,
79    }
80
81    impl Node<Ulid> for UserLookup {
82        fn cursor(&self) -> Ulid {
83            self.user_id.into()
84        }
85    }
86}
87
88use priv_::{UserLookup, UserLookupIden};
89
90impl From<UserLookup> for User {
91    fn from(value: UserLookup) -> Self {
92        let id = value.user_id.into();
93        Self {
94            id,
95            username: value.username,
96            sub: id.to_string(),
97            created_at: value.created_at,
98            locked_at: value.locked_at,
99            deactivated_at: value.deactivated_at,
100            can_request_admin: value.can_request_admin,
101            is_guest: value.is_guest,
102        }
103    }
104}
105
106impl Filter for UserFilter<'_> {
107    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
108        sea_query::Condition::all()
109            .add_option(self.state().map(|state| {
110                match state {
111                    mas_storage::user::UserState::Deactivated => {
112                        Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
113                    }
114                    mas_storage::user::UserState::Locked => {
115                        Expr::col((Users::Table, Users::LockedAt)).is_not_null()
116                    }
117                    mas_storage::user::UserState::Active => {
118                        Expr::col((Users::Table, Users::LockedAt))
119                            .is_null()
120                            .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
121                    }
122                }
123            }))
124            .add_option(self.can_request_admin().map(|can_request_admin| {
125                Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
126            }))
127            .add_option(
128                self.is_guest()
129                    .map(|is_guest| Expr::col((Users::Table, Users::IsGuest)).eq(is_guest)),
130            )
131            .add_option(self.search().map(|search| {
132                Expr::col((Users::Table, Users::Username)).ilike(format!("%{search}%"))
133            }))
134    }
135}
136
137#[async_trait]
138impl UserRepository for PgUserRepository<'_> {
139    type Error = DatabaseError;
140
141    #[tracing::instrument(
142        name = "db.user.lookup",
143        skip_all,
144        fields(
145            db.query.text,
146            user.id = %id,
147        ),
148        err,
149    )]
150    async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
151        let res = sqlx::query_as!(
152            UserLookup,
153            r#"
154                SELECT user_id
155                     , username
156                     , created_at
157                     , locked_at
158                     , deactivated_at
159                     , can_request_admin
160                     , is_guest
161                FROM users
162                WHERE user_id = $1
163            "#,
164            Uuid::from(id),
165        )
166        .traced()
167        .fetch_optional(&mut *self.conn)
168        .await?;
169
170        let Some(res) = res else { return Ok(None) };
171
172        Ok(Some(res.into()))
173    }
174
175    #[tracing::instrument(
176        name = "db.user.find_by_username",
177        skip_all,
178        fields(
179            db.query.text,
180            user.username = username,
181        ),
182        err,
183    )]
184    async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
185        // We may have multiple users with the same username, but with a different
186        // casing. In this case, we want to return the one which matches the exact
187        // casing
188        let res = sqlx::query_as!(
189            UserLookup,
190            r#"
191                SELECT user_id
192                     , username
193                     , created_at
194                     , locked_at
195                     , deactivated_at
196                     , can_request_admin
197                     , is_guest
198                FROM users
199                WHERE LOWER(username) = LOWER($1)
200            "#,
201            username,
202        )
203        .traced()
204        .fetch_all(&mut *self.conn)
205        .await?;
206
207        match &res[..] {
208            // Happy path: there is only one user matching the username…
209            [user] => Ok(Some(user.clone().into())),
210            // …or none.
211            [] => Ok(None),
212            list => {
213                // If there are multiple users with the same username, we want to
214                // return the one which matches the exact casing
215                if let Some(user) = list.iter().find(|user| user.username == username) {
216                    Ok(Some(user.clone().into()))
217                } else {
218                    // If none match exactly, we prefer to return nothing
219                    Ok(None)
220                }
221            }
222        }
223    }
224
225    #[tracing::instrument(
226        name = "db.user.add",
227        skip_all,
228        fields(
229            db.query.text,
230            user.username = username,
231            user.id,
232        ),
233        err,
234    )]
235    async fn add(
236        &mut self,
237        rng: &mut (dyn RngCore + Send),
238        clock: &dyn Clock,
239        username: String,
240    ) -> Result<User, Self::Error> {
241        let created_at = clock.now();
242        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
243        tracing::Span::current().record("user.id", tracing::field::display(id));
244
245        let res = sqlx::query!(
246            r#"
247                INSERT INTO users (user_id, username, created_at)
248                VALUES ($1, $2, $3)
249                ON CONFLICT (username) DO NOTHING
250            "#,
251            Uuid::from(id),
252            username,
253            created_at,
254        )
255        .traced()
256        .execute(&mut *self.conn)
257        .await?;
258
259        // If the user already exists, want to return an error but not poison the
260        // transaction
261        DatabaseError::ensure_affected_rows(&res, 1)?;
262
263        Ok(User {
264            id,
265            username,
266            sub: id.to_string(),
267            created_at,
268            locked_at: None,
269            deactivated_at: None,
270            can_request_admin: false,
271            is_guest: false,
272        })
273    }
274
275    #[tracing::instrument(
276        name = "db.user.exists",
277        skip_all,
278        fields(
279            db.query.text,
280            user.username = username,
281        ),
282        err,
283    )]
284    async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
285        let exists = sqlx::query_scalar!(
286            r#"
287                SELECT EXISTS(
288                    SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
289                ) AS "exists!"
290            "#,
291            username
292        )
293        .traced()
294        .fetch_one(&mut *self.conn)
295        .await?;
296
297        Ok(exists)
298    }
299
300    #[tracing::instrument(
301        name = "db.user.lock",
302        skip_all,
303        fields(
304            db.query.text,
305            %user.id,
306        ),
307        err,
308    )]
309    async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
310        if user.locked_at.is_some() {
311            return Ok(user);
312        }
313
314        let locked_at = clock.now();
315        let res = sqlx::query!(
316            r#"
317                UPDATE users
318                SET locked_at = $1
319                WHERE user_id = $2
320            "#,
321            locked_at,
322            Uuid::from(user.id),
323        )
324        .traced()
325        .execute(&mut *self.conn)
326        .await?;
327
328        DatabaseError::ensure_affected_rows(&res, 1)?;
329
330        user.locked_at = Some(locked_at);
331
332        Ok(user)
333    }
334
335    #[tracing::instrument(
336        name = "db.user.unlock",
337        skip_all,
338        fields(
339            db.query.text,
340            %user.id,
341        ),
342        err,
343    )]
344    async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
345        if user.locked_at.is_none() {
346            return Ok(user);
347        }
348
349        let res = sqlx::query!(
350            r#"
351                UPDATE users
352                SET locked_at = NULL
353                WHERE user_id = $1
354            "#,
355            Uuid::from(user.id),
356        )
357        .traced()
358        .execute(&mut *self.conn)
359        .await?;
360
361        DatabaseError::ensure_affected_rows(&res, 1)?;
362
363        user.locked_at = None;
364
365        Ok(user)
366    }
367
368    #[tracing::instrument(
369        name = "db.user.deactivate",
370        skip_all,
371        fields(
372            db.query.text,
373            %user.id,
374        ),
375        err,
376    )]
377    async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
378        if user.deactivated_at.is_some() {
379            return Ok(user);
380        }
381
382        let deactivated_at = clock.now();
383        let res = sqlx::query!(
384            r#"
385                UPDATE users
386                SET deactivated_at = $2
387                WHERE user_id = $1
388                  AND deactivated_at IS NULL
389            "#,
390            Uuid::from(user.id),
391            deactivated_at,
392        )
393        .traced()
394        .execute(&mut *self.conn)
395        .await?;
396
397        DatabaseError::ensure_affected_rows(&res, 1)?;
398
399        user.deactivated_at = Some(deactivated_at);
400
401        Ok(user)
402    }
403
404    #[tracing::instrument(
405        name = "db.user.reactivate",
406        skip_all,
407        fields(
408            db.query.text,
409            %user.id,
410        ),
411        err,
412    )]
413    async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
414        if user.deactivated_at.is_none() {
415            return Ok(user);
416        }
417
418        let res = sqlx::query!(
419            r#"
420                UPDATE users
421                SET deactivated_at = NULL
422                WHERE user_id = $1
423            "#,
424            Uuid::from(user.id),
425        )
426        .traced()
427        .execute(&mut *self.conn)
428        .await?;
429
430        DatabaseError::ensure_affected_rows(&res, 1)?;
431
432        user.deactivated_at = None;
433
434        Ok(user)
435    }
436
437    #[tracing::instrument(
438        name = "db.user.set_can_request_admin",
439        skip_all,
440        fields(
441            db.query.text,
442            %user.id,
443            user.can_request_admin = can_request_admin,
444        ),
445        err,
446    )]
447    async fn set_can_request_admin(
448        &mut self,
449        mut user: User,
450        can_request_admin: bool,
451    ) -> Result<User, Self::Error> {
452        let res = sqlx::query!(
453            r#"
454                UPDATE users
455                SET can_request_admin = $2
456                WHERE user_id = $1
457            "#,
458            Uuid::from(user.id),
459            can_request_admin,
460        )
461        .traced()
462        .execute(&mut *self.conn)
463        .await?;
464
465        DatabaseError::ensure_affected_rows(&res, 1)?;
466
467        user.can_request_admin = can_request_admin;
468
469        Ok(user)
470    }
471
472    #[tracing::instrument(
473        name = "db.user.list",
474        skip_all,
475        fields(
476            db.query.text,
477        ),
478        err,
479    )]
480    async fn list(
481        &mut self,
482        filter: UserFilter<'_>,
483        pagination: mas_storage::Pagination,
484    ) -> Result<mas_storage::Page<User>, Self::Error> {
485        let (sql, arguments) = Query::select()
486            .expr_as(
487                Expr::col((Users::Table, Users::UserId)),
488                UserLookupIden::UserId,
489            )
490            .expr_as(
491                Expr::col((Users::Table, Users::Username)),
492                UserLookupIden::Username,
493            )
494            .expr_as(
495                Expr::col((Users::Table, Users::CreatedAt)),
496                UserLookupIden::CreatedAt,
497            )
498            .expr_as(
499                Expr::col((Users::Table, Users::LockedAt)),
500                UserLookupIden::LockedAt,
501            )
502            .expr_as(
503                Expr::col((Users::Table, Users::DeactivatedAt)),
504                UserLookupIden::DeactivatedAt,
505            )
506            .expr_as(
507                Expr::col((Users::Table, Users::CanRequestAdmin)),
508                UserLookupIden::CanRequestAdmin,
509            )
510            .expr_as(
511                Expr::col((Users::Table, Users::IsGuest)),
512                UserLookupIden::IsGuest,
513            )
514            .from(Users::Table)
515            .apply_filter(filter)
516            .generate_pagination((Users::Table, Users::UserId), pagination)
517            .build_sqlx(PostgresQueryBuilder);
518
519        let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
520            .traced()
521            .fetch_all(&mut *self.conn)
522            .await?;
523
524        let page = pagination.process(edges).map(User::from);
525
526        Ok(page)
527    }
528
529    #[tracing::instrument(
530        name = "db.user.count",
531        skip_all,
532        fields(
533            db.query.text,
534        ),
535        err,
536    )]
537    async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
538        let (sql, arguments) = Query::select()
539            .expr(Expr::col((Users::Table, Users::UserId)).count())
540            .from(Users::Table)
541            .apply_filter(filter)
542            .build_sqlx(PostgresQueryBuilder);
543
544        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
545            .traced()
546            .fetch_one(&mut *self.conn)
547            .await?;
548
549        count
550            .try_into()
551            .map_err(DatabaseError::to_invalid_operation)
552    }
553
554    #[tracing::instrument(
555        name = "db.user.acquire_lock_for_sync",
556        skip_all,
557        fields(
558            db.query.text,
559            user.id = %user.id,
560        ),
561        err,
562    )]
563    async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
564        // XXX: this lock isn't stictly scoped to users, but as we don't use many
565        // postgres advisory locks, it's fine for now. Later on, we could use row-level
566        // locks to make sure we don't get into trouble
567
568        // Convert the user ID to a u128 and grab the lower 64 bits
569        // As this includes 64bit of the random part of the ULID, it should be random
570        // enough to not collide
571        let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
572
573        // Use a PG advisory lock, which will be released when the transaction is
574        // committed or rolled back
575        sqlx::query!(
576            r#"
577                SELECT pg_advisory_xact_lock($1)
578            "#,
579            lock_id,
580        )
581        .traced()
582        .execute(&mut *self.conn)
583        .await?;
584
585        Ok(())
586    }
587}