Skip to content

Commit

Permalink
Support pk in introspection
Browse files Browse the repository at this point in the history
  • Loading branch information
aqrln committed Nov 27, 2024
1 parent ce2457c commit c1f4599
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl<'a> DatamodelCalculatorContext<'a> {
.table_walkers()
.filter(|table| !is_old_migration_table(*table))
.filter(|table| !is_new_migration_table(*table))
.filter(|table| !is_prisma_m_to_n_relation(*table))
.filter(|table| !is_prisma_m_to_n_relation(*table, self.flavour.uses_pk_in_m2m_join_tables(self)))
.filter(|table| !is_relay_table(*table))
.map(move |next| {
let previous = self.existing_model(next.id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,8 @@ pub(crate) trait IntrospectionFlavour {
fn uses_exclude_constraint(&self, _ctx: &DatamodelCalculatorContext<'_>, _table: TableWalker<'_>) -> bool {
false
}

fn uses_pk_in_m2m_join_tables(&self, _ctx: &DatamodelCalculatorContext<'_>) -> bool {
false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,8 @@ impl super::IntrospectionFlavour for PostgresIntrospectionFlavour {
let pg_ext: &PostgresSchemaExt = ctx.sql_schema.downcast_connector_data();
pg_ext.uses_exclude_constraint(table.id)
}

fn uses_pk_in_m2m_join_tables(&self, ctx: &DatamodelCalculatorContext<'_>) -> bool {
!ctx.is_cockroach()
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Small utility functions.
use sql::walkers::TableWalker;
use sql_schema_describer::{self as sql, IndexType};
use sql_schema_describer::{self as sql, IndexColumnWalker, IndexType};
use std::cmp;

/// This function implements the reverse behaviour of the `Ord` implementation for `Option`: it
Expand Down Expand Up @@ -56,7 +56,7 @@ pub(crate) fn is_relay_table(table: TableWalker<'_>) -> bool {
}

/// If a relation defines a Prisma many to many relation.
pub(crate) fn is_prisma_m_to_n_relation(table: TableWalker<'_>) -> bool {
pub(crate) fn is_prisma_m_to_n_relation(table: TableWalker<'_>, pk_allowed: bool) -> bool {
fn is_a(column: &str) -> bool {
column.eq_ignore_ascii_case("a")
}
Expand All @@ -65,9 +65,16 @@ pub(crate) fn is_prisma_m_to_n_relation(table: TableWalker<'_>) -> bool {
column.eq_ignore_ascii_case("b")
}

fn index_columns_match<'a>(mut columns: impl ExactSizeIterator<Item = IndexColumnWalker<'a>>) -> bool {
columns.len() == 2
&& is_a(columns.next().unwrap().as_column().name())
&& is_b(columns.next().unwrap().as_column().name())
}

let mut fks = table.foreign_keys();
let first_fk = fks.next();
let second_fk = fks.next();

let a_b_match = || {
let first_fk = first_fk.unwrap();
let second_fk = second_fk.unwrap();
Expand All @@ -80,14 +87,13 @@ pub(crate) fn is_prisma_m_to_n_relation(table: TableWalker<'_>) -> bool {
&& is_b(first_fk_col)
&& is_a(second_fk_col))
};

table.name().starts_with('_')
//UNIQUE INDEX [A,B]
&& table.indexes().any(|i| {
i.columns().len() == 2
&& is_a(i.columns().next().unwrap().as_column().name())
&& is_b(i.columns().nth(1).unwrap().as_column().name())
// UNIQUE INDEX (A, B) or PRIMARY KEY (A, B)
&& (table.indexes().any(|i| {
index_columns_match(i.columns())
&& i.is_unique()
})
}) || pk_allowed && table.primary_key_columns().map(index_columns_match).unwrap_or(false))
//INDEX [B]
&& table
.indexes()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ impl<'a> IntrospectionMap<'a> {
match_enums(sql_schema, prisma_schema, &mut map);
match_existing_scalar_fields(sql_schema, prisma_schema, &mut map);
match_existing_inline_relations(sql_schema, prisma_schema, &mut map);
match_existing_m2m_relations(sql_schema, prisma_schema, &mut map);
match_existing_m2m_relations(sql_schema, prisma_schema, ctx, &mut map);
relation_names::introspect(ctx, &mut map);
position_inline_relation_fields(sql_schema, &mut map);
position_m2m_relation_fields(sql_schema, &mut map);
populate_top_level_names(sql_schema, prisma_schema, &mut map);
position_inline_relation_fields(sql_schema, ctx, &mut map);
position_m2m_relation_fields(sql_schema, ctx, &mut map);
populate_top_level_names(sql_schema, prisma_schema, ctx, &mut map);

map
}
Expand All @@ -63,11 +63,12 @@ impl<'a> IntrospectionMap<'a> {
fn populate_top_level_names<'a>(
sql_schema: &'a sql::SqlSchema,
prisma_schema: &'a psl::ValidatedSchema,
ctx: &DatamodelCalculatorContext<'_>,
map: &mut IntrospectionMap<'a>,
) {
for table in sql_schema
.table_walkers()
.filter(|t| !helpers::is_prisma_m_to_n_relation(*t))
.filter(|t| !helpers::is_prisma_m_to_n_relation(*t, ctx.flavour.uses_pk_in_m2m_join_tables(ctx)))
{
let name = map
.existing_models
Expand Down Expand Up @@ -115,10 +116,14 @@ fn populate_top_level_names<'a>(

/// Inlined relation fields (foreign key is defined in a model) are
/// sorted in a specific way. We handle the sorting here.
fn position_inline_relation_fields(sql_schema: &sql::SqlSchema, map: &mut IntrospectionMap<'_>) {
fn position_inline_relation_fields(
sql_schema: &sql::SqlSchema,
ctx: &DatamodelCalculatorContext<'_>,
map: &mut IntrospectionMap<'_>,
) {
for table in sql_schema
.table_walkers()
.filter(|t| !helpers::is_prisma_m_to_n_relation(*t))
.filter(|t| !helpers::is_prisma_m_to_n_relation(*t, ctx.flavour.uses_pk_in_m2m_join_tables(ctx)))
{
for fk in table.foreign_keys() {
map.inline_relation_positions
Expand All @@ -133,10 +138,14 @@ fn position_inline_relation_fields(sql_schema: &sql::SqlSchema, map: &mut Intros
/// Many to many relation fields (foreign keys are defined in a hidden
/// join table) are sorted in a specific way. We handle the sorting
/// here.
fn position_m2m_relation_fields(sql_schema: &sql::SqlSchema, map: &mut IntrospectionMap<'_>) {
fn position_m2m_relation_fields(
sql_schema: &sql::SqlSchema,
ctx: &DatamodelCalculatorContext<'_>,
map: &mut IntrospectionMap<'_>,
) {
for table in sql_schema
.table_walkers()
.filter(|t| helpers::is_prisma_m_to_n_relation(*t))
.filter(|t| helpers::is_prisma_m_to_n_relation(*t, ctx.flavour.uses_pk_in_m2m_join_tables(ctx)))
{
let mut fks = table.foreign_keys();

Expand Down Expand Up @@ -313,11 +322,12 @@ fn match_existing_inline_relations<'a>(
fn match_existing_m2m_relations(
sql_schema: &sql::SqlSchema,
prisma_schema: &psl::ValidatedSchema,
ctx: &DatamodelCalculatorContext<'_>,
map: &mut IntrospectionMap<'_>,
) {
map.existing_m2m_relations = sql_schema
.table_walkers()
.filter(|t| helpers::is_prisma_m_to_n_relation(*t))
.filter(|t| helpers::is_prisma_m_to_n_relation(*t, ctx.flavour.uses_pk_in_m2m_join_tables(ctx)))
.filter_map(|table| {
prisma_schema
.db
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub(super) fn introspect<'a>(ctx: &DatamodelCalculatorContext<'a>, map: &mut sup
let ambiguous_relations = find_ambiguous_relations(ctx);

for table in ctx.sql_schema.table_walkers() {
if is_prisma_m_to_n_relation(table) {
if is_prisma_m_to_n_relation(table, ctx.flavour.uses_pk_in_m2m_join_tables(ctx)) {
let name = prisma_m2m_relation_name(table, &ambiguous_relations, ctx);
names.m2m_relation_names.insert(table.id, name);
} else {
Expand Down Expand Up @@ -175,8 +175,8 @@ fn find_ambiguous_relations(ctx: &DatamodelCalculatorContext<'_>) -> HashSet<[sq
let mut ambiguous_relations = HashSet::new();

for table in ctx.sql_schema.table_walkers() {
if is_prisma_m_to_n_relation(table) {
m2m_relation_ambiguousness(table, &mut ambiguous_relations)
if is_prisma_m_to_n_relation(table, ctx.flavour.uses_pk_in_m2m_join_tables(ctx)) {
m2m_relation_ambiguousness(table, ctx, &mut ambiguous_relations)
} else {
for fk in table.foreign_keys() {
inline_relation_ambiguousness(fk, &mut ambiguous_relations, ctx)
Expand All @@ -187,7 +187,11 @@ fn find_ambiguous_relations(ctx: &DatamodelCalculatorContext<'_>) -> HashSet<[sq
ambiguous_relations
}

fn m2m_relation_ambiguousness(table: sql::TableWalker<'_>, ambiguous_relations: &mut HashSet<[sql::TableId; 2]>) {
fn m2m_relation_ambiguousness(
table: sql::TableWalker<'_>,
ctx: &DatamodelCalculatorContext<'_>,
ambiguous_relations: &mut HashSet<[sql::TableId; 2]>,
) {
let tables = table_ids_for_m2m_relation_table(table);

if ambiguous_relations.contains(&tables) {
Expand All @@ -205,7 +209,11 @@ fn m2m_relation_ambiguousness(table: sql::TableWalker<'_>, ambiguous_relations:
}

// Check for conflicts with another m2m relation.
for other_m2m in table.schema.table_walkers().filter(|t| is_prisma_m_to_n_relation(*t)) {
for other_m2m in table
.schema
.table_walkers()
.filter(|t| is_prisma_m_to_n_relation(*t, ctx.flavour.uses_pk_in_m2m_join_tables(ctx)))
{
if other_m2m.id != table.id && table_ids_for_m2m_relation_table(other_m2m) == tables {
ambiguous_relations.insert(tables);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ async fn name_ambiguity_with_a_scalar_field(api: &mut TestApi) -> TestResult {
}

#[test_connector(tags(Postgres), exclude(CockroachDb))]
async fn a_prisma_many_to_many_relation(api: &mut TestApi) -> TestResult {
async fn legacy_prisma_many_to_many_relation(api: &mut TestApi) -> TestResult {
let setup = indoc! {r#"
CREATE TABLE "User" (
id SERIAL PRIMARY KEY
Expand Down Expand Up @@ -266,3 +266,53 @@ async fn a_prisma_many_to_many_relation(api: &mut TestApi) -> TestResult {

Ok(())
}

#[test_connector(tags(Postgres), exclude(CockroachDb))]
async fn new_prisma_many_to_many_relation(api: &mut TestApi) -> TestResult {
let setup = indoc! {r#"
CREATE TABLE "User" (
id SERIAL PRIMARY KEY
);
CREATE TABLE "Post" (
id SERIAL PRIMARY KEY
);
CREATE TABLE "_PostToUser" (
"A" INT NOT NULL,
"B" INT NOT NULL,
CONSTRAINT "_PostToUser_A_fkey" FOREIGN KEY ("A") REFERENCES "Post"(id),
CONSTRAINT "_PostToUser_B_fkey" FOREIGN KEY ("B") REFERENCES "User"(id),
CONSTRAINT "_PostToUser_AB_pkey" PRIMARY KEY ("A", "B")
);
CREATE INDEX test ON "_PostToUser" ("B");
"#};

api.raw_cmd(setup).await;

let expected = expect![[r#"
generator client {
provider = "prisma-client-js"
}
datasource db {
provider = "postgresql"
url = "env(TEST_DATABASE_URL)"
}
model Post {
id Int @id @default(autoincrement())
User User[]
}
model User {
id Int @id @default(autoincrement())
Post Post[]
}
"#]];

api.expect_datamodel(&expected).await;

Ok(())
}

0 comments on commit c1f4599

Please sign in to comment.