Skip to content

Commit

Permalink
Check for NaN
Browse files Browse the repository at this point in the history
  • Loading branch information
tomuben committed Jan 3, 2025
1 parent b2379ae commit 0ca39d2
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions test_container/tests/test/python3-cuda-flavor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import urllib.request
from pathlib import Path
import math

import requests
from exasol_python_test_framework import udf
Expand All @@ -20,7 +21,7 @@ def test_pytorch(self):
self.query(udf.fixindent('''
CREATE OR REPLACE PYTHON3 SCALAR SCRIPT
test_pytorch(epochs INTEGER)
RETURNS VARCHAR(10000) AS
RETURNS DOUBLE AS
import torch
import torch.nn as nn
Expand Down Expand Up @@ -60,12 +61,13 @@ def forward(self, x):
with torch.no_grad():
y_pred = model(x_train)
mse = criterion(y_pred, y_train)
return f'Mean Squared Error: {mse.item():.4f}'
return mse.item()
/
'''))

row = self.query(f"SELECT pytorchbasic.test_pytorch(1000);")[0]
self.assertIn('Mean Squared Error', row[0])

self.assertFalse(math.isnan(row[0]))


if __name__ == '__main__':
Expand Down

0 comments on commit 0ca39d2

Please sign in to comment.