diff --git a/app.py b/app.py index d69108e..efc6c30 100644 --- a/app.py +++ b/app.py @@ -91,8 +91,8 @@ def pad(image, height, width, padding_value): return (ops.convert_to_numpy(output) * 255.0).astype("uint8") -def rand_augment(image, num_ops, magnitude): - output = ka_layers.vision.RandAugment(num_ops, magnitude)(image) +def rand_augment(image, p, num_ops, magnitude): + output = ka_layers.vision.RandAugment(p, num_ops, magnitude)(image) return (ops.convert_to_numpy(output) * 255.0).astype("uint8") @@ -186,8 +186,8 @@ def resize(image, height, width): return (ops.convert_to_numpy(output) * 255.0).astype("uint8") -def trivial_augment_wide(image): - output = ka_layers.vision.TrivialAugmentWide()(image) +def trivial_augment_wide(image, p): + output = ka_layers.vision.TrivialAugmentWide(p)(image) return (ops.convert_to_numpy(output) * 255.0).astype("uint8") @@ -328,6 +328,9 @@ def trivial_augment_wide(image): image = gr.Image(astronaut, label="Image") with gr.Column(): args = [ + gr.Number( + 1.0, label="p", minimum=0.0, maximum=1.0, step=0.1 + ), gr.Number(2, label="num_ops"), gr.Number(9, label="magnitude"), ] @@ -595,7 +598,11 @@ def trivial_augment_wide(image): with gr.Column(scale=2): image = gr.Image(astronaut, label="Image") with gr.Column(): - args = [] + args = [ + gr.Number( + 1.0, label="p", minimum=0.0, maximum=1.0, step=0.1 + ), + ] button = gr.Button("Run") with gr.Column(scale=2): outputs = gr.Image(label="Output")