Skip to content

Commit

Permalink
server: update
Browse files Browse the repository at this point in the history
  • Loading branch information
stduhpf committed Nov 20, 2024
1 parent 2210257 commit 951b3f8
Showing 1 changed file with 47 additions and 35 deletions.
82 changes: 47 additions & 35 deletions examples/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ const char* sample_method_str[] = {
"dpm++2s_a",
"dpm++2m",
"dpm++2mv2",
"ipndm",
"ipndm_v",
"lcm",
};

Expand All @@ -48,7 +50,9 @@ const char* schedule_str[] = {
"default",
"discrete",
"karras",
"exponential",
"ays",
"gits",
};

enum SDMode {
Expand Down Expand Up @@ -676,45 +680,45 @@ static void log_server_request(const httplib::Request& req, const httplib::Respo
printf("request: %s %s (%s)\n", req.method.c_str(), req.path.c_str(), req.body.c_str());
}

void parseJsonPrompt(std::string json_str, SDParams* params) {
void parseJsonPrompt(std::string json_str, SDParams& params) {
using namespace nlohmann;
json payload = json::parse(json_str);
// if no exception, the request is a json object
// now we try to get the new param values from the payload object
// const char *prompt, const char *negative_prompt, int clip_skip, float cfg_scale, float guidance, int width, int height, sample_method_t sample_method, int sample_steps, int64_t seed, int batch_count, const sd_image_t *control_cond, float control_strength, float style_strength, bool normalize_input, const char *input_id_images_path
try {
std::string prompt = payload["prompt"];
params->prompt = prompt;
params.prompt = prompt;
} catch (...) {
}
try {
std::string negative_prompt = payload["negative_prompt"];
params->negative_prompt = negative_prompt;
params.negative_prompt = negative_prompt;
} catch (...) {
}
try {
int clip_skip = payload["clip_skip"];
params->clip_skip = clip_skip;
int clip_skip = payload["clip_skip"];
params.clip_skip = clip_skip;
} catch (...) {
}
try {
float cfg_scale = payload["cfg_scale"];
params->cfg_scale = cfg_scale;
float cfg_scale = payload["cfg_scale"];
params.cfg_scale = cfg_scale;
} catch (...) {
}
try {
float guidance = payload["guidance"];
params->guidance = guidance;
float guidance = payload["guidance"];
params.guidance = guidance;
} catch (...) {
}
try {
int width = payload["width"];
params->width = width;
int width = payload["width"];
params.width = width;
} catch (...) {
}
try {
int height = payload["height"];
params->height = height;
int height = payload["height"];
params.height = height;
} catch (...) {
}
try {
Expand All @@ -727,25 +731,25 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
}
}
if (sample_method_found >= 0) {
params->sample_method = (sample_method_t)sample_method_found;
params.sample_method = (sample_method_t)sample_method_found;
} else {
sd_log(sd_log_level_t::SD_LOG_WARN, "Unknown sampling method: %s\n", sample_method.c_str());
}
} catch (...) {
}
try {
int sample_steps = payload["sample_steps"];
params->sample_steps = sample_steps;
int sample_steps = payload["sample_steps"];
params.sample_steps = sample_steps;
} catch (...) {
}
try {
int64_t seed = payload["seed"];
params->seed = seed;
params.seed = seed;
} catch (...) {
}
try {
int batch_count = payload["batch_count"];
params->batch_count = batch_count;
int batch_count = payload["batch_count"];
params.batch_count = batch_count;
} catch (...) {
}

Expand All @@ -759,53 +763,53 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
}
try {
float control_strength = payload["control_strength"];
// params->control_strength = control_strength;
// params.control_strength = control_strength;
// LOG_WARN("control_strength is not supported yet\n");
sd_log(sd_log_level_t::SD_LOG_WARN, "control_strength is not supported yet\n", params);
} catch (...) {
}
try {
float style_strength = payload["style_strength"];
// params->style_strength = style_strength;
// params.style_strength = style_strength;
// LOG_WARN("style_strength is not supported yet\n");
sd_log(sd_log_level_t::SD_LOG_WARN, "style_strength is not supported yet\n", params);
} catch (...) {
}
try {
bool normalize_input = payload["normalize_input"];
params->normalize_input = normalize_input;
bool normalize_input = payload["normalize_input"];
params.normalize_input = normalize_input;
} catch (...) {
}
try {
std::string input_id_images_path = payload["input_id_images_path"];
// TODO replace with b64 image maybe?
params->input_id_images_path = input_id_images_path;
params.input_id_images_path = input_id_images_path;
} catch (...) {
}
try {
std::string slg_scale = payload["slg_scale"];
params->slg_scale = stof(slg_scale);
params.slg_scale = stof(slg_scale);
} catch (...) {
}
// TODO: more slg settings (layers, start and end)
try {
std::vector<int> skip_layers = payload["skip_layers"].get<std::vector<int>>();
params->skip_layers.clear();
params.skip_layers.clear();
for (int i = 0; i < skip_layers.size(); i++) {
params->skip_layers.push_back(skip_layers[i]);
params.skip_layers.push_back(skip_layers[i]);
}
} catch (...) {
}
try {
// skip_layer_start
float skip_layer_start = payload["skip_layer_start"].get<float>();
params->skip_layer_start = skip_layer_start;
float skip_layer_start = payload["skip_layer_start"].get<float>();
params.skip_layer_start = skip_layer_start;
} catch (...) {
}
try {
// skip_layer_end
float skip_layer_end = payload["skip_layer_end"].get<float>();
params->skip_layer_end = skip_layer_end;
float skip_layer_end = payload["skip_layer_end"].get<float>();
params.skip_layer_end = skip_layer_end;
} catch (...) {
}
}
Expand Down Expand Up @@ -863,7 +867,7 @@ const float sd_latent_rgb_proj[4][3]{
{-0.2829, 0.1762, 0.2721},
{-0.2120, -0.2616, -0.7177}};

void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version) {
void proj_latents(struct ggml_tensor* latents, enum SDVersion version, uint8_t* data) {
const int channel = 3;
int width = latents->ne[0];
int height = latents->ne[1];
Expand All @@ -876,7 +880,7 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version

if (version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_2B) {
latent_rgb_proj = sd3_latent_rgb_proj;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
latent_rgb_proj = flux_latent_rgb_proj;
} else {
// unknown model
Expand All @@ -897,7 +901,6 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
// unknown latent space
return;
}
uint8_t* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t));
int data_head = 0;
for (int j = 0; j < height; j++) {
for (int i = 0; i < width; i++) {
Expand Down Expand Up @@ -925,6 +928,15 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
data[data_head++] = (uint8_t)(b * 255.);
}
}
}

void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version) {
const int channel = 3;
int width = latents->ne[0];
int height = latents->ne[1];
int dim = latents->ne[2];
uint8_t* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t));
proj_latents(latents, version, data);
stbi_write_png("latent-preview.png", width, height, channel, data, 0);
free(data);
}
Expand Down Expand Up @@ -982,7 +994,7 @@ int main(int argc, const char* argv[]) {

try {
std::string json_str = req.body;
parseJsonPrompt(json_str, &params);
parseJsonPrompt(json_str, params);
} catch (json::parse_error& e) {
// assume the request is just a prompt
// LOG_WARN("Failed to parse json: %s\n Assuming it's just a prompt...\n", e.what());
Expand Down

0 comments on commit 951b3f8

Please sign in to comment.