diff --git a/docker-bake.hcl b/docker-bake.hcl index c31d594..97c5883 100644 --- a/docker-bake.hcl +++ b/docker-bake.hcl @@ -72,14 +72,17 @@ function repoImage { ]) } -# cursed override for cuda 12.1 on torch 2.0.1... +# cursed override for torch2.0.1 & CUDA 12, and torch2.1.0 & CUDA 11.8 function torchIndex { params = [base, version, cuda] result = ( equal(base, "") ? "https://pypi.org/simple" : ( - and(equal(version, "2.0.1"), or(equal(cuda, "12.1.1"), equal(cuda, "12.0.1"))) + or( + and(equal(version, "2.0.1"), notequal(cuda, "11.8.0")), + and(equal(version, "2.1.0"), notequal(cuda, "12.1.1")) + ) ? "${base}/cu118" : "${base}/${cudaName(cuda)}" )