Skip to content

Commit

Permalink
feat: Improve auto complete (#183)
Browse files Browse the repository at this point in the history
* improve auto complete

* escape id if it is conflicted with keywords

* feat: improve the auto complete in the middle of the statement
  • Loading branch information
invisal authored Oct 30, 2023
1 parent 7eee130 commit 6f5c924
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 175 deletions.
66 changes: 58 additions & 8 deletions src/renderer/components/CodeEditor/SchemaCompletionTree.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import { Completion } from '@codemirror/autocomplete';
import { SQLDialectSpec } from 'language/dist';
import {
DatabaseSchema,
DatabaseSchemaList,
TableSchema,
} from 'types/SqlSchema';

function buildTableCompletionTree(table: TableSchema): SchemaCompletionTree {
function buildTableCompletionTree(
table: TableSchema,
dialect: SQLDialectSpec,
keywords: Record<string, boolean>,
): SchemaCompletionTree {
const root = new SchemaCompletionTree();

for (const col of Object.values(table.columns)) {
root.addOption(col.name, {
label: col.name,
apply: escapeConflictedId(dialect, col.name, keywords),
type: 'property',
detail: col.dataType,
boost: 3,
Expand All @@ -21,7 +27,9 @@ function buildTableCompletionTree(table: TableSchema): SchemaCompletionTree {
}

function buildDatabaseCompletionTree(
database: DatabaseSchema
database: DatabaseSchema,
dialect: SQLDialectSpec,
keywords: Record<string, boolean>,
): SchemaCompletionTree {
const root = new SchemaCompletionTree();

Expand All @@ -33,15 +41,30 @@ function buildDatabaseCompletionTree(
boost: 1,
});

root.addChild(table.name, buildTableCompletionTree(table));
root.addChild(
table.name,
buildTableCompletionTree(table, dialect, keywords),
);
}

return root;
}

function escapeConflictedId(
dialect: SQLDialectSpec,
label: string,
keywords: Record<string, boolean>,
): string {
if (keywords[label.toLowerCase()])
return `${dialect.identifierQuotes}${label}${dialect.identifierQuotes}`;
return label;
}

function buildCompletionTree(
schema: DatabaseSchemaList | undefined,
currentDatabase: string | undefined
currentDatabase: string | undefined,
dialect: SQLDialectSpec,
keywords: Record<string, boolean>,
): SchemaCompletionTree {
const root: SchemaCompletionTree = new SchemaCompletionTree();
if (!schema) return root;
Expand All @@ -51,36 +74,63 @@ function buildCompletionTree(
for (const table of Object.values(schema[currentDatabase].tables)) {
root.addOption(table.name, {
label: table.name,
apply: escapeConflictedId(dialect, table.name, keywords),
type: 'table',
detail: 'table',
boost: 1,
});

root.addChild(table.name, buildTableCompletionTree(table));
root.addChild(
table.name,
buildTableCompletionTree(table, dialect, keywords),
);
}
}

for (const database of Object.values(schema)) {
root.addOption(database.name, {
label: database.name,
apply: escapeConflictedId(dialect, database.name, keywords),
type: 'property',
detail: 'database',
});

root.addChild(database.name, buildDatabaseCompletionTree(database));
root.addChild(
database.name,
buildDatabaseCompletionTree(database, dialect, keywords),
);
}

return root;
}
export default class SchemaCompletionTree {
protected options: Record<string, Completion> = {};
protected child: Record<string, SchemaCompletionTree> = {};
protected keywords: Record<string, boolean> = {};

static build(
schema: DatabaseSchemaList | undefined,
currentDatabase: string | undefined
currentDatabase: string | undefined,
dialect: SQLDialectSpec,
) {
return buildCompletionTree(schema, currentDatabase);
const keywords = (dialect.keywords + ' ' + dialect.builtin)
.split(' ')
.filter(Boolean)
.map((s) => s.toLowerCase());

const keywordDict = keywords.reduce(
(a, keyword) => {
a[keyword] = true;
return a;
},
{} as Record<string, boolean>,
);

return buildCompletionTree(schema, currentDatabase, dialect, keywordDict);
}

getLength() {
return Object.keys(this.options).length;
}

addOption(name: string, complete: Completion) {
Expand Down
18 changes: 14 additions & 4 deletions src/renderer/components/CodeEditor/SqlCodeEditor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import { MySQLDialect, MySQLTooltips } from 'dialects/MySQLDialect';
import { QueryDialetType } from 'libs/QueryBuilder';
import { PgDialect, PgTooltips } from 'dialects/PgDialect copy';
import { useKeybinding } from 'renderer/contexts/KeyBindingProvider';
import SchemaCompletionTree from './SchemaCompletionTree';

const SqlCodeEditor = forwardRef(function SqlCodeEditor(
props: ReactCodeMirrorProps & {
Expand All @@ -43,16 +44,28 @@ const SqlCodeEditor = forwardRef(function SqlCodeEditor(
const { binding } = useKeybinding();
const theme = useCodeEditorTheme();

const dialect = props.dialect === 'mysql' ? MySQLDialect : PgDialect;
const tooltips = props.dialect === 'mysql' ? MySQLTooltips : PgTooltips;

const schemaTree = useMemo(() => {
return SchemaCompletionTree.build(
schema?.getSchema(),
currentDatabase,
dialect.spec,
);
}, [schema, currentDatabase, dialect]);

const customAutoComplete = useCallback(
(context: CompletionContext, tree: SyntaxNode): CompletionResult | null => {
return handleCustomSqlAutoComplete(
context,
tree,
schemaTree,
schema?.getSchema(),
currentDatabase,
);
},
[schema, currentDatabase],
[schema, schemaTree, currentDatabase],
);

const tableNameHighlightPlugin = useMemo(() => {
Expand All @@ -64,9 +77,6 @@ const SqlCodeEditor = forwardRef(function SqlCodeEditor(
return createSQLTableNameHighlightPlugin([]);
}, [schema, currentDatabase]);

const dialect = props.dialect === 'mysql' ? MySQLDialect : PgDialect;
const tooltips = props.dialect === 'mysql' ? MySQLTooltips : PgTooltips;

const keyExtension = useMemo(() => {
return keymap.of([
// Prevent the default behavior if it matches any of
Expand Down
140 changes: 140 additions & 0 deletions src/renderer/components/CodeEditor/autocomplete_test_utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import { EditorState } from '@codemirror/state';
import {
CompletionContext,
CompletionResult,
CompletionSource,
} from '@codemirror/autocomplete';
import handleCustomSqlAutoComplete from './handleCustomSqlAutoComplete';
import { MySQL, genericCompletion } from './../../../language/dist';
import {
DatabaseSchemaList,
TableColumnSchema,
TableSchema,
} from 'types/SqlSchema';
import SchemaCompletionTree from './SchemaCompletionTree';

export function get_test_autocomplete(
doc: string,
{
schema,
currentDatabase,
}: { schema: DatabaseSchemaList; currentDatabase?: string },
) {
const cur = doc.indexOf('|'),
dialect = MySQL;

doc = doc.slice(0, cur) + doc.slice(cur + 1);
const state = EditorState.create({
doc,
selection: { anchor: cur },
extensions: [
dialect,
dialect.language.data.of({
autocomplete: genericCompletion((context, tree) =>
handleCustomSqlAutoComplete(
context,
tree,
SchemaCompletionTree.build(schema, currentDatabase, dialect.spec),
schema,
currentDatabase,
),
),
}),
],
});

const result = state.languageDataAt<CompletionSource>('autocomplete', cur)[0](
new CompletionContext(state, cur, false),
);
return result as CompletionResult | null;
}

export function convert_autocomplete_to_string(
result: CompletionResult | null,
) {
return !result
? ''
: result.options
.slice()
.sort(
(a, b) =>
(b.boost || 0) - (a.boost || 0) || (a.label < b.label ? -1 : 1),
)
.map((o) => o.apply || o.label)
.join(', ');
}

function map_column_type(
tableName: string,
name: string,
type: string,
): TableColumnSchema {
const tokens = type.split('(');
let enumValues: string[] | undefined;

if (tokens[1]) {
// remove )
enumValues = tokens[1]
.replace(')', '')
.replaceAll("'", '')
.split(',')
.map((a) => a.trim());
}

return {
name,
tableName,
schemaName: '',
charLength: 0,
comment: '',
enumValues,
dataType: tokens[0],
nullable: true,
};
}

function map_cols(
tableName: string,
cols: Record<string, string>,
): Record<string, TableColumnSchema> {
return Object.entries(cols).reduce(
(acc, [colName, colType]) => {
acc[colName] = map_column_type(tableName, colName, colType);
return acc;
},
{} as Record<string, TableColumnSchema>,
);
}

function map_table(
tables: Record<string, Record<string, string>>,
): Record<string, TableSchema> {
return Object.entries(tables).reduce(
(acc, [tableName, cols]) => {
acc[tableName] = {
columns: map_cols(tableName, cols),
constraints: [],
name: tableName,
type: 'TABLE',
primaryKey: [],
};

return acc;
},
{} as Record<string, TableSchema>,
);
}

export function create_test_schema(
schemas: Record<string, Record<string, Record<string, string>>>,
) {
return Object.entries(schemas).reduce((acc, [schema, tables]) => {
acc[schema] = {
name: schema,
events: [],
triggers: [],
tables: map_table(tables),
};
return acc;
}, {} as DatabaseSchemaList);
}
Loading

0 comments on commit 6f5c924

Please sign in to comment.