Skip to content

Commit

Permalink
Merge pull request #52 from Penll/main
Browse files Browse the repository at this point in the history
fix fooocus new version, add stop api
  • Loading branch information
konieshadow authored Nov 14, 2023
2 parents 582abd5 + 223f6f6 commit c263ac4
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 11 deletions.
10 changes: 8 additions & 2 deletions fooocusapi/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import uvicorn
from fooocusapi.api_utils import generation_output, req_to_params
import fooocusapi.file_utils as file_utils
from fooocusapi.models import AllModelNamesResponse, AsyncJobResponse, GeneratedImageResult, ImgInpaintOrOutpaintRequest, ImgPromptRequest, ImgUpscaleOrVaryRequest, JobQueueInfo, Text2ImgRequest
from fooocusapi.models import AllModelNamesResponse, AsyncJobResponse,StopResponse , GeneratedImageResult, ImgInpaintOrOutpaintRequest, ImgPromptRequest, ImgUpscaleOrVaryRequest, JobQueueInfo, Text2ImgRequest
from fooocusapi.parameters import GenerationFinishReason, ImageGenerationResult
from fooocusapi.task_queue import TaskType
from fooocusapi.worker import process_generate, task_queue
from fooocusapi.worker import process_generate, task_queue, process_top
from concurrent.futures import ThreadPoolExecutor

app = FastAPI()
Expand Down Expand Up @@ -66,6 +66,8 @@ def call_worker(req: Text2ImgRequest, accept: str):

return results

def stop_worker():
process_top()

@app.get("/")
def home():
Expand Down Expand Up @@ -176,6 +178,10 @@ def all_styles():
from modules.sdxl_styles import legal_style_names
return legal_style_names

@app.get("/v1/generation/stop", response_model=StopResponse, description="Job stoping")
def stop():
stop_worker()
return StopResponse(msg="success")

app.mount("/files", StaticFiles(directory=file_utils.output_dir), name="files")

Expand Down
2 changes: 1 addition & 1 deletion fooocusapi/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fooocusapi.parameters import ImageGenerationParams, ImageGenerationResult, available_aspect_ratios, default_aspect_ratio, inpaint_model_version, default_sampler, default_scheduler, default_base_model_name, default_refiner_model_name
from fooocusapi.task_queue import QueueTask
import modules.flags as flags
import modules.path as path
import modules.config as path
from modules.sdxl_styles import legal_style_names


Expand Down
5 changes: 4 additions & 1 deletion fooocusapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,4 +401,7 @@ class AllModelNamesResponse(BaseModel):

model_config = ConfigDict(
protected_namespaces=('protect_me_', 'also_protect_')
)
)

class StopResponse(BaseModel):
msg: str
6 changes: 5 additions & 1 deletion fooocusapi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
task_queue = TaskQueue(queue_size=3, hisotry_size=6)


def process_top():
import fcbh.model_management
fcbh.model_management.interrupt_current_processing()

@torch.no_grad()
@torch.inference_mode()
def process_generate(queue_task: QueueTask, params: ImageGenerationParams) -> List[ImageGenerationResult]:
Expand All @@ -29,7 +33,7 @@ def process_generate(queue_task: QueueTask, params: ImageGenerationParams) -> Li
import modules.flags as flags
import modules.core as core
import modules.inpaint_worker as inpaint_worker
import modules.path as path
import modules.config as path
import modules.advanced_parameters as advanced_parameters
import modules.constants as constants
import fooocus_extras.preprocessors as preprocessors
Expand Down
11 changes: 5 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def download_models():
]

from modules.model_loader import load_file_from_url
from modules.path import modelfile_path, lorafile_path, vae_approx_path, fooocus_expansion_path, \
checkpoint_downloads, embeddings_path, embeddings_downloads, lora_downloads
from modules.config import path_checkpoints as modelfile_path, path_loras as lorafile_path,path_vae_approx as vae_approx_path,path_fooocus_expansion as fooocus_expansion_path, \
checkpoint_downloads, path_embeddings as embeddings_path, embeddings_downloads, lora_downloads

for file_name, url in checkpoint_downloads.items():
load_file_from_url(url=url, model_dir=modelfile_path, file_name=file_name)
Expand Down Expand Up @@ -263,7 +263,7 @@ def prepare_environments(args) -> bool:
# Add dependent repositories to import path
sys.path.append(script_path)
fooocus_path = os.path.join(script_path, dir_repos, fooocus_name)
sys.path.append(fooocus_path)
sys.path.insert(0, fooocus_path) # need add __init__.py in folder "modules"
backend_path = os.path.join(fooocus_path, 'backend', 'headless')
if backend_path not in sys.path:
sys.path.append(backend_path)
Expand All @@ -280,8 +280,7 @@ def prepare_environments(args) -> bool:

sys.argv.append('--preset')
sys.argv.append(args.preset)

import modules.path as path
import modules.config as path
import fooocusapi.parameters as parameters
parameters.defualt_styles = path.default_styles
parameters.default_base_model_name = path.default_base_model_name
Expand Down Expand Up @@ -328,7 +327,7 @@ class Args(object):
prepare_environments(args)

if load_all_models:
import modules.path as path
import modules.config as path
from fooocusapi.parameters import inpaint_model_version
path.downloading_upscale_model()
path.downloading_inpaint_models(inpaint_model_version)
Expand Down

0 comments on commit c263ac4

Please sign in to comment.