Skip to content

Commit

Permalink
consistent model filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
aliasaria committed Jan 25, 2024
1 parent 086a319 commit bd4e28a
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 92 deletions.
50 changes: 32 additions & 18 deletions src/renderer/components/Experiment/Foundation/SelectAModel.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
/* eslint-disable jsx-a11y/anchor-is-valid */
import { useCallback, useEffect, useState } from 'react';
import { useState } from 'react';

import {
Button,
Checkbox,
FormControl,
FormLabel,
Input,
Expand All @@ -26,22 +24,18 @@ import { ColorPaletteProp } from '@mui/joy/styles';

import {
ArrowDownIcon,
BoxesIcon,
CheckIcon,
CreativeCommonsIcon,
FolderOpenIcon,
GraduationCapIcon,
InfoIcon,
PlusIcon,
SearchIcon,
StoreIcon,
Trash2Icon,
} from 'lucide-react';
import SelectButton from '../SelectButton';
import CurrentFoundationInfo from './CurrentFoundationInfo';
import useSWR from 'swr';
import * as chatAPI from '../../../lib/transformerlab-api-sdk';
import Welcome from '../../Welcome';

import { modelTypes, licenseTypes, filterByFilters } from '../../../lib/utils';

type Order = 'asc' | 'desc';

Expand Down Expand Up @@ -74,6 +68,8 @@ export default function SelectAModel({
}) {
const [order, setOrder] = useState<Order>('desc');
const [open, setOpen] = useState(false);
const [searchText, setSearchText] = useState('');
const [filters, setFilters] = useState({});

const { data, error, isLoading, mutate } = useSWR(
chatAPI.Endpoints.Models.LocalList(),
Expand All @@ -99,17 +95,30 @@ export default function SelectAModel({
<Select
placeholder="Filter by license"
slotProps={{ button: { sx: { whiteSpace: 'nowrap' } } }}
value={filters?.license}
disabled
onChange={(e, newValue) => {
setFilters({ ...filters, license: newValue });
}}
>
<Option value="MIT">MIT</Option>
<Option value="pending">CC BY-SA-4.0</Option>
<Option value="refunded">Refunded</Option>
<Option value="Cancelled">Apache 2.0</Option>
{licenseTypes.map((type) => (
<Option value={type}>{type}</Option>
))}
</Select>
</FormControl>
<FormControl size="sm">
<FormLabel>Category</FormLabel>
<Select placeholder="All">
<Option value="all">All</Option>
<FormLabel>Architecture</FormLabel>
<Select
placeholder="All"
disabled
value={filters?.architecture}
onChange={(e, newValue) => {
setFilters({ ...filters, architecture: newValue });
}}
>
{modelTypes.map((type) => (
<Option value={type}>{type}</Option>
))}
</Select>
</FormControl>
</>
Expand Down Expand Up @@ -194,7 +203,12 @@ export default function SelectAModel({
>
<FormControl sx={{ flex: 1 }} size="sm">
<FormLabel>&nbsp;</FormLabel>
<Input placeholder="Search" startDecorator={<SearchIcon />} />
<Input
placeholder="Search"
value={searchText}
onChange={(e) => setSearchText(e.target.value)}
startDecorator={<SearchIcon />}
/>
</FormControl>

{renderFilters()}
Expand Down Expand Up @@ -252,7 +266,7 @@ export default function SelectAModel({
</thead>
<tbody>
{data &&
data.map((row) => (
filterByFilters(data, searchText, filters).map((row) => (
<tr key={row.rowid}>
<td>
<Typography ml={2} fontWeight="lg">
Expand Down
39 changes: 2 additions & 37 deletions src/renderer/components/ModelZoo/LocalModels.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,44 +43,9 @@ import useSWR from 'swr';
import * as chatAPI from '../../lib/transformerlab-api-sdk';
import Welcome from '../Welcome';

type Order = 'asc' | 'desc';

const modelTypes = [
'All',
'MLX',
'GGUF',
'LlamaForCausalLM',
'MistralForCausalLM',
'T5ForConditionalGeneration',
'PhiForCausalLM',
'GPTBigCodeForCausalLM',
];

const licenseTypes = [
'All',
'MIT',
'CC BY-SA-4.0',
'Apache 2.0',
'Meta Custom',
'GPL',
];
import { modelTypes, licenseTypes, filterByFilters } from '../../lib/utils';

function filterByFilters(data, searchText = '', filters = {}) {
return data.filter((row) => {
if (row.name.toLowerCase().includes(searchText.toLowerCase())) {
for (const filterKey in filters) {
console.log(filterKey, filters[filterKey]);
if (filters[filterKey] !== 'All') {
if (row[filterKey] !== filters[filterKey]) {
return false;
}
}
}
return true;
}
return false;
});
}
type Order = 'asc' | 'desc';

const fetcher = (url) => fetch(url).then((res) => res.json());

Expand Down
39 changes: 2 additions & 37 deletions src/renderer/components/ModelZoo/ModelStore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import useSWR from 'swr';
import * as chatAPI from '../../lib/transformerlab-api-sdk';
import TinyMLXLogo from '../Shared/TinyMLXLogo';

import { modelTypes, licenseTypes, filterByFilters } from '../../lib/utils';

function descendingComparator<T>(a: T, b: T, orderBy: keyof T) {
if (b[orderBy] < a[orderBy]) {
return -1;
Expand Down Expand Up @@ -74,43 +76,6 @@ function stableSort<T>(
return stabilizedThis.map((el) => el[0]);
}

function filterByFilters(data, searchText = '', filters = {}) {
return data.filter((row) => {
if (row.name.toLowerCase().includes(searchText.toLowerCase())) {
for (const filterKey in filters) {
console.log(filterKey, filters[filterKey]);
if (filters[filterKey] !== 'All') {
if (row[filterKey] !== filters[filterKey]) {
return false;
}
}
}
return true;
}
return false;
});
}

const modelTypes = [
'All',
'MLX',
'GGUF',
'LlamaForCausalLM',
'MistralForCausalLM',
'T5ForConditionalGeneration',
'PhiForCausalLM',
'GPTBigCodeForCausalLM',
];

const licenseTypes = [
'All',
'MIT',
'CC BY-SA-4.0',
'Apache 2.0',
'Meta Custom',
'GPL',
];

const fetcher = (url) => fetch(url).then((res) => res.json());

export default function ModelStore() {
Expand Down
37 changes: 37 additions & 0 deletions src/renderer/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,40 @@ export function formatBytes(bytes: number, decimals = 2): string {

return `${parseFloat((bytes / k ** i).toFixed(dm))} ${sizes[i]}`;
}

export const modelTypes = [
'All',
'MLX',
'GGUF',
'LlamaForCausalLM',
'MistralForCausalLM',
'T5ForConditionalGeneration',
'PhiForCausalLM',
'GPTBigCodeForCausalLM',
];

export const licenseTypes = [
'All',
'MIT',
'CC BY-SA-4.0',
'Apache 2.0',
'Meta Custom',
'GPL',
];

export function filterByFilters(data, searchText = '', filters = {}) {
return data.filter((row) => {
if (row.name.toLowerCase().includes(searchText.toLowerCase())) {
for (const filterKey in filters) {
console.log(filterKey, filters[filterKey]);
if (filters[filterKey] !== 'All') {
if (row[filterKey] !== filters[filterKey]) {
return false;
}
}
}
return true;
}
return false;
});
}

0 comments on commit bd4e28a

Please sign in to comment.