Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ai: generate: remove endpoint from experiment & remove beta from path #2318

Merged
merged 4 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions packages/api/src/controllers/generate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ afterEach(async () => {
await clearDatabase(server);
});

const testBothRoutes = (testFn) => {
describe("generate route", () => {
testFn("/generate");
});

describe("beta generate route", () => {
testFn("/beta/generate");
});
};
Comment on lines +74 to +82
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep in mind we need to update CloudFlare routes to also skip the regular worker on the new route.


describe("controllers/generate", () => {
let client: TestClient;
let adminUser: User;
Expand Down Expand Up @@ -145,9 +155,9 @@ describe("controllers/generate", () => {
return form;
};

describe("API proxies", () => {
it("should call the AI Gateway for generate API /audio-to-text", async () => {
const res = await client.fetch("/beta/generate/audio-to-text", {
testBothRoutes((basePath) => {
it(`should call the AI Gateway for ${basePath}/audio-to-text`, async () => {
const res = await client.fetch(`${basePath}/audio-to-text`, {
method: "POST",
body: buildMultipartBody(
{},
Expand All @@ -162,8 +172,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "audio-to-text": 1 });
});

it("should call the AI Gateway for generate API /text-to-image", async () => {
const res = await client.post("/beta/generate/text-to-image", {
it(`should call the AI Gateway for ${basePath}/text-to-image`, async () => {
const res = await client.post(`${basePath}/text-to-image`, {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(200);
Expand All @@ -174,8 +184,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "text-to-image": 1 });
});

it("should call the AI Gateway for generate API /image-to-image", async () => {
const res = await client.fetch("/beta/generate/image-to-image", {
it(`should call the AI Gateway for ${basePath}/image-to-image`, async () => {
const res = await client.fetch(`${basePath}/image-to-image`, {
method: "POST",
body: buildMultipartBody({
prompt: "replace the suit with a bathing suit",
Expand All @@ -189,8 +199,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "image-to-image": 1 });
});

it("should call the AI Gateway for generate API /image-to-video", async () => {
const res = await client.fetch("/beta/generate/image-to-video", {
it(`should call the AI Gateway for ${basePath}/image-to-video`, async () => {
const res = await client.fetch(`${basePath}/image-to-video`, {
method: "POST",
body: buildMultipartBody({}),
});
Expand All @@ -202,8 +212,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "image-to-video": 1 });
});

it("should call the AI Gateway for generate API /upscale", async () => {
const res = await client.fetch("/beta/generate/upscale", {
it(`should call the AI Gateway for ${basePath}/upscale`, async () => {
const res = await client.fetch(`${basePath}/upscale`, {
method: "POST",
body: buildMultipartBody({ prompt: "enhance" }),
});
Expand All @@ -215,8 +225,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ upscale: 1 });
});

it("should call the AI Gateway for generate API /segment-anything-2", async () => {
const res = await client.fetch("/beta/generate/segment-anything-2", {
it(`should call the AI Gateway for ${basePath}/segment-anything-2`, async () => {
const res = await client.fetch(`${basePath}/segment-anything-2`, {
method: "POST",
body: buildMultipartBody({}),
});
Expand Down Expand Up @@ -260,7 +270,7 @@ describe("controllers/generate", () => {

for (const [title, input, error] of testCases) {
it(title, async () => {
const res = await client.fetch("/beta/generate/image-to-image", {
const res = await client.fetch("/generate/image-to-image", {
gioelecerati marked this conversation as resolved.
Show resolved Hide resolved
method: "POST",
body: input,
});
Expand All @@ -287,7 +297,7 @@ describe("controllers/generate", () => {
}

it("should log all requests to db", async () => {
const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(200);
Expand Down Expand Up @@ -325,7 +335,7 @@ describe("controllers/generate", () => {
`{"details":{"msg":"sudden error"}}`,
);

const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(500);
Expand All @@ -345,7 +355,7 @@ describe("controllers/generate", () => {
it("should log non JSON outputs as strings to db", async () => {
mockFetchHttpError(418, "text/plain", `I'm not Jason`);

const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(418);
Expand All @@ -364,7 +374,7 @@ describe("controllers/generate", () => {
mockedFetchWithTimeout.mockImplementation(() => {
throw new Error("on your face");
});
const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(500);
Expand Down Expand Up @@ -394,10 +404,10 @@ describe("controllers/generate", () => {

const makeAiGenReq = (pipeline: (typeof pipelines)[number]) =>
pipeline === "text-to-image"
? client.post(`/beta/generate/${pipeline}`, {
? client.post(`/generate/${pipeline}`, {
prompt: "whatever you feel like",
})
: client.fetch(`/beta/generate/${pipeline}`, {
: client.fetch(`/generate/${pipeline}`, {
method: "POST",
body: buildMultipartBody(
pipeline === "image-to-video" ? {} : { prompt: "make magic" },
Expand Down
8 changes: 7 additions & 1 deletion packages/api/src/controllers/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ const aiGenerateDurationMetric = new promclient.Histogram({

const app = Router();

app.use(experimentSubjectsOnly("ai-generate"));
gioelecerati marked this conversation as resolved.
Show resolved Hide resolved
// TODO: Remove beta paths middleware
app.use((req, res, next) => {
if (req.path.startsWith("/beta/generate")) {
req.url = req.url.replace("/beta/generate", "/generate");
}
next();
});
gioelecerati marked this conversation as resolved.
Show resolved Hide resolved

const rateLimiter: RequestHandler = async (req, res, next) => {
const now = Date.now();
Expand Down
2 changes: 2 additions & 0 deletions packages/api/src/controllers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ export default {
"api-token": apiToken,
asset,
auth,
generate,
// TODO: Remove beta paths
"beta/generate": generate,
broadcaster,
clip,
Expand Down
14 changes: 7 additions & 7 deletions packages/api/src/schema/ai-api-schema.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
openapi: 3.1.0
paths:
/api/beta/generate/text-to-image:
/api/generate/text-to-image:
post:
tags:
- generate
Expand Down Expand Up @@ -60,7 +60,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: textToImage
/api/beta/generate/image-to-image:
/api/generate/image-to-image:
post:
tags:
- generate
Expand Down Expand Up @@ -120,7 +120,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: imageToImage
/api/beta/generate/image-to-video:
/api/generate/image-to-video:
post:
tags:
- generate
Expand Down Expand Up @@ -180,7 +180,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: imageToVideo
/api/beta/generate/upscale:
/api/generate/upscale:
post:
tags:
- generate
Expand Down Expand Up @@ -240,7 +240,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: upscale
/api/beta/generate/audio-to-text:
/api/generate/audio-to-text:
post:
tags:
- generate
Expand Down Expand Up @@ -308,7 +308,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: audioToText
/api/beta/generate/segment-anything-2:
/api/generate/segment-anything-2:
post:
tags:
- generate
Expand Down Expand Up @@ -854,4 +854,4 @@ components:
securitySchemes:
HTTPBearer:
type: http
scheme: bearer
scheme: bearer
14 changes: 12 additions & 2 deletions packages/api/src/schema/pull-ai-schema.js
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,20 @@ const downloadAiSchema = async () => {
// add studio-api-error schema
schema.components.schemas["studio-api-error"] = studioApiErrorSchema;

// TODO: Remove beta paths
const newPaths = {};
Object.entries(schema.paths).forEach(([path, value]) => {
const generatePath = `/api/generate${path}`;
const betaGeneratePath = `/api/beta/generate${path}`;
newPaths[generatePath] = value;
newPaths[betaGeneratePath] = value;
});
schema.paths = newPaths;
gioelecerati marked this conversation as resolved.
Show resolved Hide resolved

// patches to the paths section
schema.paths = mapObject(schema.paths, (path, value) => {
// prefix paths with /api/beta/generate
path = `/api/beta/generate${path}`;
// prefix paths with /api/generate
path = `/api/generate${path}`;
// remove security field
delete value.post.security;
// add Studio API error as oneOf to all of the error responses
Expand Down
Loading