Skip to content

Commit

Permalink
Host the static files (#3)
Browse files Browse the repository at this point in the history
* Migrate to AI server

* Fix test cases

* Remove categories from downloadable assets

This should fix the tests and resolve the issue of the internet being required for the compilation

* Reimplement the cache management to use the process workflow

This way, we avoid race conditions related to trying to download the same file twice, when we request it before the previous download has finished.

* Set MIX_ENV to test on CI

* Remove unnecessary log

* Remove checkout LFS step from CI
  • Loading branch information
daniel-jodlos authored May 27, 2024
1 parent 8d258f0 commit 58e940a
Show file tree
Hide file tree
Showing 27 changed files with 175 additions and 161 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/elixir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ on:
permissions:
contents: read

env:
MIX_ENV: test

jobs:
build:
name: Build and test
Expand All @@ -29,8 +32,6 @@ jobs:
path: deps
key: ${{ runner.os }}-mix-${{ hashFiles('**/mix.lock') }}
restore-keys: ${{ runner.os }}-mix-
- name: Checkout LFS
uses: nschloe/[email protected]
- name: Install dependencies
run: mix deps.get && mix deps.compile
- name: Checks if compiles without warning
Expand Down
21 changes: 0 additions & 21 deletions .github/workflows/release.yml

This file was deleted.

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,4 @@ $RECYCLE.BIN/
*.lnk

# End of https://www.gitignore.io/api/c,vim,linux,macos,elixir,windows,visualstudiocode
models/
File renamed without changes.
File renamed without changes.
File renamed without changes.
3 changes: 1 addition & 2 deletions config/config.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ config :nx, default_backend: EXLA.Backend
config :logger, level: :debug

config :ex_vision,
server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(),
cache_path: System.get_env("EX_VISION_CACHE_DIR", "/tmp/ex_vision/cache")
server_url: URI.new!("https://ai.swmansion.com/exvision/files")

import_config "#{config_env()}.exs"
4 changes: 1 addition & 3 deletions config/dev.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,4 @@ import Config

config :ortex, Ortex.Native, features: ["coreml"]

config :ex_vision,
server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(),
cache_path: System.get_env("EX_VISION_CACHE_DIR", "models")
config :ex_vision, cache_path: "models"
4 changes: 0 additions & 4 deletions config/prod.exs
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import Config

config :logger, level: :info

config :ex_vision,
server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(),
cache_path: System.get_env("EX_VISION_CACHE_DIR", "/tmp/ex_vision/cache")
5 changes: 4 additions & 1 deletion config/runtime.exs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import Config

config :ex_vision,
server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(),
server_url:
"EX_VISION_HOSTING_URI"
|> System.get_env("https://ai.swmansion.com/exvision/files")
|> URI.new!(),
cache_path: System.get_env("EX_VISION_CACHE_DIR", "/tmp/ex_vision/cache")
1 change: 0 additions & 1 deletion examples/3-membrane.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ defmodule Membrane.ExVision.Detector do
|> then(&"#{&1}")
|> :base64.encode()
|> String.to_atom()
|> dbg()

{:ok, pid} = Model.start_link(name: name)

Expand Down
154 changes: 117 additions & 37 deletions lib/ex_vision/cache.ex
Original file line number Diff line number Diff line change
@@ -1,51 +1,130 @@
defmodule ExVision.Cache do
@moduledoc false

# Module responsible for handling model file caching

use GenServer
require Logger

@default_cache_path Application.compile_env(:ex_vision, :cache_path, "/tmp/ex_vision/cache")
defp get_cache_path() do
Application.get_env(:ex_vision, :cache_path, @default_cache_path)
end

@default_server_url Application.compile_env(:ex_vision, :server_url, "http://localhost:8000")
defp get_server_url() do
Application.get_env(:ex_vision, :server_url, @default_server_url)
end

@type lazy_get_option_t() ::
{:cache_path, Path.t()} | {:server_url, String.t() | URI.t()} | {:force, true}
@type lazy_get_option_t() :: {:force, boolean()}

@doc """
Lazily evaluate the path from the cache directory.
It will only download the file if it's missing or the `force: true` option is given.
"""
@spec lazy_get(Path.t(), options :: [lazy_get_option_t()]) ::
@spec lazy_get(term() | pid(), Path.t(), options :: [lazy_get_option_t()]) ::
{:ok, Path.t()} | {:error, reason :: atom()}
def lazy_get(path, options \\ []) do
options =
Keyword.validate!(options,
cache_path: get_cache_path(),
server_url: get_server_url(),
force: false
)

cache_path = Path.join(options[:cache_path], path)
ok? = File.exists?(cache_path)

if ok? and not options[:force] do
Logger.debug("Found existing cache entry for #{path}. Loading.")
{:ok, cache_path}
else
with {:ok, server_url} <- URI.new(options[:server_url]) do
download_url = URI.append_path(server_url, ensure_backslash(path))
download_file(download_url, cache_path)
end
def lazy_get(server, path, options \\ []) do
with {:ok, options} <- Keyword.validate(options, force: false),
do: GenServer.call(server, {:download, path, options}, :infinity)
end

@spec start_link(keyword()) :: GenServer.on_start()
def start_link(opts) do
{init_args, opts} = Keyword.split(opts, [:server_url, :cache_path])
GenServer.start_link(__MODULE__, init_args, opts)
end

@impl true
def init(opts) do
opts = Keyword.validate!(opts, cache_path: get_cache_path(), server_url: get_server_url())

with {:ok, server_url} <- URI.new(opts[:server_url]),
:ok <- File.mkdir_p(opts[:cache_path]) do
{:ok,
%{
downloads: %{},
server_url: server_url,
cache_path: opts[:cache_path],
refs: %{}
}}
end
end

@impl true
def handle_call({:download, cache_path, options}, from, state) do
file_path = Path.join(state.cache_path, cache_path)

updated_downloads =
Map.update(state.downloads, cache_path, MapSet.new([from]), &MapSet.put(&1, from))

cond do
Map.has_key?(state.downloads, cache_path) ->
{:noreply, %{state | downloads: updated_downloads}}

File.exists?(file_path) or options[:force] ->
{:reply, {:ok, file_path}, state}

true ->
ref = do_create_download_job(cache_path, state)

{:noreply,
%{state | downloads: updated_downloads, refs: Map.put(state.refs, ref, cache_path)}}
end
end

@impl true
def handle_info({ref, result}, state) do
state = emit(result, ref, state)
{:noreply, state}
end

@impl true
def handle_info({:DOWN, ref, :process, _pid, reason}, state) do
state =
if reason != :normal do
Logger.error("Task #{inspect(ref)} has crashed due to #{inspect(reason)}")
emit({:error, reason}, ref, state)
else
state
end

{:noreply, state}
end

@impl true
def handle_info(msg, state) do
Logger.warning("Received an unknown message #{inspect(msg)}. Ignoring")
{:noreply, state}
end

defp emit(message, ref, state) do
path = state.refs[ref]

state.downloads
|> Map.get(path, [])
|> Enum.each(fn from ->
GenServer.reply(from, message)
end)

%{state | refs: Map.delete(state.refs, ref), downloads: Map.delete(state.downloads, path)}
end

defp do_create_download_job(path, %{server_url: server_url, cache_path: cache_path}) do
target_file_path = Path.join(cache_path, path)
download_url = URI.append_path(server_url, ensure_backslash(path))

%Task{ref: ref} =
Task.async(fn ->
download_file(download_url, target_file_path)
end)

ref
end

@default_cache_path Application.compile_env(:ex_vision, :cache_path, "/tmp/ex_vision/cache")
defp get_cache_path() do
Application.get_env(:ex_vision, :cache_path, @default_cache_path)
end

@default_server_url Application.compile_env(
:ex_vision,
:server_url,
URI.new!("https://ai.swmansion.com/exvision/files")
)
defp get_server_url() do
Application.get_env(:ex_vision, :server_url, @default_server_url)
end

@spec download_file(URI.t(), Path.t()) ::
{:ok, Path.t()} | {:error, reason :: any()}
defp download_file(url, cache_path) do
Expand All @@ -59,6 +138,9 @@ defmodule ExVision.Cache do
end
end

defp ensure_backslash("/" <> _rest = i), do: i
defp ensure_backslash(i), do: "/" <> i

defp validate_download(path) do
if File.exists?(path),
do: :ok,
Expand All @@ -73,7 +155,8 @@ defmodule ExVision.Cache do
{:ok, _resp} ->
:ok

{:error, _reason} = error ->
{:error, reason} = error ->
Logger.error("Failed to download the file due to #{inspect(reason)}")
File.rm(target_file_path)
error
end
Expand All @@ -100,7 +183,4 @@ defmodule ExVision.Cache do
{:error, :connection_failed}
end
end

defp ensure_backslash("/" <> _rest = path), do: path
defp ensure_backslash(path), do: "/" <> path
end
2 changes: 1 addition & 1 deletion lib/ex_vision/classification/mobilenet_v3_small.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ defmodule ExVision.Classification.MobileNetV3Small do
"""
use ExVision.Model.Definition.Ortex,
model: "mobilenetv3small-classifier.onnx",
categories: "imagenet_v2_categories.json"
categories: "assets/categories/imagenet_v2_categories.json"

require Bunch.Typespec
alias ExVision.Utils
Expand Down
2 changes: 1 addition & 1 deletion lib/ex_vision/detection/fasterrcnn_resnet50_fpn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule ExVision.Detection.FasterRCNN_ResNet50_FPN do
"""
use ExVision.Model.Definition.Ortex,
model: "fasterrcnn_resnet50_fpn_detector.onnx",
categories: "coco_categories.json"
categories: "assets/categories/coco_categories.json"

use ExVision.Detection.GenericDetector

Expand Down
2 changes: 1 addition & 1 deletion lib/ex_vision/detection/ssdlite320_mobilenetv3.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule ExVision.Detection.Ssdlite320_MobileNetv3 do
"""
use ExVision.Model.Definition.Ortex,
model: "ssdlite320_mobilenetv3_detector.onnx",
categories: "coco_categories.json"
categories: "assets/categories/coco_categories.json"

use ExVision.Detection.GenericDetector

Expand Down
10 changes: 10 additions & 0 deletions lib/ex_vision/ex_vision.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
defmodule ExVision do
@moduledoc false
use Application

@impl true
def start(_type, _args) do
children = [{ExVision.Cache, name: ExVision.Cache}]
Supervisor.start_link(children, strategy: :one_for_one)
end
end
10 changes: 4 additions & 6 deletions lib/ex_vision/model/definition/ortex.ex
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ defmodule ExVision.Model.Definition.Ortex do
- `:cache_path` - specifies a caching directory for this model.
- `:providers` - a list of desired providers, sorted by preference. Onnx will attempt to use the first available provider. If none of the provided is available, onnx will fallback to `:cpu`. Default: `[:cpu]`
- `:batch_size` - specifies a default batch size for this instance. Default: `1`
- `:batch_size` - specifies a default batch size for this instance. Default: `1`.
"""
@type load_option_t() ::
{:cache_path, Path.t()}
Expand Down Expand Up @@ -85,13 +85,11 @@ defmodule ExVision.Model.Definition.Ortex do
{:ok, ExVision.Model.t()} | {:error, atom()}
def load_ortex_model(module, model_path, options) do
with {:ok, options} <-
Keyword.validate(options, [
:cache_path,
Keyword.validate(options,
batch_size: 1,
providers: [:cpu]
]),
cache_options = Keyword.take(options, [:cache_path, :file_path]),
{:ok, path} <- ExVision.Cache.lazy_get(model_path, cache_options),
),
{:ok, path} <- ExVision.Cache.lazy_get(ExVision.Cache, model_path),
{:ok, model} <- do_load_model(path, options[:providers]) do
output_names = ExVision.Utils.onnx_output_names(model)

Expand Down
17 changes: 2 additions & 15 deletions lib/ex_vision/model/definition/parts/with_categories.ex
Original file line number Diff line number Diff line change
@@ -1,24 +1,11 @@
defmodule ExVision.Model.Definition.Parts.WithCategories do
@moduledoc false
require Logger
alias ExVision.{Cache, Utils}

defp get_categories(file) do
file
|> Cache.lazy_get()
|> case do
{:ok, file} ->
Utils.load_categories(file)

error ->
Logger.error("Failed to load categories from #{file} due to #{inspect(error)}")
raise "Failed to load categories from #{file}"
end
end
alias ExVision.Utils

defmacro __using__(options) do
options = Keyword.validate!(options, [:name, :categories])
categories = options |> Keyword.fetch!(:categories) |> get_categories()
categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()

quote do
Expand Down
2 changes: 1 addition & 1 deletion lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule ExVision.Segmentation.DeepLabV3_MobileNetV3 do
"""
use ExVision.Model.Definition.Ortex,
model: "deeplab_v3_mobilenetv3_segmentation.onnx",
categories: "coco_with_voc_labels_categories.json"
categories: "assets/categories/coco_with_voc_labels_categories.json"

@type output_t() :: %{category_t() => Nx.Tensor.t()}

Expand Down
2 changes: 2 additions & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ defmodule ExVision.Mixfile do

def application do
[
included_applications: [:ex_vision],
mod: {ExVision, []},
extra_applications: []
]
end
Expand Down
3 changes: 0 additions & 3 deletions models/deeplab_v3_mobilenetv3_segmentation.onnx

This file was deleted.

3 changes: 0 additions & 3 deletions models/fasterrcnn_resnet50_fpn_detector.onnx

This file was deleted.

Loading

0 comments on commit 58e940a

Please sign in to comment.