Skip to content

Commit

Permalink
fix: Update samplers to use type and rename 'k_euler_ancestral' to 'k…
Browse files Browse the repository at this point in the history
…_euler_a'. (#50)
  • Loading branch information
daveschumaker authored Sep 1, 2024
1 parent 3ba3f17 commit d5cda1b
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 54 deletions.
18 changes: 2 additions & 16 deletions app/_components/AdvancedOptions/SamplerSelect.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,7 @@
import { useInput } from '@/app/_providers/PromptInputProvider'
import Select, { SelectOption } from '../Select'
import OptionLabel from './OptionLabel'

export type SamplerOption =
| 'DDIM'
| 'k_dpm_2_a'
| 'k_dpm_2'
| 'k_dpm_adaptive'
| 'k_dpm_fast'
| 'k_dpmpp_2m'
| 'k_dpmpp_2s_a'
| 'k_dpmpp_sde'
| 'k_euler_a'
| 'k_euler'
| 'k_heun'
| 'k_lms'
| 'lcm'
import { SamplerOption } from '@/app/_types/HordeTypes'

const samplers: Array<{ value: SamplerOption; label: SamplerOption }> = [
{ value: 'DDIM', label: 'DDIM' },
Expand Down Expand Up @@ -47,7 +33,7 @@ export default function SamplerSelect() {
<div className="w-full">
<Select
onChange={(option: SelectOption) => {
setInput({ sampler: option.value as string })
setInput({ sampler: option.value as SamplerOption })
}}
options={samplers.map((sampler) => ({
value: sampler.value,
Expand Down
70 changes: 35 additions & 35 deletions app/_data-models/ImageParamsForHordeApi.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,40 @@ describe('ImageParamsForHordeApi', () => {
height: 512,
width: 512,
steps: 30,
sampler: 'k_euler_ancestral',
sampler: 'k_euler_a',
seed: '123456',
models: ['stable_diffusion']
})

// Mock AppSettings.get
;(AppSettings.get as jest.Mock).mockImplementation(
(key: AppSettingsKey) => {
const settings: AppSettingsParams = {
allowedWorkers: [],
allowNsfwImages: false,
apiKey: 'test-api-key',
autoDowngrade: false,
blockedWorkers: [],
civitAiBaseModelFilter: [],
negativePanelOpen: false,
runInBackground: false,
saveInputOnCreate: false,
sharedKey: '',
slow_workers: true,
useAllowedWorkers: false,
useBeta: false,
useBlockedWorkers: false,
useReplacementFilter: true,
useTrusted: true
// Mock AppSettings.get
; (AppSettings.get as jest.Mock).mockImplementation(
(key: AppSettingsKey) => {
const settings: AppSettingsParams = {
allowedWorkers: [],
allowNsfwImages: false,
apiKey: 'test-api-key',
autoDowngrade: false,
blockedWorkers: [],
civitAiBaseModelFilter: [],
negativePanelOpen: false,
runInBackground: false,
saveInputOnCreate: false,
sharedKey: '',
slow_workers: true,
useAllowedWorkers: false,
useBeta: false,
useBlockedWorkers: false,
useReplacementFilter: true,
useTrusted: true
}

if (key in settings) {
return settings[key]
}

throw new Error(`Unexpected key: ${key}`)
}

if (key in settings) {
return settings[key]
}

throw new Error(`Unexpected key: ${key}`)
}
)
)
})

test('setBaseParams should set correct base parameters', () => {
Expand All @@ -77,7 +77,7 @@ describe('ImageParamsForHordeApi', () => {
height: 512,
width: 512,
steps: 30,
sampler_name: 'k_euler_ancestral',
sampler_name: 'k_euler_a',
seed: '123456'
},
nsfw: false,
Expand Down Expand Up @@ -141,12 +141,12 @@ describe('ImageParamsForHordeApi', () => {

test('setSourceProcessing should set correct source processing parameters', async () => {
// Mock the getImagesForArtbotJobFromDexie function
;(dbModule.getImagesForArtbotJobFromDexie as jest.Mock).mockResolvedValue([
; (dbModule.getImagesForArtbotJobFromDexie as jest.Mock).mockResolvedValue([
{ imageBlobBuffer: new ArrayBuffer(8) }
])

// Mock blobToBase64
;(imageUtils.blobToBase64 as jest.Mock).mockResolvedValue('base64string')
// Mock blobToBase64
; (imageUtils.blobToBase64 as jest.Mock).mockResolvedValue('base64string')

promptInput.source_processing = SourceProcessing.Img2Img
promptInput.denoising_strength = 0.6
Expand Down Expand Up @@ -178,7 +178,7 @@ describe('ImageParamsForHordeApi', () => {
height: 512,
width: 512,
steps: 30,
sampler_name: 'k_euler_ancestral',
sampler_name: 'k_euler_a',
seed: '123456',
n: 1,
post_processing: [],
Expand All @@ -205,7 +205,7 @@ describe('ImageParamsForHordeApi', () => {
expect(result.height).toBe(512)
expect(result.width).toBe(512)
expect(result.steps).toBe(30)
expect(result.sampler).toBe('k_euler_ancestral')
expect(result.sampler).toBe('k_euler_a')
expect(result.seed).toBe('123456')
expect(result.models).toEqual(['stable_diffusion'])
})
Expand Down
3 changes: 2 additions & 1 deletion app/_data-models/ImageParamsForHordeApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
ControlTypes,
HordeTi,
Lora,
SamplerOption,
SourceProcessing
} from '../_types/HordeTypes'
import { castTiInject } from '../_utils/hordeUtils'
Expand All @@ -26,7 +27,7 @@ interface HordeApiParamsBuilderInterface {
}

export interface ImageParams {
sampler_name?: string // Optional due to ControlNet
sampler_name?: SamplerOption // Optional due to ControlNet
cfg_scale: number
height: number
width: number
Expand Down
2 changes: 1 addition & 1 deletion app/_data-models/PromptInput.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ describe('PromptInput', () => {
expect(defaultPromptInput.preset).toEqual([])
expect(defaultPromptInput.prompt).toBe('')
expect(defaultPromptInput.return_control_map).toBe(false)
expect(defaultPromptInput.sampler).toBe('euler_a')
expect(defaultPromptInput.sampler).toBe('k_euler_a')
expect(defaultPromptInput.seed).toBe('')
expect(defaultPromptInput.source_processing).toBe(SourceProcessing.Prompt)
expect(defaultPromptInput.steps).toBe(8)
Expand Down
3 changes: 2 additions & 1 deletion app/_data-models/PromptInput.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { ImageOrientations, JobType, Workflow } from '@/app/_types/ArtbotTypes'
import {
ControlTypes,
SamplerOption,
SourceProcessing,
StylePresetConfig
} from '@/app/_types/HordeTypes'
Expand Down Expand Up @@ -64,7 +65,7 @@ class PromptInput {
post_processing: Array<string> = []
prompt: string = ''
return_control_map: boolean = false
sampler: string = 'euler_a'
sampler: SamplerOption = 'k_euler_a'
seed: string = ''
source_processing?: SourceProcessing = SourceProcessing.Prompt
steps: number = 8
Expand Down
15 changes: 15 additions & 0 deletions app/_types/HordeTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,21 @@ export interface LoraConfig {
is_version: boolean
}

export type SamplerOption =
| 'DDIM'
| 'k_dpm_2_a'
| 'k_dpm_2'
| 'k_dpm_adaptive'
| 'k_dpm_fast'
| 'k_dpmpp_2m'
| 'k_dpmpp_2s_a'
| 'k_dpmpp_sde'
| 'k_euler_a'
| 'k_euler'
| 'k_heun'
| 'k_lms'
| 'lcm'

export interface SharedApiKey {
id: string
kudos: number
Expand Down

0 comments on commit d5cda1b

Please sign in to comment.