diff --git a/src/renderer/components/Experiment/Foundation/SelectAModel.tsx b/src/renderer/components/Experiment/Foundation/SelectAModel.tsx index bd6abab..8cc21f2 100644 --- a/src/renderer/components/Experiment/Foundation/SelectAModel.tsx +++ b/src/renderer/components/Experiment/Foundation/SelectAModel.tsx @@ -1,9 +1,7 @@ /* eslint-disable jsx-a11y/anchor-is-valid */ -import { useCallback, useEffect, useState } from 'react'; +import { useState } from 'react'; import { - Button, - Checkbox, FormControl, FormLabel, Input, @@ -26,22 +24,18 @@ import { ColorPaletteProp } from '@mui/joy/styles'; import { ArrowDownIcon, - BoxesIcon, CheckIcon, CreativeCommonsIcon, - FolderOpenIcon, GraduationCapIcon, - InfoIcon, - PlusIcon, SearchIcon, StoreIcon, - Trash2Icon, } from 'lucide-react'; import SelectButton from '../SelectButton'; import CurrentFoundationInfo from './CurrentFoundationInfo'; import useSWR from 'swr'; import * as chatAPI from '../../../lib/transformerlab-api-sdk'; -import Welcome from '../../Welcome'; + +import { modelTypes, licenseTypes, filterByFilters } from '../../../lib/utils'; type Order = 'asc' | 'desc'; @@ -74,6 +68,8 @@ export default function SelectAModel({ }) { const [order, setOrder] = useState('desc'); const [open, setOpen] = useState(false); + const [searchText, setSearchText] = useState(''); + const [filters, setFilters] = useState({}); const { data, error, isLoading, mutate } = useSWR( chatAPI.Endpoints.Models.LocalList(), @@ -99,17 +95,30 @@ export default function SelectAModel({ - Category - { + setFilters({ ...filters, architecture: newValue }); + }} + > + {modelTypes.map((type) => ( + + ))} @@ -194,7 +203,12 @@ export default function SelectAModel({ >   - } /> + setSearchText(e.target.value)} + startDecorator={} + /> {renderFilters()} @@ -252,7 +266,7 @@ export default function SelectAModel({ {data && - data.map((row) => ( + filterByFilters(data, searchText, filters).map((row) => ( diff --git a/src/renderer/components/ModelZoo/LocalModels.tsx b/src/renderer/components/ModelZoo/LocalModels.tsx index da5e25b..a50c004 100644 --- a/src/renderer/components/ModelZoo/LocalModels.tsx +++ b/src/renderer/components/ModelZoo/LocalModels.tsx @@ -43,44 +43,9 @@ import useSWR from 'swr'; import * as chatAPI from '../../lib/transformerlab-api-sdk'; import Welcome from '../Welcome'; -type Order = 'asc' | 'desc'; - -const modelTypes = [ - 'All', - 'MLX', - 'GGUF', - 'LlamaForCausalLM', - 'MistralForCausalLM', - 'T5ForConditionalGeneration', - 'PhiForCausalLM', - 'GPTBigCodeForCausalLM', -]; - -const licenseTypes = [ - 'All', - 'MIT', - 'CC BY-SA-4.0', - 'Apache 2.0', - 'Meta Custom', - 'GPL', -]; +import { modelTypes, licenseTypes, filterByFilters } from '../../lib/utils'; -function filterByFilters(data, searchText = '', filters = {}) { - return data.filter((row) => { - if (row.name.toLowerCase().includes(searchText.toLowerCase())) { - for (const filterKey in filters) { - console.log(filterKey, filters[filterKey]); - if (filters[filterKey] !== 'All') { - if (row[filterKey] !== filters[filterKey]) { - return false; - } - } - } - return true; - } - return false; - }); -} +type Order = 'asc' | 'desc'; const fetcher = (url) => fetch(url).then((res) => res.json()); diff --git a/src/renderer/components/ModelZoo/ModelStore.tsx b/src/renderer/components/ModelZoo/ModelStore.tsx index 4428bb9..8649db1 100644 --- a/src/renderer/components/ModelZoo/ModelStore.tsx +++ b/src/renderer/components/ModelZoo/ModelStore.tsx @@ -31,6 +31,8 @@ import useSWR from 'swr'; import * as chatAPI from '../../lib/transformerlab-api-sdk'; import TinyMLXLogo from '../Shared/TinyMLXLogo'; +import { modelTypes, licenseTypes, filterByFilters } from '../../lib/utils'; + function descendingComparator(a: T, b: T, orderBy: keyof T) { if (b[orderBy] < a[orderBy]) { return -1; @@ -74,43 +76,6 @@ function stableSort( return stabilizedThis.map((el) => el[0]); } -function filterByFilters(data, searchText = '', filters = {}) { - return data.filter((row) => { - if (row.name.toLowerCase().includes(searchText.toLowerCase())) { - for (const filterKey in filters) { - console.log(filterKey, filters[filterKey]); - if (filters[filterKey] !== 'All') { - if (row[filterKey] !== filters[filterKey]) { - return false; - } - } - } - return true; - } - return false; - }); -} - -const modelTypes = [ - 'All', - 'MLX', - 'GGUF', - 'LlamaForCausalLM', - 'MistralForCausalLM', - 'T5ForConditionalGeneration', - 'PhiForCausalLM', - 'GPTBigCodeForCausalLM', -]; - -const licenseTypes = [ - 'All', - 'MIT', - 'CC BY-SA-4.0', - 'Apache 2.0', - 'Meta Custom', - 'GPL', -]; - const fetcher = (url) => fetch(url).then((res) => res.json()); export default function ModelStore() { diff --git a/src/renderer/lib/utils.ts b/src/renderer/lib/utils.ts index ca9e0e6..01a382c 100644 --- a/src/renderer/lib/utils.ts +++ b/src/renderer/lib/utils.ts @@ -27,3 +27,40 @@ export function formatBytes(bytes: number, decimals = 2): string { return `${parseFloat((bytes / k ** i).toFixed(dm))} ${sizes[i]}`; } + +export const modelTypes = [ + 'All', + 'MLX', + 'GGUF', + 'LlamaForCausalLM', + 'MistralForCausalLM', + 'T5ForConditionalGeneration', + 'PhiForCausalLM', + 'GPTBigCodeForCausalLM', +]; + +export const licenseTypes = [ + 'All', + 'MIT', + 'CC BY-SA-4.0', + 'Apache 2.0', + 'Meta Custom', + 'GPL', +]; + +export function filterByFilters(data, searchText = '', filters = {}) { + return data.filter((row) => { + if (row.name.toLowerCase().includes(searchText.toLowerCase())) { + for (const filterKey in filters) { + console.log(filterKey, filters[filterKey]); + if (filters[filterKey] !== 'All') { + if (row[filterKey] !== filters[filterKey]) { + return false; + } + } + } + return true; + } + return false; + }); +}