diff --git a/examples/generate_ts_sdk.rs b/examples/generate_ts_sdk.rs index fb58c038..ea762f53 100644 --- a/examples/generate_ts_sdk.rs +++ b/examples/generate_ts_sdk.rs @@ -1,3 +1,4 @@ +use std::fmt::Write as _; use std::io::Write as _; use ts_bindgen::{TypeRegistry, TypeScriptDef, TypeScriptType}; @@ -8,65 +9,92 @@ fn main() -> Result<(), Box> { client_sdk::models::gateway::message::ServerMsg::register(&mut registry); client_sdk::models::gateway::message::ClientMsg::register(&mut registry); + client_sdk::api::error::ApiError::register(&mut registry); + client_sdk::api::commands::register_routes(&mut registry); - let mut models = std::fs::File::create("autogenerated.ts")?; + // generate TypeScript bindings, all of them + + let mut autogenerated = std::fs::File::create("out/autogenerated.ts")?; - write!(models, "import type {{ ")?; + write!(autogenerated, "import type {{ ")?; for (idx, name) in registry.external().iter().enumerate() { if idx > 0 { - write!(models, ", ")?; + write!(autogenerated, ", ")?; } - write!(models, "{name}")?; + write!(autogenerated, "{name}")?; } write!( - models, - " }} from './models';\nimport {{ command }} from './api';\n\n{}", + autogenerated, + " }} from './models';\nimport {{ command }} from './api/command';\n\n{}", registry.display() )?; - let mut api = std::fs::File::create("api.ts")?; + // - for group in ["decl", "values", "types"] { - let mut first = true; - let mut len = 0; + let models = std::fs::File::create("out/models.ts")?; + let api = std::fs::File::create("out/api.ts")?; + let gateway = std::fs::File::create("out/gateway.ts")?; + let mut out = String::new(); - if group == "types" { - writeln!(api, "export type {{")?; - } else { - writeln!(api, "export {{")?; - } + for (mut file, tag) in [(models, ""), (api, "command"), (gateway, "gateway")] { + for group in ["decl", "values", "types"] { + let tys = registry.iter().filter(|(name, _)| { + if tag.is_empty() { + registry.type_tags(name).count() == 0 + } else { + registry.has_tag(name, tag) + } + }); - for (name, ty) in registry.iter() { - match group { - "decl" if matches!(ty, TypeScriptType::ApiDecl { .. }) => {} - "values" if ty.is_value() && !matches!(ty, TypeScriptType::ApiDecl { .. }) => {} - "types" if !ty.is_value() => {} - _ => continue, - } + let mut idx = 0; + + for (name, ty) in tys { + match group { + "decl" if matches!(ty, TypeScriptType::ApiDecl { .. }) => {} + "values" if ty.is_value() && !matches!(ty, TypeScriptType::ApiDecl { .. }) => {} + "types" if !ty.is_value() => {} + _ => continue, + } - if !first { - if len % 5 == 0 { - write!(api, ",\n ")?; + if idx == 0 { + write!(out, " {name}")?; + } else if idx % 5 == 0 { + write!(out, ",\n {name}")?; } else { - write!(api, ", ")?; + write!(out, ", {name}")?; } + + idx += 1; + } + + if out.is_empty() { + continue; + } + + let comment = match group { + "decl" => "API Command declarations", + "values" => "Exported const values", + "types" => "Exported types", + _ => unreachable!(), + }; + + if group == "types" { + writeln!(file, "/** {comment} */\nexport type {{")?; } else { - write!(api, " ")?; + writeln!(file, "/** {comment} */\nexport {{")?; } - first = false; + file.write_all(out.as_bytes())?; - write!(api, "{}", name)?; + out.clear(); - len += 1; + write!(file, "\n}} from '../autogenerated';\n\n")?; } - - write!(api, "\n}} from './autogenerated';\n\n")?; } Ok(()) diff --git a/src/api/command.rs b/src/api/command.rs index 94fa0c93..46b900e5 100644 --- a/src/api/command.rs +++ b/src/api/command.rs @@ -12,16 +12,18 @@ use crate::models::Permissions; bitflags2! { /// Flags for command functionality. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] - pub struct CommandFlags: u8 { + pub struct CommandFlags: u8 where "command" { /// Command requires authorization to execute. const AUTHORIZED = 1 << 0; /// Command has a request body. const HAS_BODY = 1 << 1; + const HAS_RESPONSE = 1 << 2; + const STREAMING = 1 << 3; - const BOTS_ONLY = 1 << 2; - const USERS_ONLY = 1 << 3; - const ADMIN_ONLY = 1 << 4; + const BOTS_ONLY = 1 << 5; + const USERS_ONLY = 1 << 6; + const ADMIN_ONLY = 1 << 7; } } @@ -120,9 +122,6 @@ impl core::error::Error for MissingItemError {} /// /// For the case of `GET`/`OPTIONS` commands, the body becomes query parameters. pub trait Command: sealed::Sealed { - /// Whether the command returns one or many items - const STREAM: bool; - /// Whether the command has a query string or sends a body const IS_QUERY: bool; @@ -275,9 +274,9 @@ macro_rules! command { (@GET TRACE $c:block) => {$c}; (@GET $other:ident $c:block) => {}; - (@IS_STREAM One) => { false }; - (@IS_STREAM Many) => { true }; - (@IS_STREAM $other:ident) => { compile_error!("Must use One or Many for Command result") }; + (@STREAMING One) => { CommandFlags::empty() }; + (@STREAMING Many) => { CommandFlags::STREAMING }; + (@STREAMING $other:ident) => { compile_error!("Must use One or Many for Command result") }; (@AGGREGATE One $ty:ty) => { $ty }; (@AGGREGATE Many $ty:ty) => { Vec<$ty> }; @@ -378,8 +377,6 @@ macro_rules! command { impl $crate::api::command::sealed::Sealed for $name {} impl $crate::api::command::Command for $name { - const STREAM: bool = command!(@IS_STREAM $count); - const IS_QUERY: bool = matches!( http::Method::$method, http::Method::GET | http::Method::OPTIONS | http::Method::HEAD | http::Method::CONNECT | http::Method::TRACE @@ -393,7 +390,14 @@ macro_rules! command { const HTTP_METHOD: http::Method = http::Method::$method; - const FLAGS: CommandFlags = CommandFlags::empty() + const FLAGS: CommandFlags = CommandFlags::empty().union(command!(@STREAMING $count)) + .union(const { + if size_of::<$result>() != 0 { + CommandFlags::HAS_RESPONSE + } else { + CommandFlags::empty() + } + }) $(.union((stringify!($body_name), CommandFlags::HAS_BODY).1))? $(.union((stringify!($auth_struct), CommandFlags::AUTHORIZED).1))? $( $(.union(CommandFlags::$flag))* )? @@ -634,6 +638,8 @@ macro_rules! command { registry.insert(stringify!($name), ty, concat!($(command!(@DOC #[$($meta)*])),*).trim()); + registry.tag(stringify!($name), "command"); + TypeScriptType::Named(stringify!($name)) } } @@ -649,7 +655,7 @@ macro_rules! command { #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "command"))] $(#[$body_meta])* pub struct $body_name { $( $(#[$($body_field_meta)*])* $body_field_vis $body_field_name: $body_field_ty ),* @@ -750,9 +756,7 @@ macro_rules! command_module { #[cfg(feature = "ts")] pub fn register_routes(registry: &mut ts_bindgen::TypeRegistry) { - $( - paste::paste! { $mod::[](registry); } - )* + paste::paste! { $( $mod::[](registry); )* } } // TODO: Collect schemas from each object diff --git a/src/api/commands/file.rs b/src/api/commands/file.rs index f006811a..45330d40 100644 --- a/src/api/commands/file.rs +++ b/src/api/commands/file.rs @@ -43,7 +43,7 @@ command! { File; #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] -#[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] +#[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "command"))] pub struct FilesystemStatus { pub quota_used: i64, pub quota_total: i64, @@ -52,7 +52,7 @@ pub struct FilesystemStatus { #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] -#[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] +#[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "command"))] pub struct FileStatus { pub complete: u32, pub upload_offset: u64, diff --git a/src/api/commands/party.rs b/src/api/commands/party.rs index 4c58ab22..4f7136af 100644 --- a/src/api/commands/party.rs +++ b/src/api/commands/party.rs @@ -268,6 +268,7 @@ command! { Party; decl_enum! { #[derive(Default, serde_repr::Serialize_repr, serde_repr::Deserialize_repr)] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "command"))] pub enum CreateRoomKind: u8 { #[default] 0 = Text, @@ -280,7 +281,7 @@ decl_enum! { #[cfg_attr(feature = "typed-builder", derive(typed_builder::TypedBuilder))] #[cfg_attr(feature = "bon", derive(bon::Builder))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] -#[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] +#[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "command"))] pub struct PartySettings { pub flags: PartyFlags, pub prefs: PartyPreferences, diff --git a/src/api/commands/user.rs b/src/api/commands/user.rs index cfc8046b..7f5cd4d7 100644 --- a/src/api/commands/user.rs +++ b/src/api/commands/user.rs @@ -226,6 +226,7 @@ impl From for UpdateUserPrefsBody { decl_enum! { #[derive(Default, serde_repr::Deserialize_repr, serde_repr::Serialize_repr)] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "command"))] pub enum BannerAlign: u8 { #[default] 0 = Top, @@ -236,9 +237,9 @@ decl_enum! { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] -#[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] +#[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "command"))] pub struct Added2FA { - /// URL to be display as a QR code and added to an authenticator app + /// URL to be displayed as a QR code and added to an authenticator app pub url: String, /// Backup codes to be stored in a safe place pub backup: Vec, diff --git a/src/api/error.rs b/src/api/error.rs index 5e59a8ba..537d3caf 100644 --- a/src/api/error.rs +++ b/src/api/error.rs @@ -8,6 +8,11 @@ use http::StatusCode; #[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] +#[cfg_attr( + feature = "ts", + derive(ts_bindgen::TypeScriptDef), + ts(tag = "command", rename = "RawApiError") // we use a separate ApiError class in TypeScript +)] pub struct ApiError { /// Error code pub code: ApiErrorCode, @@ -73,6 +78,7 @@ error_codes! { /// Standard API error codes. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde_repr::Serialize_repr, serde_repr::Deserialize_repr)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema_repr))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "command", non_const))] #[derive(enum_primitive_derive::Primitive)] pub enum ApiErrorCode: u16 = Unknown { // Server errors diff --git a/src/models/auth.rs b/src/models/auth.rs index dc4e30a1..e151a007 100644 --- a/src/models/auth.rs +++ b/src/models/auth.rs @@ -27,7 +27,8 @@ const MAX_LENGTH: usize = { #[derive(Debug, Clone, Copy, Serialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] -#[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] +// rename the exported TypeScript to "RawAuthToken" since it's just a string +#[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(rename = "RawAuthToken"))] #[serde(untagged)] pub enum AuthToken { /// Bearer token for users, has a fixed length of 28 bytes. diff --git a/src/models/gateway.rs b/src/models/gateway.rs index faec6389..942121c6 100644 --- a/src/models/gateway.rs +++ b/src/models/gateway.rs @@ -4,7 +4,7 @@ use super::*; bitflags2! { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] - pub struct Intent: u32 { + pub struct Intent: u32 where "gateway" { /// - PARTY_CREATE /// - PARTY_UPDATE /// - PARTY_DELETE @@ -93,7 +93,7 @@ pub mod commands { #[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct Identify { pub auth: AuthToken, pub intent: Intent, @@ -102,7 +102,7 @@ pub mod commands { #[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct SetPresence { #[serde(flatten)] pub presence: UserPresence, @@ -115,7 +115,7 @@ pub mod events { #[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct Hello { /// Number of milliseconds between heartbeats pub heartbeat_interval: u32, @@ -132,7 +132,7 @@ pub mod events { #[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct ReadyParty { pub party: Party, @@ -166,7 +166,7 @@ pub mod events { #[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct Ready { pub user: User, @@ -183,7 +183,7 @@ pub mod events { #[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct TypingStart { pub room_id: RoomId, pub party_id: PartyId, @@ -198,7 +198,7 @@ pub mod events { #[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct PartyPositionUpdate { pub id: PartyId, pub position: i16, @@ -207,7 +207,7 @@ pub mod events { #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct UserPresenceEvent { #[serde(default, skip_serializing_if = "Option::is_none")] pub party_id: Option, @@ -218,7 +218,7 @@ pub mod events { #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct MessageDeleteEvent { pub id: MessageId, pub room_id: RoomId, @@ -228,7 +228,7 @@ pub mod events { #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct RoleDeleteEvent { pub id: RoleId, pub party_id: PartyId, @@ -237,7 +237,7 @@ pub mod events { #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct RoomDeleteEvent { pub id: RoomId, @@ -249,7 +249,7 @@ pub mod events { #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct PartyMemberEvent { pub party_id: PartyId, @@ -260,7 +260,7 @@ pub mod events { #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] #[serde(untagged)] pub enum PartyUpdateEvent { Position(PartyPositionUpdate), @@ -270,7 +270,7 @@ pub mod events { #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct UserReactionEvent { pub user_id: UserId, pub room_id: RoomId, @@ -285,7 +285,7 @@ pub mod events { #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] pub struct ProfileUpdateEvent { #[serde(default, skip_serializing_if = "Option::is_none")] pub party_id: Option, @@ -330,7 +330,7 @@ pub mod message { #[doc = "OpCodes for [`" $name "`]"] #[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema_repr))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] #[repr(u8)] pub enum [<$name Opcode>] { $($opcode = $code,)* @@ -346,7 +346,7 @@ pub mod message { #[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef), ts(tag = "gateway"))] $(#[derive($Default, PartialEq, Eq)])? pub struct [<$opcode Payload>] { $($(#[$field_meta])* pub $field : $ty,)* @@ -562,6 +562,8 @@ pub mod message { registry.insert(stringify!($name), TypeScriptType::Union(variants), concat!("Union of all ", stringify!($name), " messages")); + registry.tag(stringify!($name), "gateway"); + TypeScriptType::Named(stringify!($name)) } } diff --git a/src/models/message.rs b/src/models/message.rs index 7a383049..ccf6ee5e 100644 --- a/src/models/message.rs +++ b/src/models/message.rs @@ -42,6 +42,7 @@ decl_enum! { #[derive(Default)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] #[derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] #[derive(enum_primitive_derive::Primitive)] pub enum MessageKind: i16 { #[default] diff --git a/src/models/user/mod.rs b/src/models/user/mod.rs index f28f27c8..aec29d94 100644 --- a/src/models/user/mod.rs +++ b/src/models/user/mod.rs @@ -261,6 +261,7 @@ mod tests { enum_codes! { #[derive(Default, Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Hash)] #[derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] #[derive(enum_primitive_derive::Primitive)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] pub enum UserRelationship: i8 = None { diff --git a/src/models/user/prefs.rs b/src/models/user/prefs.rs index fab0c430..2484e261 100644 --- a/src/models/user/prefs.rs +++ b/src/models/user/prefs.rs @@ -2,6 +2,7 @@ use super::*; enum_codes! { #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash, serde_repr::Serialize_repr, serde_repr::Deserialize_repr)] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] #[allow(non_camel_case_types)] pub enum Locale: u16 = enUS { #[default] @@ -11,6 +12,7 @@ enum_codes! { enum_codes! { #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash, serde_repr::Serialize_repr, serde_repr::Deserialize_repr)] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] #[cfg_attr(feature = "ts", ts(non_const))] pub enum Font: u16 = SansSerif { #[default] @@ -28,6 +30,7 @@ enum_codes! { enum_codes! { #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash, serde_repr::Serialize_repr, serde_repr::Deserialize_repr)] + #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] pub enum FriendAddability: u8 = None { #[default] 0 = None, diff --git a/src/models/util/macros.rs b/src/models/util/macros.rs index 6d8d7206..93f52192 100644 --- a/src/models/util/macros.rs +++ b/src/models/util/macros.rs @@ -4,7 +4,7 @@ macro_rules! bitflags2 { ( $(#[$outer:meta])* - $vis:vis struct $BitFlags:ident: $T:ty { + $vis:vis struct $BitFlags:ident: $T:ty $(where $tag:literal)? { $( $(#[$inner:ident $($args:tt)*])* const $Flag:tt = $value:expr; @@ -141,6 +141,8 @@ macro_rules! bitflags2 { concat!("Bitflags for ", stringify!($BitFlags)), ); + $( registry.tag(name, $tag); )? + ty } } @@ -168,7 +170,6 @@ macro_rules! enum_codes { } ) => { rkyv_rpc::enum_codes! { - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] $(#[$meta])* $vis enum $name: $archived_vis $repr $(= $unknown)? { $($(#[$variant_meta])* $code = $variant,)* @@ -185,7 +186,6 @@ macro_rules! enum_codes { $($(#[$variant_meta:meta])* $code:literal = $variant:ident,)* } ) => { - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] $(#[$meta])* #[repr($repr)] $vis enum $name { @@ -206,7 +206,6 @@ macro_rules! decl_enum { ) => { rkyv_rpc::unit_enum! { $(#[$meta])* - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] $vis enum $name: $repr { $($(#[$variant_meta])* $code = $variant,)* } @@ -225,7 +224,6 @@ macro_rules! decl_enum { $(#[$meta])* #[repr($repr)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] - #[cfg_attr(feature = "ts", derive(ts_bindgen::TypeScriptDef))] $vis enum $name { $($(#[$variant_meta])* $variant = $code,)* } diff --git a/ts-bindgen/src/impls/alloc_impl.rs b/ts-bindgen/src/impls/alloc_impl.rs index c4b37d6d..80e70393 100644 --- a/ts-bindgen/src/impls/alloc_impl.rs +++ b/ts-bindgen/src/impls/alloc_impl.rs @@ -1,5 +1,6 @@ extern crate alloc; +use alloc::borrow::Cow; use alloc::boxed::Box; use alloc::rc::Rc; use alloc::string::String; @@ -43,3 +44,13 @@ impl TypeScriptDef for VecDeque { T::register(registry).into_array() } } + +impl<'a, T: TypeScriptDef> TypeScriptDef for Cow<'a, T> +where + T: ?Sized + 'a, + T: ToOwned, +{ + fn register(registry: &mut TypeRegistry) -> TypeScriptType { + T::register(registry) + } +} diff --git a/ts-bindgen/src/registry.rs b/ts-bindgen/src/registry.rs index 145812dd..40d6ce56 100644 --- a/ts-bindgen/src/registry.rs +++ b/ts-bindgen/src/registry.rs @@ -1,6 +1,6 @@ use indexmap::{IndexMap, IndexSet}; -use std::borrow::Cow; +use std::{borrow::Cow, collections::HashMap}; use crate::TypeScriptType; @@ -8,6 +8,7 @@ use crate::TypeScriptType; pub struct TypeRegistry { // use IndexMap to preserve the insertion order types: IndexMap<&'static str, (TypeScriptType, Cow<'static, str>)>, + tags: HashMap<&'static str, Vec<&'static str>>, external: IndexSet>, } @@ -36,6 +37,30 @@ impl TypeRegistry { self.types.insert(name, (ty, comment.into())); } + /// Adds an arbitrary string tag to a named type + pub fn tag(&mut self, name: &'static str, tag: &'static str) { + self.tags.entry(tag).or_default().push(name); + } + + /// Returns all tags for the given type + pub fn get_tags(&self, tag: &'static str) -> Option<&[&'static str]> { + self.tags.get(tag).map(|v| &v[..]) + } + + /// Returns true if the type has the given tag + pub fn has_tag(&self, name: &'static str, tag: &'static str) -> bool { + self.tags.get(tag).map_or(false, |v| v.contains(&name)) + } + + /// Returns all types with the given tag + pub fn tagged_types(&self, tag: &'static str) -> impl Iterator { + self.tags.get(tag).into_iter().flat_map(move |v| v.iter().map(move |&name| (name, &self.types[name].0))) + } + + pub fn type_tags(&self, ty: &'static str) -> impl Iterator + '_ { + self.tags.iter().filter_map(move |(tag, names)| if names.contains(&ty) { Some(*tag) } else { None }) + } + pub fn get(&self, name: &'static str) -> Option<&TypeScriptType> { self.types.get(name).map(|(ty, _)| ty) } @@ -223,8 +248,6 @@ impl TypeRegistry { body_type, path, } => { - let parse_response = **return_type != TypeScriptType::Null; - let body_type = match body_type { Some(ty) => ty, None => &TypeScriptType::Null, @@ -236,10 +259,6 @@ impl TypeRegistry { method.to_lowercase() )?; - if parse_response { - writeln!(out, " parse: true,")?; - } - if path.contains("${") { writeln!(out, " path() {{ return `{path}`; }},")?; } else { diff --git a/ts-bindgen/ts-bindgen-macros/src/lib.rs b/ts-bindgen/ts-bindgen-macros/src/lib.rs index 3c710e6c..a005dab8 100644 --- a/ts-bindgen/ts-bindgen-macros/src/lib.rs +++ b/ts-bindgen/ts-bindgen-macros/src/lib.rs @@ -21,6 +21,8 @@ pub fn derive_typescript_def(input: proc_macro::TokenStream) -> proc_macro::Toke non_const: false, includes: Vec::new(), max: false, + tags: Vec::new(), + rename: None, comment: extract_doc_comments(&input.attrs), }; @@ -32,12 +34,18 @@ pub fn derive_typescript_def(input: proc_macro::TokenStream) -> proc_macro::Toke return e.into_compile_error().into(); } - let name = input.ident; + let name = Ident::new( + match attrs.rename { + Some(ref name) => name, + None => attrs.serde.name().serialize_name(), + }, + input.ident.span(), + ); - let includes = &attrs.includes; - let includes = quote! { - #( #includes::register(registry); )* - }; + let rust_name = input.ident; + + let includes = std::mem::take(&mut attrs.includes); + let tags = std::mem::take(&mut attrs.tags); let inner = match input.data { Data::Enum(data) => derive_enum(data, name.clone(), attrs), @@ -46,15 +54,19 @@ pub fn derive_typescript_def(input: proc_macro::TokenStream) -> proc_macro::Toke }; proc_macro::TokenStream::from(quote! { - impl ts_bindgen::TypeScriptDef for #name { + impl ts_bindgen::TypeScriptDef for #rust_name { fn register(registry: &mut ts_bindgen::TypeRegistry) -> ts_bindgen::TypeScriptType { if registry.contains(stringify!(#name)) { return ts_bindgen::TypeScriptType::Named(stringify!(#name)); } - #includes + #( #includes::register(registry); )* + + #( registry.tag(stringify!(#name), #tags); )* #inner + + ts_bindgen::TypeScriptType::Named(stringify!(#name)) } } }) @@ -77,6 +89,12 @@ struct ItemAttributes { /// Include other types in the generated register function. includes: Vec, + + /// Tags to give the type in the registry + tags: Vec, + + /// Only rename the TypeScript type, not the serde type. + rename: Option, } impl ItemAttributes { @@ -107,6 +125,18 @@ impl ItemAttributes { })?; } + if meta.path.is_ident("tag") { + let tag: syn::LitStr = meta.value()?.parse()?; + + self.tags.push(tag.value()); + } + + if meta.path.is_ident("rename") { + let rename: syn::LitStr = meta.value()?.parse()?; + + self.rename = Some(rename.value()); + } + Ok(()) })?; } @@ -123,12 +153,9 @@ fn derive_struct(input: syn::DataStruct, name: Ident, attrs: ItemAttributes) -> // unit types are just null if let Fields::Unit = input.fields { out.extend(if attrs.inline { - quote! { ts_bindgen::TypeScriptType::Null } + quote! { return ts_bindgen::TypeScriptType::Null; } } else { - quote! { - registry.insert(stringify!(#name), ts_bindgen::TypeScriptType::Null, #struct_comment); - ts_bindgen::TypeScriptType::Named(stringify!(#name)) - } + quote! { registry.insert(stringify!(#name), ts_bindgen::TypeScriptType::Null, #struct_comment); } }); return out; @@ -192,16 +219,13 @@ fn derive_struct(input: syn::DataStruct, name: Ident, attrs: ItemAttributes) -> cmt }); - - ts_bindgen::TypeScriptType::Named(stringify!(#name)) } } } else if attrs.inline { - quote! { ts_bindgen::TypeScriptType::Tuple(fields) } + quote! { return ts_bindgen::TypeScriptType::Tuple(fields); } } else { quote! { registry.insert(stringify!(#name), ts_bindgen::TypeScriptType::Tuple(fields), #struct_comment); - ts_bindgen::TypeScriptType::Named(stringify!(#name)) } }); } else { @@ -248,14 +272,12 @@ fn derive_struct(input: syn::DataStruct, name: Ident, attrs: ItemAttributes) -> let num_extends = flattened.len(); out.extend(if attrs.inline { - quote! { ts_bindgen::TypeScriptType::interface(members, #num_extends) #(.flatten(#flattened))*; } + quote! { return ts_bindgen::TypeScriptType::interface(members, #num_extends) #(.flatten(#flattened))*; } } else { quote! { let ty = ts_bindgen::TypeScriptType::interface(members, #num_extends) #(.flatten(#flattened))*; registry.insert(stringify!(#name), ty, #struct_comment); - - ts_bindgen::TypeScriptType::Named(stringify!(#name)) } }); } @@ -358,11 +380,7 @@ fn derive_enum(input: syn::DataEnum, name: Ident, attrs: ItemAttributes) -> Toke out.extend(quote! { let ty = ts_bindgen::TypeScriptType::#ty(variants); }); } - out.extend(quote! { - registry.insert(stringify!(#name), ty, #enum_comment); - - ts_bindgen::TypeScriptType::Named(stringify!(#name)) - }); + out.extend(quote! { registry.insert(stringify!(#name), ty, #enum_comment); }); return out; } @@ -607,10 +625,7 @@ fn derive_enum(input: syn::DataEnum, name: Ident, attrs: ItemAttributes) -> Toke } out.extend(quote! { - let ty = ts_bindgen::TypeScriptType::Union(variants); - registry.insert(stringify!(#name), ty, #enum_comment); - - ts_bindgen::TypeScriptType::Named(stringify!(#name)) + registry.insert(stringify!(#name), ts_bindgen::TypeScriptType::Union(variants), #enum_comment); }); out