diff --git a/test_container/tests/test/python3-cuda-flavor/pytorch.py b/test_container/tests/test/python3-cuda-flavor/pytorch.py index d9865bca..23d1b4ee 100644 --- a/test_container/tests/test/python3-cuda-flavor/pytorch.py +++ b/test_container/tests/test/python3-cuda-flavor/pytorch.py @@ -5,6 +5,7 @@ import time import urllib.request from pathlib import Path +import math import requests from exasol_python_test_framework import udf @@ -20,7 +21,7 @@ def test_pytorch(self): self.query(udf.fixindent(''' CREATE OR REPLACE PYTHON3 SCALAR SCRIPT test_pytorch(epochs INTEGER) - RETURNS VARCHAR(10000) AS + RETURNS DOUBLE AS import torch import torch.nn as nn @@ -60,12 +61,13 @@ def forward(self, x): with torch.no_grad(): y_pred = model(x_train) mse = criterion(y_pred, y_train) - return f'Mean Squared Error: {mse.item():.4f}' + return mse.item() / ''')) row = self.query(f"SELECT pytorchbasic.test_pytorch(1000);")[0] - self.assertIn('Mean Squared Error', row[0]) + + self.assertFalse(math.isnan(row[0])) if __name__ == '__main__':