mas_templates/
functions.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7// This is needed to make the Environment::add* functions work
8#![allow(clippy::needless_pass_by_value)]
9
10//! Additional functions, tests and filters used in templates
11
12use std::{
13    collections::HashMap,
14    fmt::Formatter,
15    str::FromStr,
16    sync::{Arc, atomic::AtomicUsize},
17};
18
19use camino::Utf8Path;
20use mas_i18n::{Argument, ArgumentList, DataLocale, Translator, sprintf::FormattedMessagePart};
21use mas_router::UrlBuilder;
22use mas_spa::ViteManifest;
23use minijinja::{
24    Error, ErrorKind, State, Value, escape_formatter,
25    machinery::make_string_output,
26    value::{Kwargs, Object, ViaDeserialize, from_args},
27};
28use url::Url;
29
30pub fn register(
31    env: &mut minijinja::Environment,
32    url_builder: UrlBuilder,
33    vite_manifest: ViteManifest,
34    translator: Arc<Translator>,
35) {
36    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
37
38    minijinja_contrib::add_to_environment(env);
39    env.add_test("empty", self::tester_empty);
40    env.add_filter("to_params", filter_to_params);
41    env.add_filter("simplify_url", filter_simplify_url);
42    env.add_filter("add_slashes", filter_add_slashes);
43    env.add_filter("parse_user_agent", filter_parse_user_agent);
44    env.add_function("add_params_to_url", function_add_params_to_url);
45    env.add_function("counter", || Ok(Value::from_object(Counter::default())));
46    env.add_global(
47        "include_asset",
48        Value::from_object(IncludeAsset {
49            url_builder: url_builder.clone(),
50            vite_manifest,
51        }),
52    );
53    env.add_global(
54        "translator",
55        Value::from_object(TranslatorFunc { translator }),
56    );
57    env.add_filter("prefix_url", move |url: &str| -> String {
58        if !url.starts_with('/') {
59            // Let's assume it's not an internal URL and return it as-is
60            return url.to_owned();
61        }
62
63        let Some(prefix) = url_builder.prefix() else {
64            // If there is no prefix to add, return the URL as-is
65            return url.to_owned();
66        };
67
68        format!("{prefix}{url}")
69    });
70}
71
72fn tester_empty(seq: Value) -> bool {
73    seq.len() == Some(0)
74}
75
76fn filter_add_slashes(value: &str) -> String {
77    value
78        .replace('\\', "\\\\")
79        .replace('\"', "\\\"")
80        .replace('\'', "\\\'")
81}
82
83fn filter_to_params(params: &Value, kwargs: Kwargs) -> Result<String, Error> {
84    let params = serde_urlencoded::to_string(params).map_err(|e| {
85        Error::new(
86            ErrorKind::InvalidOperation,
87            "Could not serialize parameters",
88        )
89        .with_source(e)
90    })?;
91
92    let prefix = kwargs.get("prefix").unwrap_or("");
93    kwargs.assert_all_used()?;
94
95    if params.is_empty() {
96        Ok(String::new())
97    } else {
98        Ok(format!("{prefix}{params}"))
99    }
100}
101
102/// Filter which simplifies a URL to its domain name for HTTP(S) URLs
103fn filter_simplify_url(url: &str, kwargs: Kwargs) -> Result<String, minijinja::Error> {
104    // Do nothing if the URL is not valid
105    let Ok(mut url) = Url::from_str(url) else {
106        return Ok(url.to_owned());
107    };
108
109    // Always at least remove the query parameters and fragment
110    url.set_query(None);
111    url.set_fragment(None);
112
113    // Do nothing else for non-HTTPS URLs
114    if url.scheme() != "https" {
115        return Ok(url.to_string());
116    }
117
118    let keep_path = kwargs.get::<Option<bool>>("keep_path")?.unwrap_or_default();
119    kwargs.assert_all_used()?;
120
121    // Only return the domain name
122    let Some(domain) = url.domain() else {
123        return Ok(url.to_string());
124    };
125
126    if keep_path {
127        Ok(format!(
128            "{domain}{path}",
129            domain = domain,
130            path = url.path(),
131        ))
132    } else {
133        Ok(domain.to_owned())
134    }
135}
136
137/// Filter which parses a user-agent string
138fn filter_parse_user_agent(user_agent: String) -> Value {
139    let user_agent = mas_data_model::UserAgent::parse(user_agent);
140    Value::from_serialize(user_agent)
141}
142
143enum ParamsWhere {
144    Fragment,
145    Query,
146}
147
148fn function_add_params_to_url(
149    uri: ViaDeserialize<Url>,
150    mode: &str,
151    params: ViaDeserialize<HashMap<String, Value>>,
152) -> Result<String, Error> {
153    use ParamsWhere::{Fragment, Query};
154
155    let mode = match mode {
156        "fragment" => Fragment,
157        "query" => Query,
158        _ => {
159            return Err(Error::new(
160                ErrorKind::InvalidOperation,
161                "Invalid `mode` parameter",
162            ));
163        }
164    };
165
166    // First, get the `uri`, `mode` and `params` parameters
167    // Get the relevant part of the URI and parse for existing parameters
168    let existing = match mode {
169        Fragment => uri.fragment(),
170        Query => uri.query(),
171    };
172    let existing: HashMap<String, Value> = existing
173        .map(serde_urlencoded::from_str)
174        .transpose()
175        .map_err(|e| {
176            Error::new(
177                ErrorKind::InvalidOperation,
178                "Could not parse existing `uri` parameters",
179            )
180            .with_source(e)
181        })?
182        .unwrap_or_default();
183
184    // Merge the exising and the additional parameters together
185    let params: HashMap<&String, &Value> = params.iter().chain(existing.iter()).collect();
186
187    // Transform them back to urlencoded
188    let params = serde_urlencoded::to_string(params).map_err(|e| {
189        Error::new(
190            ErrorKind::InvalidOperation,
191            "Could not serialize back parameters",
192        )
193        .with_source(e)
194    })?;
195
196    let uri = {
197        let mut uri = uri;
198        match mode {
199            Fragment => uri.set_fragment(Some(&params)),
200            Query => uri.set_query(Some(&params)),
201        }
202        uri
203    };
204
205    Ok(uri.to_string())
206}
207
208struct TranslatorFunc {
209    translator: Arc<Translator>,
210}
211
212impl std::fmt::Debug for TranslatorFunc {
213    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
214        f.debug_struct("TranslatorFunc")
215            .field("translator", &"..")
216            .finish()
217    }
218}
219
220impl std::fmt::Display for TranslatorFunc {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        f.write_str("translator")
223    }
224}
225
226impl Object for TranslatorFunc {
227    fn call(self: &Arc<Self>, _state: &State, args: &[Value]) -> Result<Value, Error> {
228        let (lang,): (&str,) = from_args(args)?;
229
230        let lang: DataLocale = lang.parse().map_err(|e| {
231            Error::new(ErrorKind::InvalidOperation, "Invalid language").with_source(e)
232        })?;
233
234        Ok(Value::from_object(TranslateFunc {
235            lang,
236            translator: Arc::clone(&self.translator),
237        }))
238    }
239}
240
241struct TranslateFunc {
242    translator: Arc<Translator>,
243    lang: DataLocale,
244}
245
246impl std::fmt::Debug for TranslateFunc {
247    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
248        f.debug_struct("Translate")
249            .field("translator", &"..")
250            .field("lang", &self.lang)
251            .finish()
252    }
253}
254
255impl std::fmt::Display for TranslateFunc {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        f.write_str("translate")
258    }
259}
260
261impl Object for TranslateFunc {
262    fn call(self: &Arc<Self>, state: &State, args: &[Value]) -> Result<Value, Error> {
263        let (key, kwargs): (&str, Kwargs) = from_args(args)?;
264
265        let (message, _locale) = if let Some(count) = kwargs.get("count")? {
266            self.translator
267                .plural_with_fallback(self.lang.clone(), key, count)
268                .ok_or(Error::new(
269                    ErrorKind::InvalidOperation,
270                    "Missing translation",
271                ))?
272        } else {
273            self.translator
274                .message_with_fallback(self.lang.clone(), key)
275                .ok_or(Error::new(
276                    ErrorKind::InvalidOperation,
277                    "Missing translation",
278                ))?
279        };
280
281        let res: Result<ArgumentList, Error> = kwargs
282            .args()
283            .map(|name| {
284                let value: Value = kwargs.get(name)?;
285                let value = serde_json::to_value(value).map_err(|e| {
286                    Error::new(ErrorKind::InvalidOperation, "Could not serialize argument")
287                        .with_source(e)
288                })?;
289
290                Ok::<_, Error>(Argument::named(name.to_owned(), value))
291            })
292            .collect();
293        let list = res?;
294
295        let formatted = message.format_(&list).map_err(|e| {
296            Error::new(ErrorKind::InvalidOperation, "Could not format message").with_source(e)
297        })?;
298
299        let mut buf = String::with_capacity(formatted.len());
300        let mut output = make_string_output(&mut buf);
301        for part in formatted.parts() {
302            match part {
303                FormattedMessagePart::Text(text) => {
304                    // Literal text, just write it
305                    output.write_str(text)?;
306                }
307                FormattedMessagePart::Placeholder(placeholder) => {
308                    // Placeholder, escape it
309                    escape_formatter(&mut output, state, &placeholder.as_str().into())?;
310                }
311            }
312        }
313
314        Ok(Value::from_safe_string(buf))
315    }
316
317    fn call_method(
318        self: &Arc<Self>,
319        _state: &State,
320        name: &str,
321        args: &[Value],
322    ) -> Result<Value, Error> {
323        match name {
324            "relative_date" => {
325                let (date,): (String,) = from_args(args)?;
326                let date: chrono::DateTime<chrono::Utc> = date.parse().map_err(|e| {
327                    Error::new(
328                        ErrorKind::InvalidOperation,
329                        "Invalid date while calling function `relative_date`",
330                    )
331                    .with_source(e)
332                })?;
333
334                // TODO: grab the clock somewhere
335                #[allow(clippy::disallowed_methods)]
336                let now = chrono::Utc::now();
337
338                let diff = (date - now).num_days();
339
340                Ok(Value::from(
341                    self.translator
342                        .relative_date(&self.lang, diff)
343                        .map_err(|_e| {
344                            Error::new(
345                                ErrorKind::InvalidOperation,
346                                "Failed to format relative date",
347                            )
348                        })?,
349                ))
350            }
351
352            "short_time" => {
353                let (date,): (String,) = from_args(args)?;
354                let date: chrono::DateTime<chrono::Utc> = date.parse().map_err(|e| {
355                    Error::new(
356                        ErrorKind::InvalidOperation,
357                        "Invalid date while calling function `time`",
358                    )
359                    .with_source(e)
360                })?;
361
362                // TODO: we should use the user's timezone here
363                let time = date.time();
364
365                Ok(Value::from(
366                    self.translator
367                        .short_time(&self.lang, &TimeAdapter(time))
368                        .map_err(|_e| {
369                            Error::new(ErrorKind::InvalidOperation, "Failed to format time")
370                        })?,
371                ))
372            }
373
374            _ => Err(Error::new(
375                ErrorKind::InvalidOperation,
376                "Invalid method on include_asset",
377            )),
378        }
379    }
380}
381
382/// An adapter to make a [`Timelike`] implement [`IsoTimeInput`]
383///
384/// [`Timelike`]: chrono::Timelike
385/// [`IsoTimeInput`]: mas_i18n::icu_datetime::input::IsoTimeInput
386struct TimeAdapter<T>(T);
387
388impl<T: chrono::Timelike> mas_i18n::icu_datetime::input::IsoTimeInput for TimeAdapter<T> {
389    fn hour(&self) -> Option<mas_i18n::icu_calendar::types::IsoHour> {
390        let hour: usize = chrono::Timelike::hour(&self.0).try_into().ok()?;
391        hour.try_into().ok()
392    }
393
394    fn minute(&self) -> Option<mas_i18n::icu_calendar::types::IsoMinute> {
395        let minute: usize = chrono::Timelike::minute(&self.0).try_into().ok()?;
396        minute.try_into().ok()
397    }
398
399    fn second(&self) -> Option<mas_i18n::icu_calendar::types::IsoSecond> {
400        let second: usize = chrono::Timelike::second(&self.0).try_into().ok()?;
401        second.try_into().ok()
402    }
403
404    fn nanosecond(&self) -> Option<mas_i18n::icu_calendar::types::NanoSecond> {
405        let nanosecond: usize = chrono::Timelike::nanosecond(&self.0).try_into().ok()?;
406        nanosecond.try_into().ok()
407    }
408}
409
410struct IncludeAsset {
411    url_builder: UrlBuilder,
412    vite_manifest: ViteManifest,
413}
414
415impl std::fmt::Debug for IncludeAsset {
416    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
417        f.debug_struct("IncludeAsset")
418            .field("url_builder", &self.url_builder.assets_base())
419            .field("vite_manifest", &"..")
420            .finish()
421    }
422}
423
424impl std::fmt::Display for IncludeAsset {
425    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426        f.write_str("include_asset")
427    }
428}
429
430impl Object for IncludeAsset {
431    fn call(self: &Arc<Self>, _state: &State, args: &[Value]) -> Result<Value, Error> {
432        let (path,): (&str,) = from_args(args)?;
433
434        let path: &Utf8Path = path.into();
435
436        let (main, imported) = self.vite_manifest.find_assets(path).map_err(|_e| {
437            Error::new(
438                ErrorKind::InvalidOperation,
439                "Invalid assets manifest while calling function `include_asset`",
440            )
441        })?;
442
443        let assets = std::iter::once(main)
444            .chain(imported.iter().filter(|a| a.is_stylesheet()).copied())
445            .filter_map(|asset| asset.include_tag(self.url_builder.assets_base().into()));
446
447        let preloads = imported
448            .iter()
449            .filter(|a| a.is_script())
450            .map(|asset| asset.preload_tag(self.url_builder.assets_base().into()));
451
452        let tags: Vec<String> = preloads.chain(assets).collect();
453
454        Ok(Value::from_safe_string(tags.join("\n")))
455    }
456}
457
458#[derive(Debug, Default)]
459struct Counter {
460    count: AtomicUsize,
461}
462
463impl std::fmt::Display for Counter {
464    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
465        write!(
466            f,
467            "{}",
468            self.count.load(std::sync::atomic::Ordering::Relaxed)
469        )
470    }
471}
472
473impl Object for Counter {
474    fn call_method(
475        self: &Arc<Self>,
476        _state: &State,
477        name: &str,
478        args: &[Value],
479    ) -> Result<Value, Error> {
480        // None of the methods take any arguments
481        from_args::<()>(args)?;
482
483        match name {
484            "reset" => {
485                self.count.store(0, std::sync::atomic::Ordering::Relaxed);
486                Ok(Value::UNDEFINED)
487            }
488            "next" => {
489                let old = self
490                    .count
491                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
492                Ok(Value::from(old))
493            }
494            "peek" => Ok(Value::from(
495                self.count.load(std::sync::atomic::Ordering::Relaxed),
496            )),
497            _ => Err(Error::new(
498                ErrorKind::InvalidOperation,
499                "Invalid method on counter",
500            )),
501        }
502    }
503}