Skip to content

Commit

Permalink
Support update ... from (Pg | SQLite) (#2963)
Browse files Browse the repository at this point in the history
* Implement `update ... from` in PG

* Add `update ... from` in SQLite

* Lint and format

* Fix type error

* Fix SQLite type errors

* Lint and format

* Push merge changes
  • Loading branch information
L-Mario564 authored Nov 14, 2024
1 parent c31614a commit d7e3535
Show file tree
Hide file tree
Showing 13 changed files with 1,152 additions and 137 deletions.
134 changes: 78 additions & 56 deletions drizzle-orm/src/pg-core/dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,30 @@ export class PgDialect {
}));
}

buildUpdateQuery({ table, set, where, returning, withList }: PgUpdateConfig): SQL {
buildUpdateQuery({ table, set, where, returning, withList, from, joins }: PgUpdateConfig): SQL {
const withSql = this.buildWithCTE(withList);

const tableName = table[PgTable.Symbol.Name];
const tableSchema = table[PgTable.Symbol.Schema];
const origTableName = table[PgTable.Symbol.OriginalName];
const alias = tableName === origTableName ? undefined : tableName;
const tableSql = sql`${tableSchema ? sql`${sql.identifier(tableSchema)}.` : undefined}${
sql.identifier(origTableName)
}${alias && sql` ${sql.identifier(alias)}`}`;

const setSql = this.buildUpdateSet(table, set);

const fromSql = from && sql.join([sql.raw(' from '), this.buildFromTable(from)]);

const joinsSql = this.buildJoins(joins);

const returningSql = returning
? sql` returning ${this.buildSelection(returning, { isSingleTable: true })}`
? sql` returning ${this.buildSelection(returning, { isSingleTable: !from })}`
: undefined;

const whereSql = where ? sql` where ${where}` : undefined;

return sql`${withSql}update ${table} set ${setSql}${whereSql}${returningSql}`;
return sql`${withSql}update ${tableSql} set ${setSql}${fromSql}${joinsSql}${whereSql}${returningSql}`;
}

/**
Expand Down Expand Up @@ -245,6 +257,67 @@ export class PgDialect {
return sql.join(chunks);
}

private buildJoins(joins: PgSelectJoinConfig[] | undefined): SQL | undefined {
if (!joins || joins.length === 0) {
return undefined;
}

const joinsArray: SQL[] = [];

for (const [index, joinMeta] of joins.entries()) {
if (index === 0) {
joinsArray.push(sql` `);
}
const table = joinMeta.table;
const lateralSql = joinMeta.lateral ? sql` lateral` : undefined;

if (is(table, PgTable)) {
const tableName = table[PgTable.Symbol.Name];
const tableSchema = table[PgTable.Symbol.Schema];
const origTableName = table[PgTable.Symbol.OriginalName];
const alias = tableName === origTableName ? undefined : joinMeta.alias;
joinsArray.push(
sql`${sql.raw(joinMeta.joinType)} join${lateralSql} ${
tableSchema ? sql`${sql.identifier(tableSchema)}.` : undefined
}${sql.identifier(origTableName)}${alias && sql` ${sql.identifier(alias)}`} on ${joinMeta.on}`,
);
} else if (is(table, View)) {
const viewName = table[ViewBaseConfig].name;
const viewSchema = table[ViewBaseConfig].schema;
const origViewName = table[ViewBaseConfig].originalName;
const alias = viewName === origViewName ? undefined : joinMeta.alias;
joinsArray.push(
sql`${sql.raw(joinMeta.joinType)} join${lateralSql} ${
viewSchema ? sql`${sql.identifier(viewSchema)}.` : undefined
}${sql.identifier(origViewName)}${alias && sql` ${sql.identifier(alias)}`} on ${joinMeta.on}`,
);
} else {
joinsArray.push(
sql`${sql.raw(joinMeta.joinType)} join${lateralSql} ${table} on ${joinMeta.on}`,
);
}
if (index < joins.length - 1) {
joinsArray.push(sql` `);
}
}

return sql.join(joinsArray);
}

private buildFromTable(
table: SQL | Subquery | PgViewBase | PgTable | undefined,
): SQL | Subquery | PgViewBase | PgTable | undefined {
if (is(table, Table) && table[Table.Symbol.OriginalName] !== table[Table.Symbol.Name]) {
let fullName = sql`${sql.identifier(table[Table.Symbol.OriginalName])}`;
if (table[Table.Symbol.Schema]) {
fullName = sql`${sql.identifier(table[Table.Symbol.Schema]!)}.${fullName}`;
}
return sql`${fullName} ${sql.identifier(table[Table.Symbol.Name])}`;
}

return table;
}

buildSelectQuery(
{
withList,
Expand Down Expand Up @@ -300,60 +373,9 @@ export class PgDialect {

const selection = this.buildSelection(fieldsList, { isSingleTable });

const tableSql = (() => {
if (is(table, Table) && table[Table.Symbol.OriginalName] !== table[Table.Symbol.Name]) {
let fullName = sql`${sql.identifier(table[Table.Symbol.OriginalName])}`;
if (table[Table.Symbol.Schema]) {
fullName = sql`${sql.identifier(table[Table.Symbol.Schema]!)}.${fullName}`;
}
return sql`${fullName} ${sql.identifier(table[Table.Symbol.Name])}`;
}

return table;
})();

const joinsArray: SQL[] = [];

if (joins) {
for (const [index, joinMeta] of joins.entries()) {
if (index === 0) {
joinsArray.push(sql` `);
}
const table = joinMeta.table;
const lateralSql = joinMeta.lateral ? sql` lateral` : undefined;

if (is(table, PgTable)) {
const tableName = table[PgTable.Symbol.Name];
const tableSchema = table[PgTable.Symbol.Schema];
const origTableName = table[PgTable.Symbol.OriginalName];
const alias = tableName === origTableName ? undefined : joinMeta.alias;
joinsArray.push(
sql`${sql.raw(joinMeta.joinType)} join${lateralSql} ${
tableSchema ? sql`${sql.identifier(tableSchema)}.` : undefined
}${sql.identifier(origTableName)}${alias && sql` ${sql.identifier(alias)}`} on ${joinMeta.on}`,
);
} else if (is(table, View)) {
const viewName = table[ViewBaseConfig].name;
const viewSchema = table[ViewBaseConfig].schema;
const origViewName = table[ViewBaseConfig].originalName;
const alias = viewName === origViewName ? undefined : joinMeta.alias;
joinsArray.push(
sql`${sql.raw(joinMeta.joinType)} join${lateralSql} ${
viewSchema ? sql`${sql.identifier(viewSchema)}.` : undefined
}${sql.identifier(origViewName)}${alias && sql` ${sql.identifier(alias)}`} on ${joinMeta.on}`,
);
} else {
joinsArray.push(
sql`${sql.raw(joinMeta.joinType)} join${lateralSql} ${table} on ${joinMeta.on}`,
);
}
if (index < joins.length - 1) {
joinsArray.push(sql` `);
}
}
}
const tableSql = this.buildFromTable(table);

const joinsSql = sql.join(joinsArray);
const joinsSql = this.buildJoins(joins);

const whereSql = where ? sql` where ${where}` : undefined;

Expand Down
4 changes: 2 additions & 2 deletions drizzle-orm/src/pg-core/query-builders/select.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ import type {
LockConfig,
LockStrength,
PgCreateSetOperatorFn,
PgJoinFn,
PgSelectConfig,
PgSelectDynamic,
PgSelectHKT,
PgSelectHKTBase,
PgSelectJoinFn,
PgSelectPrepare,
PgSelectWithout,
PgSetOperatorExcludedMethods,
Expand Down Expand Up @@ -194,7 +194,7 @@ export abstract class PgSelectQueryBuilderBase<

private createJoin<TJoinType extends JoinType>(
joinType: TJoinType,
): PgJoinFn<this, TDynamic, TJoinType> {
): PgSelectJoinFn<this, TDynamic, TJoinType> {
return (
table: PgTable | Subquery | PgViewBase | SQL,
on: ((aliases: TSelection) => SQL | undefined) | SQL | undefined,
Expand Down
6 changes: 3 additions & 3 deletions drizzle-orm/src/pg-core/query-builders/select.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export interface PgSelectConfig {
}[];
}

export type PgJoin<
export type PgSelectJoin<
T extends AnyPgSelectQueryBuilder,
TDynamic extends boolean,
TJoinType extends JoinType,
Expand Down Expand Up @@ -108,7 +108,7 @@ export type PgJoin<
>
: never;

export type PgJoinFn<
export type PgSelectJoinFn<
T extends AnyPgSelectQueryBuilder,
TDynamic extends boolean,
TJoinType extends JoinType,
Expand All @@ -118,7 +118,7 @@ export type PgJoinFn<
>(
table: TJoinedTable,
on: ((aliases: T['_']['selection']) => SQL | undefined) | SQL | undefined,
) => PgJoin<T, TDynamic, TJoinType, TJoinedTable, TJoinedName>;
) => PgSelectJoin<T, TDynamic, TJoinType, TJoinedTable, TJoinedName>;

export type SelectedFieldsFlat = SelectedFieldsFlatBase<PgColumn>;

Expand Down
Loading

0 comments on commit d7e3535

Please sign in to comment.