diff --git a/src/overrides.rs b/src/overrides.rs index 4c888f4..309f2f9 100644 --- a/src/overrides.rs +++ b/src/overrides.rs @@ -2,12 +2,14 @@ use syn::{Attribute, Meta, NestedMeta, Lit}; pub struct Overrides { pub name: Option, + pub transparent: bool, } impl Overrides { pub fn extract(attrs: &[Attribute]) -> Result { let mut overrides = Overrides { name: None, + transparent: false, }; for attr in attrs { @@ -39,11 +41,22 @@ impl Overrides { overrides.name = Some(value); }, - _ => return Err("expected a name-value meta item".to_owned()), + NestedMeta::Meta(Meta::Word(ref meta)) => { + if meta.as_ref() == "transparent" { + overrides.transparent = true; + } else { + return Err(format!("unknown override `{}`", meta.as_ref())); + } + } + _ => return Err("expected a name-value or word meta item".to_owned()), } } } + if overrides.name.is_some() && overrides.transparent { + return Err("overrides `name` and `transparent` may not be used at the same time".to_owned()) + } + Ok(overrides) } } diff --git a/src/tosql.rs b/src/tosql.rs index 4ea89ab..c1abf4c 100644 --- a/src/tosql.rs +++ b/src/tosql.rs @@ -11,6 +11,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { let overrides = Overrides::extract(&input.attrs)?; let name = overrides.name.unwrap_or_else(|| input.ident.to_string()); + let transparent = overrides.transparent; let (accepts_body, to_sql_body) = match input.data { Data::Enum(ref data) => { @@ -18,8 +19,13 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { (accepts::enum_body(&name, &variants), enum_body(&input.ident, &variants)) } Data::Struct(DataStruct { fields: Fields::Unnamed(ref fields), .. }) if fields.unnamed.len() == 1 => { - let field = fields.unnamed.first().unwrap().into_value(); - (domain_accepts_body(&name, &field), domain_body()) + if transparent { + let field = fields.unnamed.first().unwrap().into_value(); + (transparent_accepts_body(&field), transparent_body()) + } else { + let field = fields.unnamed.first().unwrap().into_value(); + (domain_accepts_body(&name, &field), domain_body()) + } } Data::Struct(DataStruct { fields: Fields::Named(ref fields), .. }) => { let fields = fields.named.iter().map(Field::parse).collect::, _>>()?; @@ -73,6 +79,20 @@ fn enum_body(ident: &Ident, variants: &[Variant]) -> Tokens { } } +fn transparent_accepts_body(field: &syn::Field) -> Tokens { + let ty = &field.ty; + + quote! { + <#ty as ::postgres::types::ToSql>::accepts(type_) + } +} + +fn transparent_body() -> Tokens { + quote! { + ::postgres::types::ToSql::to_sql(&self.0, _type, buf) + } +} + fn domain_accepts_body(name: &str, field: &syn::Field) -> Tokens { let ty = &field.ty; diff --git a/tests/wrapper.rs b/tests/wrapper.rs new file mode 100644 index 0000000..f482e50 --- /dev/null +++ b/tests/wrapper.rs @@ -0,0 +1,49 @@ +#[macro_use] +extern crate postgres_derive; +#[macro_use] +extern crate postgres; + +use postgres::{Connection, TlsMode}; +use postgres::types::WrongType; + +mod util; + +#[test] +fn transparent() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(transparent)] + struct ResourceId(i32); + + let conn = Connection::connect("postgres://postgres:password@localhost", TlsMode::None) + .unwrap(); + + util::test_type( + &conn, + "\"int4\"", + &[ + ( + ResourceId(123), + "123", + ), + ( + ResourceId(-27), + "-27", + ), + ], + ); +} + +#[test] +fn wrong_type() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(transparent)] + struct ResourceId(i32); + + let conn = Connection::connect("postgres://postgres:password@localhost", TlsMode::None) + .unwrap(); + + let err = conn.execute("SELECT $1::date", &[&ResourceId(0)]) + .unwrap_err(); + assert!(err.as_conversion().unwrap().is::()); +} +