Skip to content

Commit

Permalink
state_discriminator3: rename: invalid -> mixed
Browse files Browse the repository at this point in the history
  • Loading branch information
guicho271828 committed Sep 10, 2017
1 parent 5972c8f commit 95f8035
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions state_discriminator3.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,40 +57,40 @@ def prune_unreconstructable(sae,data):

def prepare(data_valid, sae):
gen_batch = 10000 if len(data_valid) < 2000 else None
data_invalid = generate_random(data_valid, sae, gen_batch)
data_mixed = generate_random(data_valid, sae, gen_batch)
try:
p = 0
pp = 0
ppp = 0
while len(data_invalid) < len(data_valid) and p < len(data_invalid):
while len(data_mixed) < len(data_valid) and p < len(data_mixed):
p = pp
pp = ppp
ppp = len(data_invalid)
data_invalid = union(data_invalid, generate_random(data_valid, sae, gen_batch))
ppp = len(data_mixed)
data_mixed = union(data_mixed, generate_random(data_valid, sae, gen_batch))
print("valid:",len(data_valid),
"mixed:", len(data_invalid),
"mixed:", len(data_mixed),
"## generation stops when it failed to generate new examples three times in a row")
except KeyboardInterrupt:
pass
finally:
print("generation stopped")

if len(data_valid) < len(data_invalid):
if len(data_valid) < len(data_mixed):
# downsample
data_invalid = data_invalid[:len(data_valid)]
data_mixed = data_mixed[:len(data_valid)]
else:
# oversample
data_invalid = np.repeat(data_invalid, 1+(len(data_valid)//len(data_invalid)), axis=0)
data_invalid = data_invalid[:len(data_valid)]
data_mixed = np.repeat(data_mixed, 1+(len(data_valid)//len(data_mixed)), axis=0)
data_mixed = data_mixed[:len(data_valid)]

train_in, train_out, test_in, test_out = prepare_binary_classification_data(data_valid, data_invalid)
return train_in, train_out, test_in, test_out, data_valid, data_invalid
train_in, train_out, test_in, test_out = prepare_binary_classification_data(data_valid, data_mixed)
return train_in, train_out, test_in, test_out, data_valid, data_mixed

def prepare_random(data_valid, sae, inflation=1):
batch = data_valid.shape[0]
data_invalid = np.random.randint(0,2,data_valid.shape,dtype=np.int8)
train_in, train_out, test_in, test_out = prepare_binary_classification_data(data_valid, data_invalid)
return train_in, train_out, test_in, test_out, data_valid, data_invalid
data_mixed = np.random.randint(0,2,data_valid.shape,dtype=np.int8)
train_in, train_out, test_in, test_out = prepare_binary_classification_data(data_valid, data_mixed)
return train_in, train_out, test_in, test_out, data_valid, data_mixed

sae = None
cae = None
Expand All @@ -110,8 +110,8 @@ def learn(method):
'min_grad' : 0.0,
}
data_valid = np.loadtxt(sae.local("states.csv"),dtype=np.int8)
train_in, train_out, test_in, test_out, data_valid, data_invalid = prepare(data_valid,sae)
sae.plot_autodecode(data_invalid[:8], "_sd3/fake_samples.png")
train_in, train_out, test_in, test_out, data_valid, data_mixed = prepare(data_valid,sae)
sae.plot_autodecode(data_mixed[:8], "_sd3/fake_samples.png")

if method == "feature":
# decode into image, extract features and learn from it
Expand Down Expand Up @@ -257,12 +257,12 @@ def test(method):

################################################################
# type 2 error
_,_,_,_, _, states_invalid = prepare(states_valid[:50000],sae)
print(len(states_invalid),"reconstructable states generated.")
_,_,_,_, _, states_mixed = prepare(states_valid[:50000],sae)
print(len(states_mixed),"reconstructable states generated.")

p = latplan.util.puzzle_module(sae.path)
is_valid = p.validate_states(sae.decode_binary(states_invalid))
states_invalid = states_invalid[np.logical_not(is_valid)]
is_valid = p.validate_states(sae.decode_binary(states_mixed))
states_invalid = states_mixed[np.logical_not(is_valid)]
states_invalid = states_invalid[:30000]
print(len(states_invalid),"invalid states generated.")

Expand Down

0 comments on commit 95f8035

Please sign in to comment.