Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow 2/1: faster GPU implementation & other small additions #16

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
6 changes: 6 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
try:
from mittens.mittens.tf_mittens import Mittens, GloVe
except ImportError:
from mittens.mittens.np_mittens import Mittens, GloVe

__version__ = "0.2.2"
11 changes: 10 additions & 1 deletion mittens/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
try:
try:
from mittens.tf_mittens import Mittens, GloVe
except:
# print("Failed mittens.tf_mittens")
from mittens.mittens.tf_mittens import Mittens, GloVe
except ImportError:
# print("Failed ANY tf_mittens")
try:
from mittens.np_mittens import Mittens, GloVe
except:
# print("Failed mittens.np_mittens")
from mittens.mittens.np_mittens import Mittens, GloVe

__version__ = "0.2"
__version__ = "0.2.2"
25 changes: 23 additions & 2 deletions mittens/mittens_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from copy import copy
import random
import sys
from time import time

import numpy as np

from mittens.doc import BASE_DOC, MITTENS_PARAM_DESCRIPTION
try:
from mittens.doc import BASE_DOC, MITTENS_PARAM_DESCRIPTION
except:
from mittens.mittens.doc import BASE_DOC, MITTENS_PARAM_DESCRIPTION


class MittensBase(object):
Expand All @@ -31,6 +35,19 @@ def __init__(self, n=100, mittens=0.1, xmax=100, alpha=0.75,
self.max_iter = max_iter
self.errors = list()
self.test_mode = test_mode

def message(self, obj, timer=None):
if type(obj) != str:
obj = str(obj)
elapsed = 0
if timer == 'start':
self._msg_time = time()
elif timer == 'stop':
elapsed = time() - self._msg_time
if elapsed > 0:
obj = obj + ' ({:.1f}s)'.format(elapsed)
print("\r" + obj, flush=True)
return

def fit(self,
X,
Expand Down Expand Up @@ -69,14 +86,18 @@ def fit(self,
embedding of the corresponding element in `vocab`.

"""
self.message("Fitting mco {}".format(X.shape))

if fixed_initialization is not None:
assert self.test_mode, \
"Fixed initialization parameters can only be provided" \
" in test mode. Initialize {} with `test_mode=True`.". \
format(self.__class__.split(".")[-1])
self.message(" Dimensions check")
self._check_dimensions(
X, vocab, initial_embedding_dict
)
self.message(" Initializing weights and log(mco)")
weights, log_coincidence = self._initialize(X)
return self._fit(X, weights, log_coincidence,
vocab=vocab,
Expand Down Expand Up @@ -163,7 +184,7 @@ def _progressbar(self, msg, iter_num):
if self.display_progress and \
(iter_num + 1) % self.display_progress == 0:
sys.stderr.write('\r')
sys.stderr.write("Iteration {}: {}".format(iter_num + 1, msg))
sys.stderr.write("Iteration {}: {}\t\t\t".format(iter_num + 1, msg))
sys.stderr.flush()

def __repr__(self):
Expand Down
14 changes: 12 additions & 2 deletions mittens/np_mittens.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
"""
import numpy as np

from mittens.mittens_base import randmatrix, noise
from mittens.mittens_base import MittensBase, GloVeBase
try:
from mittens.mittens_base import randmatrix, noise
from mittens.mittens_base import MittensBase, GloVeBase
except:
from mittens.mittens.mittens_base import randmatrix, noise
from mittens.mittens.mittens_base import MittensBase, GloVeBase



_FRAMEWORK = "NumPy"
Expand All @@ -35,6 +40,11 @@ class Mittens(MittensBase):
framework=_FRAMEWORK,
second=_DESC.format(model=MittensBase._MODEL))

def __init__(self,
**kwargs):
super().__init__(**kwargs)
self.message("NumPy Mittens initialized.")

@property
def framework(self):
return _FRAMEWORK
Expand Down
Loading