From 613b18d513415e736f6a089aad63647c50c0b36f Mon Sep 17 00:00:00 2001 From: Andrew Powers-Holmes Date: Mon, 15 Jul 2024 22:12:34 +1000 Subject: [PATCH] nvidia ruins lives --- docker-bake.hcl | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/docker-bake.hcl b/docker-bake.hcl index 9f0d58a..833098b 100644 --- a/docker-bake.hcl +++ b/docker-bake.hcl @@ -95,6 +95,15 @@ function torchIndex { ) } +function cudnnTag { + params = [cudaVersion] + result = ( + and(split(".", cudaVersion)[0] >= 12, split(".", cudaVersion)[1] > 1) + ? "cudnn" + : "cudnn8" + ) +} + # set to "true" by github actions, used to disable auto-tag variable "CI" { default = "" } @@ -126,7 +135,7 @@ target "base" { context = "docker/base" target = equal(torch.xformers, "") ? "base" : "xformers-binary" contexts = { - base-cuda = "docker-image://${cudaImage(cuda.version, "devel", "cudnn8")}" + base-cuda = "docker-image://${cudaImage(cuda.version, "devel", cudnnTag(cuda.version))}" } matrix = { torch = [ @@ -136,7 +145,7 @@ target "base" { xformers = "xformers>=0.0.27" }, { - version = "2.3.0" + version = "2.3.1" index = "https://pypi.org/simple" xformers = "xformers>=0.0.27" }, @@ -194,10 +203,10 @@ target xformers-wheel { } target local-torchrelease { - inherits = ["base-cu121-torch230"] + inherits = ["base-cu124-torch230"] target = "xformers-binary" tags = [ - repoImage("base", cudaName("12.1.1"), torchName("2.2.0")), + repoImage("base", cudaName("12.4.1"), torchName("2.3.1")), repoImage("base", "latest"), ] args = {}