Skip to content

Commit

Permalink
Rewrite connection setup, hide some types, Update Docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Einliterflasche committed Aug 29, 2023
1 parent e119bd4 commit 1e6ded3
Show file tree
Hide file tree
Showing 11 changed files with 504 additions and 202 deletions.
2 changes: 1 addition & 1 deletion pg-worm-derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0.56"
quote = "1.0.27"
syn = { version = "2.0.15", features = ["full"] }
syn = { version = "2.0.15", features = ["derive"] }
darling = "0.20"
postgres-types = "0.2"
145 changes: 57 additions & 88 deletions pg-worm-derive/src/parse.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use darling::{ast::Data, FromDeriveInput, FromField};
use darling::{ast::Data, Error, FromDeriveInput, FromField};
use postgres_types::Type;
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{GenericArgument, Ident, PathArguments};
use syn::{Ident, PathArguments};

#[derive(FromDeriveInput)]
#[darling(attributes(table), supports(struct_named))]
Expand All @@ -18,7 +18,6 @@ pub struct ModelInput {
pub struct ModelField {
ident: Option<syn::Ident>,
ty: syn::Type,
dtype: Option<String>,
column_name: Option<String>,
#[darling(default)]
auto: bool,
Expand Down Expand Up @@ -64,15 +63,15 @@ impl ModelInput {

/// Generate the SQL statement needed to create
/// the table corresponding to the input.
fn table_creation_sql(&self) -> String {
format!(
fn table_creation_sql(&self) -> Result<String, Error> {
Ok(format!(
"CREATE TABLE {} ({})",
self.table_name(),
self.all_fields()
.map(|f| f.column_creation_sql())
.collect::<Vec<String>>()
.collect::<Result<Vec<String>, Error>>()?
.join(", ")
)
))
}

/// Generate all code needed.
Expand Down Expand Up @@ -102,7 +101,10 @@ impl ModelInput {
fn impl_model(&self) -> TokenStream {
let ident = self.ident();
let table_name = self.table_name();
let creation_sql = self.table_creation_sql();
let creation_sql = match self.table_creation_sql() {
Ok(res) => quote!(#res),
Err(err) => err.write_errors(),
};

let select = self.impl_select();
let delete = self.impl_delete();
Expand Down Expand Up @@ -314,6 +316,12 @@ impl ModelInput {
}
}

macro_rules! spanned_error {
($msg:expr, $err:expr) => {
return Err(darling::Error::custom($msg).with_span($err))
};
}

impl ModelField {
/// Initialization function called before each
/// field is stored.
Expand All @@ -322,10 +330,12 @@ impl ModelField {

// Extract relevant type from the path
let syn::Type::Path(path) = ty else {
panic!("field type must be valid path");
spanned_error!("unsupported type", &ty)
};
let path = &path.path;
let last_seg = path.segments.last().expect("must provide type");
let Some(last_seg) = path.segments.last() else {
spanned_error!("invalid path (needs at least one segment)", &ty)
};

match last_seg.ident.to_string().as_str() {
// If it's an Option<T>, set the field nullable
Expand Down Expand Up @@ -354,102 +364,61 @@ impl ModelField {
self.ident().to_string().to_lowercase()
}

/// Get the corresponding column's PostgreSQL datatype.
fn pg_datatype(&self) -> Type {
fn from_str(ty: &str) -> Type {
match ty {
"bool" | "boolean" => Type::BOOL,
"text" => Type::TEXT,
"int" | "integer" | "int4" => Type::INT4,
"bigint" | "int8" => Type::INT8,
"smallint" | "int2" => Type::INT2,
"real" => Type::FLOAT4,
"double precision" => Type::FLOAT8,
"bigserial" => Type::INT8,
_ => panic!("couldn't find postgres type `{}`", ty),
}
}
/// Get the corresponding postgres type
fn try_pg_datatype(&self) -> Result<Type, Error> {
let ty = self.ty.clone();

fn from_type(ty: &Ident) -> Type {
match ty.to_string().as_str() {
"String" => Type::TEXT,
"i32" => Type::INT4,
"i64" => Type::INT8,
"f32" => Type::FLOAT4,
"f64" => Type::FLOAT8,
"bool" => Type::BOOL,
_ => panic!("cannot map rust type to postgres type: {ty}"),
}
}

if let Some(dtype) = &self.dtype {
return from_str(dtype.as_str());
}

let syn::Type::Path(type_path) = &self.ty else {
panic!("field type must be path; no reference, impl, etc. allowed")
let syn::Type::Path(path) = &self.ty else {
spanned_error!("pg-worm: unsupported type, must be a TypePath", &ty)
};

let segment = type_path
.path
.segments
.last()
.expect("field type must have a last segment");
let args = &segment.arguments;

if segment.ident.to_string().as_str() == "Option" {
// Extract `T` from `Option<T>`
let PathArguments::AngleBracketed(args) = args else {
panic!("field of type option needs angle bracketed argument")
};
let GenericArgument::Type(arg) = args.args.first().expect("Option needs to have generic argument") else {
panic!("generic argument for Option must be concrete type")
};
let syn::Type::Path(type_path) = arg else {
panic!("generic arg for Option must be path")
};
let Some(segment) = path.path.segments.last() else {
spanned_error!("pg-worm: unsupported type path, must have at least one segment", &ty)
};

let ident = &type_path
.path
.segments
.first()
.expect("generic arg for Option must have segment")
.ident;
let mut id = &segment.ident;

return from_type(ident);
}
if self.array || self.nullable {
let PathArguments::AngleBracketed(args) = &segment.arguments else {
spanned_error!("pg-worm: unsupported type, Option/Vec need generic argument", &ty)
};

if segment.ident.to_string().as_str() == "Vec" {
// Extract `T` from `Option<T>`
let PathArguments::AngleBracketed(args) = args else {
panic!("field of type Vec needs angle bracketed argument")
let Some(arg) = args.args.first() else {
spanned_error!("pg-worm: unsupported type, Option/Vec need generic argument", &ty)
};
let GenericArgument::Type(arg) = args.args.first().expect("Vec needs to have generic argument") else {
panic!("generic argument for Vec must be concrete type")

let syn::GenericArgument::Type(arg_type) = arg else {
spanned_error!("pg-worm: unsupported Option/Vec generic argument, must be valid type", &ty)
};
let syn::Type::Path(type_path) = arg else {
panic!("generic arg for Vec must be path")

let syn::Type::Path(path) = &arg_type else {
spanned_error!("pg-worm: unsupported type, must be a TypePath", &ty)
};

let ident = &type_path
.path
.segments
.first()
.expect("generic arg for Vec must have segment")
.ident;
let Some(segment) = path.path.segments.last() else {
spanned_error!("pg-worm: unsupported type path, must have at least one segment", &ty)
};

return from_type(ident);
id = &segment.ident;
}

from_type(&segment.ident)
Ok(match id.to_string().as_ref() {
"String" => Type::TEXT,
"i32" => Type::INT4,
"i64" => Type::INT8,
"f32" => Type::FLOAT4,
"f64" => Type::FLOAT8,
"bool" => Type::BOOL,
_ => spanned_error!("pg-worm: unsupported type, check docs", &ty),
})
}

/// Get the SQL representing the column needed
/// for creating a table.
fn column_creation_sql(&self) -> String {
fn column_creation_sql(&self) -> Result<String, Error> {
// The list of "args" for the sql statement.
// Includes at least the column name and datatype.
let mut args = vec![self.column_name(), self.pg_datatype().to_string()];
let mut args = vec![self.column_name(), self.try_pg_datatype()?.to_string()];

// This macro allows adding an arg to the list
// under a given condition.
Expand All @@ -469,7 +438,7 @@ impl ModelField {
arg!(!(self.primary_key || self.nullable), "NOT NULL");

// Join the args, seperated by a space and return them
args.join(" ")
Ok(args.join(" "))
}

/// The datatype which should be provided when
Expand Down
1 change: 1 addition & 0 deletions pg-worm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ thiserror = "1.0"
deadpool-postgres = "0.10"
tokio-postgres = "0.7"
async-trait = "0.1"
futures = "0.3"

pg-worm-derive = { version = "0.5", path = "../pg-worm-derive" }

Expand Down
Loading

0 comments on commit 1e6ded3

Please sign in to comment.