Skip to content

Commit

Permalink
Fix bugs due to spago upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
matteo-grella committed Oct 30, 2023
1 parent 52d622f commit 9d9e067
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 9 deletions.
2 changes: 2 additions & 0 deletions pkg/converter/bert/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error {
func mapBaseModel[T float.DType](config bert.Config, pyParams *pytorch.ParamsProvider[T], params paramsMap, vocab *vocabulary.Vocabulary) *bert.Model {
baseModel := bert.New[T](config)

baseModel.Embeddings.Vocab = vocab

{
source := pyParams.Pop("bert.embeddings.word_embeddings.weight")
size := baseModel.Embeddings.Tokens.Dim
Expand Down
5 changes: 5 additions & 0 deletions pkg/models/bart/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ func ConfigFromFile(file string) (Config, error) {
if err != nil {
return Config{}, err
}

// Set default values
if config.MaxLength == 0 {
config.MaxLength = config.MaxPositionEmbeddings
}
return config, nil
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/models/bart/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ func NewEmbeddings[T float.DType](c Config, shared embedding.Shared, isDecoder b

// Encode performs the Bart initial input encoding.
func (m *Embeddings) Encode(inputIDs []int, offset int) []mat.Tensor {
ys := ag.Map2(ag.Add,
m.useScaledEmbeddings(m.SharedEmbeddings.MustEncode(inputIDs)),
m.PositionalEncoder.Encode(makePositions(len(inputIDs), offset)),
)
scaled := m.useScaledEmbeddings(m.SharedEmbeddings.MustEncode(inputIDs))
positions := m.PositionalEncoder.Encode(makePositions(len(inputIDs), offset))
ys := ag.Map2(ag.Add, scaled, positions)

if m.Config.NormalizeEmbedding {
ys = m.Norm.Forward(ys...)
}
Expand Down
8 changes: 6 additions & 2 deletions pkg/models/bart/positionalencoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package bart

import (
"encoding/gob"
"log"

"github.com/nlpodyssey/spago/mat"
"github.com/nlpodyssey/spago/mat/float"
Expand Down Expand Up @@ -41,7 +42,7 @@ func init() {

// NewPositionalEncoder returns a new PositionalEncoder.
func NewPositionalEncoder[T float.DType](config PositionalEncoderConfig) *PositionalEncoder {
e := embedding.New[T](config.NumEmbeddings, config.EmbeddingDim)
e := embedding.New[T](config.NumEmbeddings+config.Offset, config.EmbeddingDim)

size := config.EmbeddingDim
half := (size + (size % 2)) / 2
Expand All @@ -56,7 +57,10 @@ func NewPositionalEncoder[T float.DType](config PositionalEncoderConfig) *Positi
data[half+j/2] = mat.Cos(v)
}
}
item, _ := e.Embedding(i)
item, err := e.Embedding(i)
if err != nil {
log.Fatalf("positional encoder: error getting embedding: %s", err)
}
item.ReplaceValue(mat.NewDense[T](mat.WithBacking(data)))
}
return &PositionalEncoder{
Expand Down
4 changes: 2 additions & 2 deletions pkg/models/bert/bert_for_masked_lm.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ func NewModelForMaskedLM[T float.DType](bert *Model) *ModelForMaskedLM {

// Predict returns the predictions for the token associated to the masked nodes.
func (m *ModelForMaskedLM) Predict(tokens []string) map[int]mat.Tensor {
encoded := evaluate(m.Bert.EncodeTokens(tokens)...)
encoded := waitForComputation(m.Bert.EncodeTokens(tokens)...)
result := make(map[int]mat.Tensor)
for _, id := range masked(tokens) {
result[id] = m.Layers.Forward(encoded[id])[0]
}
return result
}

func evaluate(xs ...mat.Tensor) []mat.Tensor {
func waitForComputation(xs ...mat.Tensor) []mat.Tensor {
for _, x := range xs {
x.Value()
}
Expand Down
13 changes: 12 additions & 1 deletion pkg/models/bert/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package bert

import (
"github.com/nlpodyssey/cybertron/pkg/tokenizers/wordpiecetokenizer"
"github.com/nlpodyssey/cybertron/pkg/vocabulary"
"github.com/nlpodyssey/spago/ag"
"github.com/nlpodyssey/spago/mat"
"github.com/nlpodyssey/spago/mat/float"
Expand All @@ -18,6 +19,7 @@ import (
// Embeddings implements a Bert input embedding module.
type Embeddings struct {
nn.Module
Vocab *vocabulary.Vocabulary
Tokens *emb.Model // string
Positions *emb.Model
TokenTypes *emb.Model
Expand Down Expand Up @@ -46,7 +48,7 @@ func NewEmbeddings[T float.DType](c Config) *Embeddings {
// EncodeTokens performs the Bert input encoding.
func (m *Embeddings) EncodeTokens(tokens []string) []mat.Tensor {
var (
encoded = m.Tokens.MustEncode([]int{}) // TODO: temporary []int{} should the tokens be []int?
encoded = m.Tokens.MustEncode(m.tokensToIDs(tokens))
positions = m.Positions.MustEncode(indices(len(tokens)))
tokenType, _ = m.TokenTypes.Embedding(0)
)
Expand All @@ -62,6 +64,15 @@ func (m *Embeddings) EncodeTokens(tokens []string) []mat.Tensor {
return m.useProjection(m.Norm.Forward(encoded...))
}

// tokensToIDs returns the IDs of the given tokens.
func (m *Embeddings) tokensToIDs(tokens []string) []int {
IDs := make([]int, len(tokens))
for i, token := range tokens {
IDs[i] = m.Vocab.MustID(token)
}
return IDs
}

// useProjection returns the output of the projector if it is not nil, otherwise the input.
func (m *Embeddings) useProjection(xs []mat.Tensor) []mat.Tensor {
if m.Projector == nil {
Expand Down
4 changes: 4 additions & 0 deletions pkg/tasks/zeroshotclassifier/bart/zeroshotclassifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/nlpodyssey/spago/mat"
"github.com/nlpodyssey/spago/mat/float"
"github.com/nlpodyssey/spago/nn"
"github.com/nlpodyssey/spago/nn/embedding"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -51,6 +52,9 @@ func LoadZeroShotClassifier(modelPath string) (*ZeroShotClassifier, error) {
return nil, fmt.Errorf("failed to load bart model: %w", err)
}

m.Bart.Encoder.Embeddings.SharedEmbeddings = embedding.Shared{Model: m.Bart.Embeddings}
m.Bart.Decoder.Embeddings.SharedEmbeddings = embedding.Shared{Model: m.Bart.Embeddings}

entailmentID, err := m.Bart.Config.EntailmentID()
if err != nil {
return nil, err
Expand Down

0 comments on commit 9d9e067

Please sign in to comment.