Skip to content

Commit

Permalink
nn chg separate string generation from seq sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
jarbus committed Jul 11, 2024
1 parent 94787a7 commit 6a74aca
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/environments/repeatsequence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@ Base.@kwdef struct RepeatSequence <: AbstractEnvironment
seq_len::Int
n_repeat::Int
end
function split_into_strings(n)
# Convert the input to a string and ensure it's exactly 3 characters long
n_str = lpad(n, 3, '0')
if length(n_str) != 3 || any(c -> !isdigit(c), n_str)
throw(ArgumentError("Input must be a 3-digit integer or a valid string representation $n, $n_str"))
end
return n_str[1], n_str[2], n_str[3]
end

# ==== PERFORMANCE CRITICAL BEGIN (on server cpus, which are slow af
function sample_sequence(n_labels, seq_len, n_repeat, i)
rng = StableRNG(i)
seq = Vector{String}(undef, seq_len)
for i in 1:seq_len
seq[i] = string(rand(rng, 0:n_labels))
end
i_base = digits(i-1, base=n_labels) .|> string |> join
seq = split_into_strings(i_base)
concat_seq = join(seq, " ")
repeat_seq = join((concat_seq for i = 1:n_repeat), " ")
repeat_seq
Expand All @@ -26,7 +31,6 @@ function sample_batch(env::RepeatSequence)
seqs = [(sample_sequence(env.n_labels, env.seq_len, env.n_repeat, i),) for i in 1:env.batch_size]
batch = batched(seqs)
batch[1] # get decoder batch

end

function shift_decode_loss(logits, trg, trg_mask)
Expand Down Expand Up @@ -85,7 +89,7 @@ function get_preprocessed_batch(env, tfr)
# Allocating a large amount of memory on the CPU appears to alleviate this
# issue. Garbage collection does not help. Unable to justify spending
# more time on this, if it's resolved.
size(zeros(500_000))
size(zeros(1_000_000))
Main.preprocessed_batch |> deepcopy |> gpu
end

Expand Down

0 comments on commit 6a74aca

Please sign in to comment.