Skip to content

Commit

Permalink
Fixed test
Browse files Browse the repository at this point in the history
  • Loading branch information
leschultz committed Apr 19, 2024
1 parent 28af8cd commit 01f6413
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/materials/combined/run.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
export PYTHONPATH=../../../src/:$PYTHONPATH
export PYTHONPATH=$(pwd)/../../../src/:$PYTHONPATH
torchrun fit.py
2 changes: 1 addition & 1 deletion examples/synthetic/run.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
export PYTHONPATH=../../src/:$PYTHONPATH
export PYTHONPATH=$(pwd)/../../src/:$PYTHONPATH
torchrun fit.py
6 changes: 3 additions & 3 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ def test_ml(self):
lr = 1e-4
batch_size = 32
n_epochs = 10
tasks = ['data1', 'data2', 'data3']
tasks = ['toy1', 'toy2', 'friedman1']

# Data
X, y = datasets.toy()
X, y = datasets.load(tasks)
data = datasets.splitter(X, y, tasks, train_size=1)

for k, v in data.items():
data[k]['scaler'] = StandardScaler()
data[k]['loss'] = nn.L1Loss()

model = models.MultiNet(tasks=tasks, input_arch={500: 1})
model = models.MultiNet(tasks=tasks, input_arch={50: 1})
optimizer = optim.Adam

out = utils.train(
Expand Down

0 comments on commit 01f6413

Please sign in to comment.