276 lines
9 KiB
Rust
276 lines
9 KiB
Rust
use std::{env, error::Error, fs};
|
|
|
|
use async_trait::async_trait;
|
|
use chrono::{DateTime, Utc};
|
|
use common::{Adaptor, Event, Person, Stats};
|
|
use entity::{event, person, stats};
|
|
use migration::{Migrator, MigratorTrait};
|
|
use sea_orm::{
|
|
strum::Display,
|
|
ActiveModelTrait,
|
|
ActiveValue::{NotSet, Set},
|
|
ColumnTrait, Database, DatabaseConnection, DbErr, EntityTrait, ModelTrait, QueryFilter,
|
|
TransactionError, TransactionTrait, TryIntoModel,
|
|
};
|
|
use serde_json::json;
|
|
use url::Url;
|
|
|
|
mod entity;
|
|
mod migration;
|
|
|
|
pub struct SqlAdaptor {
|
|
db: DatabaseConnection,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Adaptor for SqlAdaptor {
|
|
type Error = SqlAdaptorError;
|
|
|
|
async fn get_stats(&self) -> Result<Stats, Self::Error> {
|
|
let stats_row = get_stats_row(&self.db).await?;
|
|
Ok(Stats {
|
|
event_count: stats_row.event_count.unwrap() as i64,
|
|
person_count: stats_row.person_count.unwrap() as i64,
|
|
})
|
|
}
|
|
|
|
async fn increment_stat_event_count(&self) -> Result<i64, Self::Error> {
|
|
let mut current_stats = get_stats_row(&self.db).await?;
|
|
current_stats.event_count = Set(current_stats.event_count.unwrap() + 1);
|
|
|
|
Ok(current_stats.save(&self.db).await?.event_count.unwrap() as i64)
|
|
}
|
|
|
|
async fn increment_stat_person_count(&self) -> Result<i64, Self::Error> {
|
|
let mut current_stats = get_stats_row(&self.db).await?;
|
|
current_stats.person_count = Set(current_stats.person_count.unwrap() + 1);
|
|
|
|
Ok(current_stats.save(&self.db).await?.person_count.unwrap() as i64)
|
|
}
|
|
|
|
async fn get_people(&self, event_id: String) -> Result<Option<Vec<Person>>, Self::Error> {
|
|
// TODO: optimize into one query
|
|
let event_row = event::Entity::find_by_id(event_id).one(&self.db).await?;
|
|
|
|
Ok(match event_row {
|
|
Some(event) => Some(
|
|
event
|
|
.find_related(person::Entity)
|
|
.all(&self.db)
|
|
.await?
|
|
.into_iter()
|
|
.map(|model| model.into())
|
|
.collect(),
|
|
),
|
|
None => None,
|
|
})
|
|
}
|
|
|
|
async fn upsert_person(
|
|
&self,
|
|
event_id: String,
|
|
person: Person,
|
|
) -> Result<Option<Person>, Self::Error> {
|
|
let data = person::ActiveModel {
|
|
name: Set(person.name.clone()),
|
|
password_hash: Set(person.password_hash),
|
|
created_at: Set(person.created_at.naive_utc()),
|
|
availability: Set(serde_json::to_value(person.availability).unwrap_or(json!([]))),
|
|
event_id: Set(event_id.clone()),
|
|
};
|
|
|
|
// Check if the event exists
|
|
if event::Entity::find_by_id(event_id.clone())
|
|
.one(&self.db)
|
|
.await?
|
|
.is_none()
|
|
{
|
|
return Ok(None);
|
|
}
|
|
|
|
Ok(Some(
|
|
match person::Entity::find_by_id((person.name, event_id))
|
|
.one(&self.db)
|
|
.await?
|
|
{
|
|
Some(_) => data.update(&self.db).await?.try_into_model()?.into(),
|
|
None => data.insert(&self.db).await?.try_into_model()?.into(),
|
|
},
|
|
))
|
|
}
|
|
|
|
async fn get_event(&self, id: String) -> Result<Option<Event>, Self::Error> {
|
|
let existing_event = event::Entity::find_by_id(id).one(&self.db).await?;
|
|
|
|
// Mark as visited
|
|
if let Some(event) = existing_event.clone() {
|
|
let mut event: event::ActiveModel = event.into();
|
|
event.visited_at = Set(Utc::now().naive_utc());
|
|
event.save(&self.db).await?;
|
|
}
|
|
|
|
Ok(existing_event.map(|model| model.into()))
|
|
}
|
|
|
|
async fn create_event(&self, event: Event) -> Result<Event, Self::Error> {
|
|
Ok(event::ActiveModel {
|
|
id: Set(event.id),
|
|
name: Set(event.name),
|
|
created_at: Set(event.created_at.naive_utc()),
|
|
visited_at: Set(event.visited_at.naive_utc()),
|
|
times: Set(serde_json::to_value(event.times).unwrap_or(json!([]))),
|
|
timezone: Set(event.timezone),
|
|
}
|
|
.insert(&self.db)
|
|
.await?
|
|
.try_into_model()?
|
|
.into())
|
|
}
|
|
|
|
async fn delete_events(&self, cutoff: DateTime<Utc>) -> Result<Stats, Self::Error> {
|
|
let (event_count, person_count) = self
|
|
.db
|
|
.transaction::<_, (i64, i64), DbErr>(|t| {
|
|
Box::pin(async move {
|
|
// Get events older than the cutoff date
|
|
let old_events = event::Entity::find()
|
|
.filter(event::Column::VisitedAt.lt(cutoff.naive_utc()))
|
|
.all(t)
|
|
.await?;
|
|
|
|
// Delete people
|
|
let mut people_deleted: i64 = 0;
|
|
// TODO: run concurrently
|
|
for e in old_events.iter() {
|
|
let people_delete_result = person::Entity::delete_many()
|
|
.filter(person::Column::EventId.eq(&e.id))
|
|
.exec(t)
|
|
.await?;
|
|
people_deleted += people_delete_result.rows_affected as i64;
|
|
}
|
|
|
|
// Delete events
|
|
let event_delete_result = event::Entity::delete_many()
|
|
.filter(event::Column::VisitedAt.lt(cutoff.naive_utc()))
|
|
.exec(t)
|
|
.await?;
|
|
|
|
Ok((event_delete_result.rows_affected as i64, people_deleted))
|
|
})
|
|
})
|
|
.await?;
|
|
|
|
Ok(Stats {
|
|
event_count,
|
|
person_count,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Get the current stats as an ActiveModel
|
|
async fn get_stats_row(db: &DatabaseConnection) -> Result<stats::ActiveModel, DbErr> {
|
|
let current_stats = stats::Entity::find().one(db).await?;
|
|
|
|
Ok(match current_stats {
|
|
Some(model) => model.into(),
|
|
None => stats::ActiveModel {
|
|
id: NotSet,
|
|
event_count: Set(0),
|
|
person_count: Set(0),
|
|
},
|
|
})
|
|
}
|
|
|
|
fn get_connection_string() -> String {
|
|
let connection_string =
|
|
env::var("DATABASE_URL").expect("Expected DATABASE_URL environment variable");
|
|
|
|
if let Some(password_file_location) = env::var_os("DATABASE_PASSWORD_FILE") {
|
|
// The password can be left out of the URL, we add it from the specified
|
|
// file (presumably under /run/secrets/)
|
|
let password = fs::read(&password_file_location)
|
|
.unwrap_or_else(|err| {
|
|
panic!("could not read database password from {password_file_location:?}\n\t{err:?}")
|
|
});
|
|
let password = String::from(String::from_utf8_lossy(password.as_slice()));
|
|
let password = password.trim_end();
|
|
let mut url = Url::parse(&connection_string).expect("invalid connection string");
|
|
url.set_password(Some(password))
|
|
.unwrap_or_else(|_| panic!("invalid database URL: {connection_string:?}"));
|
|
url.to_string()
|
|
} else {
|
|
connection_string
|
|
}
|
|
}
|
|
|
|
impl SqlAdaptor {
|
|
pub async fn new() -> Self {
|
|
let connection_string = get_connection_string();
|
|
|
|
// Connect to the database
|
|
let db = Database::connect(&connection_string)
|
|
.await
|
|
.expect("Failed to connect to SQL database");
|
|
println!(
|
|
"{} Connected to database at {}",
|
|
match db {
|
|
DatabaseConnection::SqlxMySqlPoolConnection(_) => "🐬",
|
|
DatabaseConnection::SqlxPostgresPoolConnection(_) => "🐘",
|
|
DatabaseConnection::SqlxSqlitePoolConnection(_) => "🪶",
|
|
DatabaseConnection::Disconnected => panic!("Failed to connect to SQL database"),
|
|
},
|
|
connection_string
|
|
);
|
|
|
|
// Setup tables
|
|
Migrator::up(&db, None)
|
|
.await
|
|
.expect("Failed to set up tables in the database");
|
|
|
|
Self { db }
|
|
}
|
|
}
|
|
|
|
impl From<event::Model> for Event {
|
|
fn from(value: event::Model) -> Self {
|
|
Self {
|
|
id: value.id,
|
|
name: value.name,
|
|
created_at: DateTime::<Utc>::from_utc(value.created_at, Utc),
|
|
visited_at: DateTime::<Utc>::from_utc(value.visited_at, Utc),
|
|
times: serde_json::from_value(value.times).unwrap_or(vec![]),
|
|
timezone: value.timezone,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<person::Model> for Person {
|
|
fn from(value: person::Model) -> Self {
|
|
Self {
|
|
name: value.name,
|
|
password_hash: value.password_hash,
|
|
created_at: DateTime::<Utc>::from_utc(value.created_at, Utc),
|
|
availability: serde_json::from_value(value.availability).unwrap_or(vec![]),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Display, Debug)]
|
|
pub enum SqlAdaptorError {
|
|
DbErr(DbErr),
|
|
TransactionError(TransactionError<DbErr>),
|
|
}
|
|
|
|
impl Error for SqlAdaptorError {}
|
|
|
|
impl From<DbErr> for SqlAdaptorError {
|
|
fn from(value: DbErr) -> Self {
|
|
Self::DbErr(value)
|
|
}
|
|
}
|
|
impl From<TransactionError<DbErr>> for SqlAdaptorError {
|
|
fn from(value: TransactionError<DbErr>) -> Self {
|
|
Self::TransactionError(value)
|
|
}
|
|
}
|