Skip to content

Commit

Permalink
Get home dir from system instead hard-coding it
Browse files Browse the repository at this point in the history
  • Loading branch information
byebye committed May 10, 2018
1 parent 3bc04c1 commit ff4ce8a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
11 changes: 8 additions & 3 deletions data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision import transforms

KAGGLE_PATH = Path.home() / '.kaggle/competitions/digit-recognizer/'
KAGGLE_TEST_PATH = KAGGLE_PATH / 'test.csv'
KAGGLE_TRAIN_PATH = KAGGLE_PATH / 'train.csv'


class DigitRecognizerDataset(Dataset):
def __init__(self, X, Y, pretransform=False):
Expand Down Expand Up @@ -33,7 +37,8 @@ def __len__(self):
return self.len


def train_validation_split(csv_file, max_rows=None, validation_num=None, pretransform=False):
def train_validation_split(csv_file, max_rows=None, validation_num=None,
pretransform=False):
if max_rows:
data = np.genfromtxt(csv_file, delimiter=',',
skip_header=1, max_rows=max_rows)
Expand Down
6 changes: 4 additions & 2 deletions dtc.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import numpy as np
from sklearn import tree

data = np.genfromtxt('/home/arccha/.kaggle/competitions/digit-recognizer/train.csv',
from data import KAGGLE_TRAIN_PATH, KAGGLE_TEST_PATH

data = np.genfromtxt(KAGGLE_TRAIN_PATH,
delimiter=',', skip_header=1)

Y, X = np.split(data, [1], axis=1)

dtc = tree.DecisionTreeClassifier()
dtc = dtc.fit(X, Y)

X = np.genfromtxt('/home/arccha/.kaggle/competitions/digit-recognizer/test.csv',
X = np.genfromtxt(KAGGLE_TEST_PATH,
delimiter=',', skip_header=1)

Y = dtc.predict(X)
Expand Down
16 changes: 6 additions & 10 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
from pathlib import Path
from torch.utils.data import DataLoader
from tqdm import tqdm

import numpy as np
import torch
from cuda import *
from data import *
from net import *
from torch.utils.data import DataLoader
from tqdm import tqdm

test_path = Path('/home/arccha/.kaggle/competitions/digit-recognizer/test.csv')
if not test_path.exists():
test_path = '../test.csv'
if not KAGGLE_TEST_PATH.exists():
KAGGLE_TEST_PATH = '../test.csv'
else:
test_path = str(test_path)
KAGGLE_TEST_PATH = str(KAGGLE_TEST_PATH)

test_dataset = get_test_dataset(test_path)
test_dataset = get_test_dataset(KAGGLE_TEST_PATH)
test_loader = DataLoader(dataset=test_dataset, batch_size=32,
shuffle=False, num_workers=1, pin_memory=True)

Expand Down

0 comments on commit ff4ce8a

Please sign in to comment.