Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: activate model installation in desktop onboarding #261

Merged
merged 9 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@ export const ModelQuailityTag = ({
quality: OllamaModelQuality;
}) => {
const colorMap: { [key in OllamaModelQuality]: string } = {
[OllamaModelQuality.Bad]: 'border-red-700 bg-red-400 text-red-600',
[OllamaModelQuality.Medium]:
'border-yellow-700 bg-yellow-400 text-yellow-600',
[OllamaModelQuality.Great]: 'border-green-700 bg-green-400 text-green-600',
[OllamaModelQuality.Bad]: 'bg-red-900 text-red-400',
[OllamaModelQuality.Medium]: 'text-yellow-400 bg-yellow-900',
[OllamaModelQuality.Great]: 'text-green-400 bg-green-900',
};
return (
<Badge className={cn('capitalize', colorMap[quality])} variant="outline">
<Badge
className={cn(
'rounded-full border-0 px-2 py-1 font-normal capitalize',
colorMap[quality],
)}
variant="outline"
>
{quality}
</Badge>
);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { Badge } from '@shinkai_network/shinkai-ui';
import { cn } from '@shinkai_network/shinkai-ui/utils';

import { OllamaModelSpeed } from '../../../lib/shinkai-node-manager/ollama-models';
Expand All @@ -12,8 +11,8 @@ export const ModelSpeedTag = ({ speed }: { speed: OllamaModelSpeed }) => {
[OllamaModelSpeed.VeryFast]: '🐆',
};
return (
<Badge className={cn('capitalize')} variant="outline">
<div className={cn(' px-2 font-normal capitalize')}>
{speed} {emojiMap[speed]}
</Badge>
</div>
);
};
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { useSyncOllamaModels } from '@shinkai_network/shinkai-node-state/lib/mutations/syncOllamaModels/useSyncOllamaModels';
import {
Badge,
Button,
Progress,
ScrollArea,
Expand All @@ -11,10 +10,12 @@ import {
TableHeader,
TableRow,
} from '@shinkai_network/shinkai-ui';
import { useMap } from '@shinkai_network/shinkai-ui/hooks';
import { cn } from '@shinkai_network/shinkai-ui/utils';
import { Loader2 } from 'lucide-react';
import { motion } from 'framer-motion';
import { Download, Loader2, Minus } from 'lucide-react';
import { ModelResponse, ProgressResponse } from 'ollama/browser';
import { useEffect, useRef, useState } from 'react';
import { useEffect } from 'react';
import { toast } from 'sonner';

import {
Expand All @@ -36,25 +37,29 @@ export const OllamaModels = () => {
const auth = useAuth((auth) => auth.auth);
const { data: ollamaApiUrl } = useShinkaiNodeGetOllamaApiUrlQuery();
const ollamaConfig = { host: ollamaApiUrl || 'http://127.0.0.1:11435' };

const installedOllamaModelsMap = useMap<string, ModelResponse>();
const pullingModelsMap = useMap<string, ProgressResponse>();

const { data: isShinkaiNodeRunning } = useShinkaiNodeIsRunningQuery();
const { mutateAsync: shinkaiNodeSpawn } = useShinkaiNodeSpawnMutation({});
const { mutateAsync: syncOllamaModels } = useSyncOllamaModels(OLLAMA_MODELS.map((value) => value.fullName));
const { mutateAsync: syncOllamaModels } = useSyncOllamaModels(
OLLAMA_MODELS.map((value) => value.fullName),
);
const { isLoading: isOllamaListLoading, data: installedOllamaModels } =
useOllamaListQuery(ollamaConfig, {});
const { mutateAsync: ollamaPull } = useOllamaPullMutation(ollamaConfig, {
onSuccess: (data, input) => {
handlePullProgress(input.model, data);
},
onError: (_, input) => {
pullingModelsMap.current = {
...pullingModelsMap.current,
[input.model]: undefined,
};
pullingModelsMap.delete(input.model);
},
});
const { mutateAsync: ollamaRemove } = useOllamaRemoveMutation(ollamaConfig, {
onSuccess: (_, input) => {
toast.success(`Model ${input.model} removed`);
installedOllamaModelsMap.delete(input.model);
},
onError: (error, input) => {
toast.error(`Error removing ${input.model}. ${error.message}`);
Expand All @@ -70,10 +75,7 @@ export const OllamaModels = () => {
if (!progress) {
continue;
}
pullingModelsMap.current = {
...pullingModelsMap.current,
[model]: progress,
};
pullingModelsMap.set(model, progress);
if (progress.status === 'success') {
toast.success(`Model ${model} pull completed`);
if (auth) {
Expand All @@ -97,36 +99,20 @@ export const OllamaModels = () => {
} catch (error) {
toast.error(`Error pulling model ${model}. ${error?.toString()}`);
} finally {
pullingModelsMap.current = {
...pullingModelsMap.current,
[model]: undefined,
};
pullingModelsMap.delete(model);
}
};
const [installedOllamaModelsMap, setInstalledOllamaModelsMap] = useState(
new Map<string, ModelResponse>(),
);
// const [pullingModelsMap, setPullingModelsMap] = useState<{
// [model: string]: ProgressResponse | undefined;
// }>();
const pullingModelsMap = useRef<{
[model: string]: ProgressResponse | undefined;
}>();

const getProgress = (progress: ProgressResponse): number => {
return Math.ceil((100 * progress.completed) / progress.total);
return Math.ceil((100 * (progress.completed ?? 0)) / (progress.total ?? 1));
};

useEffect(() => {
setInstalledOllamaModelsMap(
new Map(
installedOllamaModels?.models.map((modelResponse) => [
modelResponse.name,
modelResponse,
]) || [],
),
);
}, [installedOllamaModels, setInstalledOllamaModelsMap]);
installedOllamaModels?.models &&
installedOllamaModels.models.forEach((modelResponse) => {
installedOllamaModelsMap.set(modelResponse.name, modelResponse);
});
}, [installedOllamaModels]);

if (!isShinkaiNodeRunning) {
return (
Expand All @@ -149,33 +135,34 @@ export const OllamaModels = () => {
);
}
return (
<ScrollArea className="h-full rounded-md border">
<Table>
<TableHeader className="sticky top-0 bg-gray-700">
<ScrollArea className="h-full flex-1 rounded-md">
<Table className="w-full border-collapse text-[13px]">
<TableHeader className="bg-gray-400 text-xs">
<TableRow>
<TableHead className="w-[300px]">AI Name</TableHead>
<TableHead className="md:w-[300px] lg:w-[480px]">AI Name</TableHead>
<TableHead>Data Limit</TableHead>
<TableHead>Quality</TableHead>
<TableHead>Speed</TableHead>
<TableHead>Size</TableHead>
<TableHead />
<TableHead className="w-[80px]">Size</TableHead>
<TableHead className="w-[180px]" />
</TableRow>
</TableHeader>
<TableBody>
{OLLAMA_MODELS.map((model) => {
return (
<TableRow key={model.fullName}>
<TableRow
className="transition-colors hover:bg-gray-300/50"
key={model.fullName}
>
<TableCell>
<div className="flex flex-col space-y-2">
<div className="flex flex-row space-x-2">
<span className="font-medium">{model.name}</span>
</div>
<span className="text-gray-80 text-ellipsis text-xs">
<div className="flex flex-col items-start gap-2">
<span className="font-medium">{model.name}</span>
{/*<Badge className={cn('text-[8px]')} variant="outline">*/}
{/* {model.fullName}*/}
{/*</Badge>*/}
<span className="text-gray-80 line-clamp-3 text-ellipsis text-xs">
{model.description}
</span>
<Badge className={cn('text-[8px]')} variant="outline">
{model.fullName}
</Badge>
</div>
</TableCell>
<TableCell>
Expand All @@ -189,40 +176,57 @@ export const OllamaModels = () => {
</TableCell>
<TableCell>{model.size} GB</TableCell>
<TableCell>
{isOllamaListLoading ? (
<Loader2 className="animate-spin" />
) : installedOllamaModelsMap.has(model.fullName) ? (
<Button
className="hover:border-brand py-1.5 text-sm hover:bg-transparent hover:text-white"
onClick={() => {
ollamaRemove({ model: model.fullName });
}}
variant={'destructive'}
>
Delete
</Button>
) : pullingModelsMap.current?.[model.fullName] ? (
<div className="flex flex-col space-y-1">
<Progress
className="h-4 w-[150px] bg-gray-700 [&>*]:bg-gray-100"
value={getProgress(
// eslint-disable-next-line @typescript-eslint/no-non-null-asserted-optional-chain
pullingModelsMap.current?.[model.fullName]!,
)}
/>
<span>
{pullingModelsMap.current?.[model.fullName]?.status}
</span>
</div>
) : (
<Button
className="hover:border-brand py-1.5 text-sm hover:bg-transparent hover:text-white"
onClick={() => ollamaPull({ model: model.fullName })}
variant={'outline'}
>
Pull
</Button>
)}
<motion.div
className="flex items-center justify-center"
layout
>
{isOllamaListLoading ? (
<Loader2 className="animate-spin" />
) : installedOllamaModelsMap.has(model.fullName) ? (
<Button
className="hover:border-brand py-1.5 text-sm hover:text-white"
onClick={() => {
ollamaRemove({ model: model.fullName });
}}
size="auto"
variant={'destructive'}
>
<Minus className="mr-2 h-3 w-3" />
Remove
</Button>
) : pullingModelsMap.get(model.fullName) ? (
<div className="flex flex-col items-center gap-1">
<span className="text-xs text-gray-100">
{getProgress(
pullingModelsMap.get(
model.fullName,
) as ProgressResponse,
) + '%'}
</span>
<Progress
className="h-2 w-full bg-gray-200 [&>*]:bg-gray-100"
value={getProgress(
pullingModelsMap.get(
model.fullName,
) as ProgressResponse,
)}
/>
<span className="text-xs text-gray-100">
{pullingModelsMap.get(model.fullName)?.status}
</span>
</div>
) : (
<Button
className="hover:border-brand py-1.5 text-sm hover:bg-transparent hover:text-white"
onClick={() => ollamaPull({ model: model.fullName })}
size="auto"
variant={'outline'}
>
<Download className="mr-2 h-4 w-4" />
Install
</Button>
)}
</motion.div>
</TableCell>
</TableRow>
);
Expand Down
47 changes: 26 additions & 21 deletions apps/shinkai-desktop/src/pages/ai-model-installation.tsx
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
import { buttonVariants } from '@shinkai_network/shinkai-ui';
import { cn } from '@shinkai_network/shinkai-ui/utils';
import { QueryClientProvider } from '@tanstack/react-query';
import { ArrowRight } from 'lucide-react';
import { Link } from 'react-router-dom';

import { SubpageLayout } from './layout/simple-layout';
import { OllamaModels } from '../components/shinkai-node-manager/ollama-models';
import { queryClient } from '../lib/shinkai-node-manager/shinkai-node-manager-client';
import { FixedHeaderLayout } from './layout/simple-layout';

const AIModelInstallation = () => {
return (
<SubpageLayout title="Install AI">
<div className="flex h-full flex-col">
<img
alt="AI Model Installation"
src="https://via.placeholder.com/150"
/>
</div>
<Link
className={cn(
buttonVariants({
size: 'lg',
}),
'mt-4 w-full',
)}
to={{
pathname: '/',
}}
<QueryClientProvider client={queryClient}>
<FixedHeaderLayout
className="relative flex w-full max-w-6xl flex-col gap-2 px-4"
title="Install AI"
>
Continue
</Link>
</SubpageLayout>
<OllamaModels />
<div className="flex justify-center pt-3">
<Link
className={cn(
buttonVariants({
size: 'lg',
}),
'min-w-[200px] gap-2 px-6 py-2.5',
)}
to={{ pathname: '/' }}
>
Continue
<ArrowRight className="h-4 w-4" />
</Link>
</div>
</FixedHeaderLayout>
</QueryClientProvider>
);
};

Expand Down
21 changes: 4 additions & 17 deletions apps/shinkai-desktop/src/pages/get-started.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ const GetStartedPage = () => {
node_encryption_pk: response.data?.encryption_public_key ?? '',
};
setAuth(updatedSetupData);
navigate('/');
// Hide http subscription for now
// navigate('/connect-ai');
navigate('/ai-model-installation');
} else {
throw new Error('Failed to submit registration');
}
Expand Down Expand Up @@ -73,27 +75,12 @@ const GetStartedPage = () => {
}
return (
<OnboardingLayout>
<div className="flex h-full flex-col">
<div className="mx-auto flex h-full max-w-lg flex-col">
<p className="text-gray-80 text-center text-base tracking-wide">
Transform your desktop experience using AI with Shinkai Desktop{' '}
<span aria-hidden> 🔑</span>
</p>
<div className="mt-20 flex flex-1 flex-col gap-10">
{/* Note: Temporary disabled, model manager and http subscriptions are work in progress */}
{/*<Link*/}
{/* className={cn(*/}
{/* buttonVariants({*/}
{/* size: 'lg',*/}
{/* }),*/}
{/* 'w-full',*/}
{/* )}*/}
{/* state={{ connectionType: 'local' }}*/}
{/* to={{*/}
{/* pathname: '/onboarding',*/}
{/* }}*/}
{/*>*/}
{/* Shinkai Private (Local)*/}
{/*</Link>*/}
<Button
isLoading={shinkaiNodeSpawnIsPending}
onClick={() => shinkaiNodeSpawn()}
Expand Down
Loading
Loading