diff --git a/Cargo.lock b/Cargo.lock index df48e85..aa51c1c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -93,6 +93,18 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +[[package]] +name = "argon2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" +dependencies = [ + "base64ct", + "blake2", + "cpufeatures", + "password-hash", +] + [[package]] name = "arrayvec" version = "0.7.6" @@ -143,6 +155,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" + [[package]] name = "bindgen" version = "0.72.0" @@ -181,6 +199,15 @@ dependencies = [ "wyz", ] +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -531,6 +558,18 @@ dependencies = [ "xxhash-rust", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.3.0" @@ -553,6 +592,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "foreign-types" version = "0.3.2" @@ -677,6 +722,7 @@ name = "gamestream-webtransport-proxy" version = "0.1.0" dependencies = [ "anyhow", + "argon2", "directories", "flatbuffers", "getrandom 0.3.3", @@ -690,6 +736,7 @@ dependencies = [ "openssl", "rand 0.9.1", "reqwest", + "rusqlite", "salvo", "serde", "serde-xml-rs", @@ -834,6 +881,18 @@ name = "hashbrown" version = "0.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashlink" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" +dependencies = [ + "hashbrown 0.15.4", +] [[package]] name = "headers" @@ -1256,6 +1315,17 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbb8270bb4060bd76c6e96f20c52d80620f1d82a3470885694e41e0f81ef6fe7" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -1549,6 +1619,17 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "password-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +dependencies = [ + "base64ct", + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "path-slash" version = "0.2.1" @@ -1989,6 +2070,20 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "rusqlite" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37e34486da88d8e051c7c0e23c3f15fd806ea8546260aa2fec247e97242ec143" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rust-embed" version = "8.6.0" diff --git a/gamestream-webtransport-proxy/Cargo.toml b/gamestream-webtransport-proxy/Cargo.toml index f6cb8fa..a7757fc 100644 --- a/gamestream-webtransport-proxy/Cargo.toml +++ b/gamestream-webtransport-proxy/Cargo.toml @@ -5,6 +5,7 @@ edition = "2024" [dependencies] anyhow = "1.0.98" +argon2 = "0.5" directories = "6.0.0" flatbuffers = "25.2.10" getrandom = { version = "0.3.3", features = ["std"] } @@ -17,6 +18,7 @@ libc = "0.2.174" moonlight-common-c-sys = { path = "../moonlight-common-c-sys" } openssl = "0.10.73" rand = "0.9.1" +rusqlite = { version = "0.34", features = ["bundled"] } reqwest = { version = "0.12.20", features = [ "rustls-tls", "native-tls", diff --git a/gamestream-webtransport-proxy/src/auth.rs b/gamestream-webtransport-proxy/src/auth.rs new file mode 100644 index 0000000..22c19d7 --- /dev/null +++ b/gamestream-webtransport-proxy/src/auth.rs @@ -0,0 +1,326 @@ +use std::sync::Arc; + +use salvo::prelude::*; +use serde::{Deserialize, Serialize}; +use tracing::error; + +use crate::common::{AppError, AppResult}; +use crate::db::{AppPermission, Db, User}; + +const SESSION_MAX_AGE_SECONDS: i64 = 7 * 24 * 3600; // 7 days + +// Key used to store the authenticated user in the Salvo Depot +const USER_DEPOT_KEY: &str = "authenticated_user"; + +pub fn get_user_from_depot(depot: &Depot) -> Option<&User> { + depot.get::(USER_DEPOT_KEY).ok() +} + +// -- Middleware -- + +pub struct SessionAuthMiddleware { + pub db: Arc, +} + +#[handler] +impl SessionAuthMiddleware { + async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) { + let token = req + .headers() + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")); + + let token = match token { + Some(t) => t, + None => { + res.status_code(StatusCode::UNAUTHORIZED); + Json(serde_json::json!({"description": "Missing or invalid Authorization header"})).render(res); + ctrl.skip_rest(); + return; + } + }; + + match self.db.validate_session(token) { + Ok(Some(user)) => { + depot.insert(USER_DEPOT_KEY, user); + } + Ok(None) => { + res.status_code(StatusCode::UNAUTHORIZED); + Json(serde_json::json!({"description": "Invalid or expired session"})).render(res); + ctrl.skip_rest(); + return; + } + Err(e) => { + error!("Session validation error: {e}"); + res.status_code(StatusCode::INTERNAL_SERVER_ERROR); + Json(serde_json::json!({"description": "Internal server error"})).render(res); + ctrl.skip_rest(); + return; + } + } + } +} + +pub struct AdminCheckMiddleware; + +#[handler] +impl AdminCheckMiddleware { + async fn handle(&self, _req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) { + let user = match get_user_from_depot(depot) { + Some(u) => u, + None => { + res.status_code(StatusCode::UNAUTHORIZED); + Json(serde_json::json!({"description": "Not authenticated"})).render(res); + ctrl.skip_rest(); + return; + } + }; + + if !user.is_admin { + res.status_code(StatusCode::FORBIDDEN); + Json(serde_json::json!({"description": "Admin access required"})).render(res); + ctrl.skip_rest(); + return; + } + } +} + +// -- Request/Response types -- + +#[derive(Deserialize, ToSchema)] +pub struct LoginRequest { + pub username: String, + pub password: String, +} + +#[derive(Serialize, ToSchema)] +pub struct LoginResponse { + pub token: String, +} + +#[derive(Serialize, ToSchema)] +pub struct MeResponse { + pub username: String, + pub is_admin: bool, + pub permissions: Vec, +} + +#[derive(Deserialize, ToSchema)] +pub struct CreateUserRequest { + pub username: String, + pub password: String, + pub is_admin: bool, +} + +#[derive(Deserialize, ToSchema)] +pub struct UpdateUserRequest { + pub password: Option, + pub is_admin: Option, +} + +#[derive(Deserialize, ToSchema)] +pub struct SetPermissionsRequest { + pub permissions: Vec, +} + +// -- Auth endpoint handlers -- + +#[craft] +impl crate::backend::Backend { + #[craft(handler)] + pub async fn login( + self: Arc, + body: salvo::oapi::extract::JsonBody, + ) -> AppResult> { + let user = match self.db.verify_password(&body.username, &body.password) { + Ok(Some(u)) => u, + Ok(None) => { + return Err(AppError { + status_code: StatusCode::UNAUTHORIZED, + description: "Invalid username or password".to_string(), + }); + } + Err(e) => { + error!("Login error: {e}"); + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Internal server error".to_string(), + }); + } + }; + + let token = match self.db.create_session(&user.id, SESSION_MAX_AGE_SECONDS) { + Ok(t) => t, + Err(e) => { + error!("Session creation error: {e}"); + return Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Internal server error".to_string(), + }); + } + }; + + Ok(Json(LoginResponse { token })) + } + + #[craft(handler)] + pub async fn logout(self: Arc, req: &mut Request) -> AppResult> { + let token = req + .headers() + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")) + .unwrap_or(""); + + let _ = self.db.delete_session(token); + + Ok(Json(serde_json::json!({"status": "ok"}))) + } + + #[craft(handler)] + pub async fn me(self: Arc, depot: &mut Depot) -> AppResult> { + let user = match get_user_from_depot(depot) { + Some(u) => u.clone(), + None => { + return Err(AppError { + status_code: StatusCode::UNAUTHORIZED, + description: "Not authenticated".to_string(), + }); + } + }; + + let permissions = self.db.get_permissions(&user.id).unwrap_or_default(); + + Ok(Json(MeResponse { + username: user.username, + is_admin: user.is_admin, + permissions, + })) + } + + // -- Admin endpoint handlers -- + + #[craft(handler)] + pub async fn admin_list_users(self: Arc) -> AppResult>> { + match self.db.list_users() { + Ok(users) => Ok(Json(users)), + Err(e) => { + error!("List users error: {e}"); + Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Failed to list users".to_string(), + }) + } + } + } + + #[craft(handler)] + pub async fn admin_create_user( + self: Arc, + body: salvo::oapi::extract::JsonBody, + ) -> AppResult> { + match self + .db + .create_user(&body.username, &body.password, body.is_admin) + { + Ok(user) => Ok(Json(user)), + Err(e) => { + error!("Create user error: {e}"); + Err(AppError { + status_code: StatusCode::BAD_REQUEST, + description: format!("Failed to create user: {e}"), + }) + } + } + } + + #[craft(handler)] + pub async fn admin_update_user( + self: Arc, + req: &mut Request, + body: salvo::oapi::extract::JsonBody, + ) -> AppResult> { + let user_id = req.param::("id").unwrap_or_default(); + + match self + .db + .update_user(&user_id, body.password.as_deref(), body.is_admin) + { + Ok(true) => Ok(Json(serde_json::json!({"status": "ok"}))), + Ok(false) => Err(AppError { + status_code: StatusCode::NOT_FOUND, + description: "User not found".to_string(), + }), + Err(e) => { + error!("Update user error: {e}"); + Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Failed to update user".to_string(), + }) + } + } + } + + #[craft(handler)] + pub async fn admin_delete_user( + self: Arc, + req: &mut Request, + ) -> AppResult> { + let user_id = req.param::("id").unwrap_or_default(); + + match self.db.delete_user(&user_id) { + Ok(true) => Ok(Json(serde_json::json!({"status": "ok"}))), + Ok(false) => Err(AppError { + status_code: StatusCode::NOT_FOUND, + description: "User not found".to_string(), + }), + Err(e) => { + error!("Delete user error: {e}"); + Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Failed to delete user".to_string(), + }) + } + } + } + + #[craft(handler)] + pub async fn admin_get_permissions( + self: Arc, + req: &mut Request, + ) -> AppResult>> { + let user_id = req.param::("id").unwrap_or_default(); + + match self.db.get_permissions(&user_id) { + Ok(perms) => Ok(Json(perms)), + Err(e) => { + error!("Get permissions error: {e}"); + Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Failed to get permissions".to_string(), + }) + } + } + } + + #[craft(handler)] + pub async fn admin_set_permissions( + self: Arc, + req: &mut Request, + body: salvo::oapi::extract::JsonBody, + ) -> AppResult> { + let user_id = req.param::("id").unwrap_or_default(); + + match self.db.set_permissions(&user_id, &body.permissions) { + Ok(()) => Ok(Json(serde_json::json!({"status": "ok"}))), + Err(e) => { + error!("Set permissions error: {e}"); + Err(AppError { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + description: "Failed to set permissions".to_string(), + }) + } + } + } +} diff --git a/gamestream-webtransport-proxy/src/backend.rs b/gamestream-webtransport-proxy/src/backend.rs index fb1f04d..56a4b08 100644 --- a/gamestream-webtransport-proxy/src/backend.rs +++ b/gamestream-webtransport-proxy/src/backend.rs @@ -5,6 +5,7 @@ use salvo::oapi::ToSchema; use serde::{Deserialize, Serialize}; use tokio::sync::RwLock; +use crate::db::Db; use crate::state::StateFile; #[derive(Debug, Clone, Deserialize, Serialize)] @@ -89,14 +90,25 @@ pub struct Backend { pub state: StateFile, pub streams: RwLock>, pub port: u16, + pub db: Db, } impl Backend { pub fn new(port: u16) -> Result { + let project_dirs = + directories::ProjectDirs::from("xyz", "ohea", "gamestream-webtransport-proxy") + .ok_or(anyhow::anyhow!("Could not get project dirs"))?; + let data_dir = project_dirs.data_dir(); + std::fs::create_dir_all(data_dir)?; + let db_path = data_dir.join("auth.db"); + + let db = Db::open(&db_path)?; + Ok(Backend { state: StateFile::new()?, streams: RwLock::new(HashMap::new()), port, + db, }) } } diff --git a/gamestream-webtransport-proxy/src/db.rs b/gamestream-webtransport-proxy/src/db.rs new file mode 100644 index 0000000..a04a978 --- /dev/null +++ b/gamestream-webtransport-proxy/src/db.rs @@ -0,0 +1,659 @@ +use std::path::Path; +use std::sync::Mutex; + +use anyhow::{Context, Result}; +use argon2::{ + Argon2, + password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString, rand_core::OsRng}, +}; +use salvo::oapi::ToSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct User { + pub id: String, + pub username: String, + pub is_admin: bool, + pub created_at: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct AppPermission { + pub server: String, + pub app_id: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Session { + pub token: String, + pub user_id: String, + pub created_at: String, + pub expires_at: String, +} + +pub struct Db { + conn: Mutex, +} + +impl Db { + pub fn open(path: &Path) -> Result { + let conn = rusqlite::Connection::open(path)?; + let db = Db { + conn: Mutex::new(conn), + }; + db.init()?; + Ok(db) + } + + fn init(&self) -> Result<()> { + let conn = self.conn.lock().unwrap(); + conn.execute_batch("PRAGMA foreign_keys = ON;")?; + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + username TEXT UNIQUE NOT NULL, + password TEXT NOT NULL, + is_admin INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS user_app_permissions ( + user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + server TEXT NOT NULL, + app_id INTEGER NOT NULL, + PRIMARY KEY (user_id, server, app_id) + ); + + CREATE TABLE IF NOT EXISTS sessions ( + token TEXT PRIMARY KEY, + user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + created_at TEXT NOT NULL, + expires_at TEXT NOT NULL + );", + )?; + Ok(()) + } + + pub fn seed_admin_if_needed(&self) -> Result> { + let conn = self.conn.lock().unwrap(); + let count: i64 = conn.query_row("SELECT COUNT(*) FROM users", [], |row| row.get(0))?; + if count > 0 { + return Ok(None); + } + drop(conn); + + let password = generate_random_password(); + let user = self.create_user("admin", &password, true)?; + Ok(Some((user.username, password))) + } + + pub fn create_user(&self, username: &str, password: &str, is_admin: bool) -> Result { + let id = uuid::Uuid::new_v4().to_string(); + let password_hash = hash_password(password)?; + let created_at = now_iso8601(); + + let conn = self.conn.lock().unwrap(); + conn.execute( + "INSERT INTO users (id, username, password, is_admin, created_at) VALUES (?1, ?2, ?3, ?4, ?5)", + rusqlite::params![id, username, password_hash, is_admin as i32, created_at], + ).context("Failed to create user (username may already exist)")?; + + Ok(User { + id, + username: username.to_string(), + is_admin, + created_at, + }) + } + + pub fn verify_password(&self, username: &str, password: &str) -> Result> { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare( + "SELECT id, username, password, is_admin, created_at FROM users WHERE username = ?1", + )?; + + let mut rows = stmt.query(rusqlite::params![username])?; + let row = match rows.next()? { + Some(r) => r, + None => return Ok(None), + }; + + let id: String = row.get(0)?; + let uname: String = row.get(1)?; + let stored_hash: String = row.get(2)?; + let is_admin: bool = row.get::<_, i32>(3)? != 0; + let created_at: String = row.get(4)?; + + let parsed_hash = + PasswordHash::new(&stored_hash).map_err(|e| anyhow::anyhow!("Invalid hash: {e}"))?; + if Argon2::default() + .verify_password(password.as_bytes(), &parsed_hash) + .is_err() + { + return Ok(None); + } + + Ok(Some(User { + id, + username: uname, + is_admin, + created_at, + })) + } + + pub fn get_user(&self, user_id: &str) -> Result> { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare( + "SELECT id, username, is_admin, created_at FROM users WHERE id = ?1", + )?; + + let mut rows = stmt.query(rusqlite::params![user_id])?; + match rows.next()? { + Some(row) => Ok(Some(User { + id: row.get(0)?, + username: row.get(1)?, + is_admin: row.get::<_, i32>(2)? != 0, + created_at: row.get(3)?, + })), + None => Ok(None), + } + } + + pub fn list_users(&self) -> Result> { + let conn = self.conn.lock().unwrap(); + let mut stmt = + conn.prepare("SELECT id, username, is_admin, created_at FROM users ORDER BY username")?; + let users = stmt + .query_map([], |row| { + Ok(User { + id: row.get(0)?, + username: row.get(1)?, + is_admin: row.get::<_, i32>(2)? != 0, + created_at: row.get(3)?, + }) + })? + .collect::, _>>()?; + Ok(users) + } + + pub fn update_user( + &self, + user_id: &str, + new_password: Option<&str>, + new_is_admin: Option, + ) -> Result { + let conn = self.conn.lock().unwrap(); + + if let Some(password) = new_password { + let hash = hash_password(password)?; + conn.execute( + "UPDATE users SET password = ?1 WHERE id = ?2", + rusqlite::params![hash, user_id], + )?; + } + + if let Some(is_admin) = new_is_admin { + conn.execute( + "UPDATE users SET is_admin = ?1 WHERE id = ?2", + rusqlite::params![is_admin as i32, user_id], + )?; + } + + let changed = conn.changes() > 0; + Ok(changed) + } + + pub fn delete_user(&self, user_id: &str) -> Result { + let conn = self.conn.lock().unwrap(); + conn.execute("PRAGMA foreign_keys = ON;", [])?; + let rows = conn.execute("DELETE FROM users WHERE id = ?1", rusqlite::params![user_id])?; + Ok(rows > 0) + } + + // Session management + + pub fn create_session(&self, user_id: &str, max_age_seconds: i64) -> Result { + let token = generate_session_token(); + let created_at = now_iso8601(); + let expires_at = future_iso8601(max_age_seconds); + + let conn = self.conn.lock().unwrap(); + conn.execute( + "INSERT INTO sessions (token, user_id, created_at, expires_at) VALUES (?1, ?2, ?3, ?4)", + rusqlite::params![token, user_id, created_at, expires_at], + )?; + + Ok(token) + } + + pub fn validate_session(&self, token: &str) -> Result> { + let conn = self.conn.lock().unwrap(); + let now = now_iso8601(); + + let mut stmt = conn.prepare( + "SELECT u.id, u.username, u.is_admin, u.created_at + FROM sessions s + JOIN users u ON s.user_id = u.id + WHERE s.token = ?1 AND s.expires_at > ?2", + )?; + + let mut rows = stmt.query(rusqlite::params![token, now])?; + match rows.next()? { + Some(row) => Ok(Some(User { + id: row.get(0)?, + username: row.get(1)?, + is_admin: row.get::<_, i32>(2)? != 0, + created_at: row.get(3)?, + })), + None => Ok(None), + } + } + + pub fn delete_session(&self, token: &str) -> Result { + let conn = self.conn.lock().unwrap(); + let rows = conn.execute( + "DELETE FROM sessions WHERE token = ?1", + rusqlite::params![token], + )?; + Ok(rows > 0) + } + + pub fn cleanup_expired_sessions(&self) -> Result { + let conn = self.conn.lock().unwrap(); + let now = now_iso8601(); + let rows = conn.execute( + "DELETE FROM sessions WHERE expires_at <= ?1", + rusqlite::params![now], + )?; + Ok(rows) + } + + // Permission management + + pub fn set_permissions(&self, user_id: &str, permissions: &[AppPermission]) -> Result<()> { + let conn = self.conn.lock().unwrap(); + conn.execute( + "DELETE FROM user_app_permissions WHERE user_id = ?1", + rusqlite::params![user_id], + )?; + + let mut stmt = conn.prepare( + "INSERT INTO user_app_permissions (user_id, server, app_id) VALUES (?1, ?2, ?3)", + )?; + for perm in permissions { + stmt.execute(rusqlite::params![user_id, perm.server, perm.app_id])?; + } + + Ok(()) + } + + pub fn get_permissions(&self, user_id: &str) -> Result> { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare( + "SELECT server, app_id FROM user_app_permissions WHERE user_id = ?1", + )?; + let perms = stmt + .query_map(rusqlite::params![user_id], |row| { + Ok(AppPermission { + server: row.get(0)?, + app_id: row.get(1)?, + }) + })? + .collect::, _>>()?; + Ok(perms) + } + + pub fn check_app_permission( + &self, + user_id: &str, + server: &str, + app_id: i64, + ) -> Result { + // Check if user is admin first + let conn = self.conn.lock().unwrap(); + let is_admin: i32 = conn.query_row( + "SELECT is_admin FROM users WHERE id = ?1", + rusqlite::params![user_id], + |row| row.get(0), + )?; + if is_admin != 0 { + return Ok(true); + } + + let count: i64 = conn.query_row( + "SELECT COUNT(*) FROM user_app_permissions WHERE user_id = ?1 AND server = ?2 AND app_id = ?3", + rusqlite::params![user_id, server, app_id], + |row| row.get(0), + )?; + Ok(count > 0) + } +} + +fn hash_password(password: &str) -> Result { + let salt = SaltString::generate(&mut OsRng); + let hash = Argon2::default() + .hash_password(password.as_bytes(), &salt) + .map_err(|e| anyhow::anyhow!("Failed to hash password: {e}"))?; + Ok(hash.to_string()) +} + +fn generate_session_token() -> String { + let mut bytes = [0u8; 32]; + openssl::rand::rand_bytes(&mut bytes).expect("Failed to generate random bytes"); + hex::encode(bytes) +} + +fn generate_random_password() -> String { + let mut bytes = [0u8; 16]; + openssl::rand::rand_bytes(&mut bytes).expect("Failed to generate random bytes"); + hex::encode(bytes) +} + +fn now_iso8601() -> String { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + .to_string() +} + +fn future_iso8601(seconds_from_now: i64) -> String { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + (now as i64 + seconds_from_now).to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_db() -> Db { + let conn = rusqlite::Connection::open_in_memory().unwrap(); + let db = Db { + conn: Mutex::new(conn), + }; + db.init().unwrap(); + db + } + + #[test] + fn test_create_and_get_user() { + let db = test_db(); + let user = db.create_user("alice", "password123", false).unwrap(); + assert_eq!(user.username, "alice"); + assert!(!user.is_admin); + + let fetched = db.get_user(&user.id).unwrap().unwrap(); + assert_eq!(fetched.username, "alice"); + assert_eq!(fetched.id, user.id); + } + + #[test] + fn test_verify_correct_password() { + let db = test_db(); + db.create_user("bob", "secret", false).unwrap(); + + let result = db.verify_password("bob", "secret").unwrap(); + assert!(result.is_some()); + assert_eq!(result.unwrap().username, "bob"); + } + + #[test] + fn test_verify_wrong_password() { + let db = test_db(); + db.create_user("bob", "secret", false).unwrap(); + + let result = db.verify_password("bob", "wrong").unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_verify_nonexistent_user() { + let db = test_db(); + let result = db.verify_password("nobody", "pass").unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_duplicate_username_rejected() { + let db = test_db(); + db.create_user("alice", "pass1", false).unwrap(); + let result = db.create_user("alice", "pass2", false); + assert!(result.is_err()); + } + + #[test] + fn test_list_users() { + let db = test_db(); + db.create_user("charlie", "pass", false).unwrap(); + db.create_user("alice", "pass", true).unwrap(); + + let users = db.list_users().unwrap(); + assert_eq!(users.len(), 2); + assert_eq!(users[0].username, "alice"); // sorted + assert_eq!(users[1].username, "charlie"); + } + + #[test] + fn test_update_user_password() { + let db = test_db(); + let user = db.create_user("dave", "oldpass", false).unwrap(); + + db.update_user(&user.id, Some("newpass"), None).unwrap(); + + assert!(db.verify_password("dave", "oldpass").unwrap().is_none()); + assert!(db.verify_password("dave", "newpass").unwrap().is_some()); + } + + #[test] + fn test_update_user_admin_status() { + let db = test_db(); + let user = db.create_user("eve", "pass", false).unwrap(); + assert!(!user.is_admin); + + db.update_user(&user.id, None, Some(true)).unwrap(); + let updated = db.get_user(&user.id).unwrap().unwrap(); + assert!(updated.is_admin); + } + + #[test] + fn test_delete_user() { + let db = test_db(); + let user = db.create_user("frank", "pass", false).unwrap(); + assert!(db.delete_user(&user.id).unwrap()); + assert!(db.get_user(&user.id).unwrap().is_none()); + } + + #[test] + fn test_delete_nonexistent_user() { + let db = test_db(); + assert!(!db.delete_user("nonexistent-id").unwrap()); + } + + #[test] + fn test_create_and_validate_session() { + let db = test_db(); + let user = db.create_user("grace", "pass", false).unwrap(); + + let token = db.create_session(&user.id, 3600).unwrap(); + let validated = db.validate_session(&token).unwrap(); + assert!(validated.is_some()); + assert_eq!(validated.unwrap().username, "grace"); + } + + #[test] + fn test_expired_session_rejected() { + let db = test_db(); + let user = db.create_user("heidi", "pass", false).unwrap(); + + // Create session that expired 10 seconds ago + let token = db.create_session(&user.id, -10).unwrap(); + let validated = db.validate_session(&token).unwrap(); + assert!(validated.is_none()); + } + + #[test] + fn test_invalid_token_rejected() { + let db = test_db(); + let validated = db.validate_session("bogus-token").unwrap(); + assert!(validated.is_none()); + } + + #[test] + fn test_delete_session() { + let db = test_db(); + let user = db.create_user("ivan", "pass", false).unwrap(); + let token = db.create_session(&user.id, 3600).unwrap(); + + assert!(db.delete_session(&token).unwrap()); + assert!(db.validate_session(&token).unwrap().is_none()); + } + + #[test] + fn test_delete_user_cascades_sessions() { + let db = test_db(); + let user = db.create_user("judy", "pass", false).unwrap(); + let token = db.create_session(&user.id, 3600).unwrap(); + + db.delete_user(&user.id).unwrap(); + assert!(db.validate_session(&token).unwrap().is_none()); + } + + #[test] + fn test_delete_user_cascades_permissions() { + let db = test_db(); + let user = db.create_user("karl", "pass", false).unwrap(); + db.set_permissions( + &user.id, + &[AppPermission { + server: "srv".to_string(), + app_id: 1, + }], + ) + .unwrap(); + + db.delete_user(&user.id).unwrap(); + // Permissions table should be empty for this user + let perms = db.get_permissions(&user.id).unwrap(); + assert!(perms.is_empty()); + } + + #[test] + fn test_set_and_get_permissions() { + let db = test_db(); + let user = db.create_user("laura", "pass", false).unwrap(); + + let perms = vec![ + AppPermission { + server: "server1".to_string(), + app_id: 10, + }, + AppPermission { + server: "server1".to_string(), + app_id: 20, + }, + ]; + db.set_permissions(&user.id, &perms).unwrap(); + + let fetched = db.get_permissions(&user.id).unwrap(); + assert_eq!(fetched.len(), 2); + } + + #[test] + fn test_set_permissions_replaces_existing() { + let db = test_db(); + let user = db.create_user("mike", "pass", false).unwrap(); + + db.set_permissions( + &user.id, + &[AppPermission { + server: "s1".to_string(), + app_id: 1, + }], + ) + .unwrap(); + + db.set_permissions( + &user.id, + &[AppPermission { + server: "s2".to_string(), + app_id: 2, + }], + ) + .unwrap(); + + let perms = db.get_permissions(&user.id).unwrap(); + assert_eq!(perms.len(), 1); + assert_eq!(perms[0].server, "s2"); + assert_eq!(perms[0].app_id, 2); + } + + #[test] + fn test_check_app_permission_allowed() { + let db = test_db(); + let user = db.create_user("nancy", "pass", false).unwrap(); + db.set_permissions( + &user.id, + &[AppPermission { + server: "srv".to_string(), + app_id: 42, + }], + ) + .unwrap(); + + assert!(db.check_app_permission(&user.id, "srv", 42).unwrap()); + } + + #[test] + fn test_check_app_permission_denied() { + let db = test_db(); + let user = db.create_user("oscar", "pass", false).unwrap(); + + assert!(!db.check_app_permission(&user.id, "srv", 42).unwrap()); + } + + #[test] + fn test_check_app_permission_admin_bypass() { + let db = test_db(); + let user = db.create_user("pat", "pass", true).unwrap(); + // Admin has no explicit permissions but should pass + assert!(db.check_app_permission(&user.id, "srv", 42).unwrap()); + } + + #[test] + fn test_cleanup_expired_sessions() { + let db = test_db(); + let user = db.create_user("quinn", "pass", false).unwrap(); + + let _expired = db.create_session(&user.id, -10).unwrap(); + let valid = db.create_session(&user.id, 3600).unwrap(); + + let cleaned = db.cleanup_expired_sessions().unwrap(); + assert_eq!(cleaned, 1); + + // Valid session should still work + assert!(db.validate_session(&valid).unwrap().is_some()); + } + + #[test] + fn test_seed_admin_if_needed() { + let db = test_db(); + + // First call should create admin + let result = db.seed_admin_if_needed().unwrap(); + assert!(result.is_some()); + let (username, password) = result.unwrap(); + assert_eq!(username, "admin"); + assert!(!password.is_empty()); + + // Verify can login with generated password + let user = db.verify_password("admin", &password).unwrap().unwrap(); + assert!(user.is_admin); + + // Second call should be a no-op + let result = db.seed_admin_if_needed().unwrap(); + assert!(result.is_none()); + } +} diff --git a/gamestream-webtransport-proxy/src/main.rs b/gamestream-webtransport-proxy/src/main.rs index 98552f9..d3e1161 100644 --- a/gamestream-webtransport-proxy/src/main.rs +++ b/gamestream-webtransport-proxy/src/main.rs @@ -3,9 +3,11 @@ use salvo::logging::Logger; use salvo::prelude::*; mod apps; +mod auth; mod backend; mod certs; mod common; +mod db; mod gamestream; mod pair; mod proxy; @@ -40,12 +42,71 @@ fn create_static_handler() -> impl Handler { async fn run_backend(port: u16) -> Result<()> { let backend = backend::Backend::new(port)?; + + // Seed default admin user if no users exist + if let Some((username, password)) = backend.db.seed_admin_if_needed()? { + tracing::info!("Created default admin user: {username}"); + println!("==========================================="); + println!(" Default admin credentials:"); + println!(" Username: {username}"); + println!(" Password: {password}"); + println!("==========================================="); + } + + // Clean up expired sessions on startup + if let Ok(cleaned) = backend.db.cleanup_expired_sessions() { + if cleaned > 0 { + tracing::info!("Cleaned up {cleaned} expired sessions"); + } + } + let backend_arc = std::sync::Arc::new(backend); + let auth_middleware = auth::SessionAuthMiddleware { + db: std::sync::Arc::new( + db::Db::open( + &directories::ProjectDirs::from("xyz", "ohea", "gamestream-webtransport-proxy") + .ok_or(anyhow!("Could not get project dirs"))? + .data_dir() + .join("auth.db"), + )?, + ), + }; + let router = Router::new() + // Public auth routes + .push(Router::with_path("api/auth/login").post(backend_arc.login())) + // Existing routes (not yet gated - will be gated in a subsequent commit) .push(Router::with_path("api/pair").post(backend_arc.post_pair())) .push(Router::with_path("api/apps").get(backend_arc.get_apps())) .push(Router::with_path("api/stream/start").post(backend_arc.post_stream_start())) + // Authenticated routes + .push( + Router::with_path("api") + .hoop(auth_middleware) + .push(Router::with_path("auth/logout").post(backend_arc.logout())) + .push(Router::with_path("auth/me").get(backend_arc.me())) + // Admin-only routes + .push( + Router::with_path("admin") + .hoop(auth::AdminCheckMiddleware) + .push( + Router::with_path("users") + .get(backend_arc.admin_list_users()) + .post(backend_arc.admin_create_user()), + ) + .push( + Router::with_path("users/") + .put(backend_arc.admin_update_user()) + .delete(backend_arc.admin_delete_user()), + ) + .push( + Router::with_path("users//permissions") + .get(backend_arc.admin_get_permissions()) + .put(backend_arc.admin_set_permissions()), + ), + ), + ) .push(Router::with_path("{*path}").get(create_static_handler())); let doc = OpenApi::new("test api", "0.0.1").merge_router(&router); let router = router @@ -66,8 +127,6 @@ async fn run_backend(port: u16) -> Result<()> { async fn run_proxy(port: u16, stream_id: uuid::Uuid) -> Result<()> { let (config, cert_hash) = certs::get_webtransport_stream_config(stream_id)?; - //let config = certs::get_http_stream_config()?; - //let cert_hash = [0; 32]; let proxy = proxy::Proxy::new(cert_hash); let proxy_arc = std::sync::Arc::new(proxy);