diff --git a/pkg/converter/bert/convert.go b/pkg/converter/bert/convert.go index 0452471..ff289eb 100644 --- a/pkg/converter/bert/convert.go +++ b/pkg/converter/bert/convert.go @@ -136,6 +136,8 @@ func Convert[T float.DType](modelDir string, overwriteIfExist bool) error { func mapBaseModel[T float.DType](config bert.Config, pyParams *pytorch.ParamsProvider[T], params paramsMap, vocab *vocabulary.Vocabulary) *bert.Model { baseModel := bert.New[T](config) + baseModel.Embeddings.Vocab = vocab + { source := pyParams.Pop("bert.embeddings.word_embeddings.weight") size := baseModel.Embeddings.Tokens.Dim diff --git a/pkg/models/bart/config.go b/pkg/models/bart/config.go index 8571ef2..c164723 100644 --- a/pkg/models/bart/config.go +++ b/pkg/models/bart/config.go @@ -80,6 +80,11 @@ func ConfigFromFile(file string) (Config, error) { if err != nil { return Config{}, err } + + // Set default values + if config.MaxLength == 0 { + config.MaxLength = config.MaxPositionEmbeddings + } return config, nil } diff --git a/pkg/models/bart/embeddings.go b/pkg/models/bart/embeddings.go index 6732954..3f211bd 100644 --- a/pkg/models/bart/embeddings.go +++ b/pkg/models/bart/embeddings.go @@ -57,10 +57,10 @@ func NewEmbeddings[T float.DType](c Config, shared embedding.Shared, isDecoder b // Encode performs the Bart initial input encoding. func (m *Embeddings) Encode(inputIDs []int, offset int) []mat.Tensor { - ys := ag.Map2(ag.Add, - m.useScaledEmbeddings(m.SharedEmbeddings.MustEncode(inputIDs)), - m.PositionalEncoder.Encode(makePositions(len(inputIDs), offset)), - ) + scaled := m.useScaledEmbeddings(m.SharedEmbeddings.MustEncode(inputIDs)) + positions := m.PositionalEncoder.Encode(makePositions(len(inputIDs), offset)) + ys := ag.Map2(ag.Add, scaled, positions) + if m.Config.NormalizeEmbedding { ys = m.Norm.Forward(ys...) } diff --git a/pkg/models/bart/positionalencoder.go b/pkg/models/bart/positionalencoder.go index 9043fb2..da66eb5 100644 --- a/pkg/models/bart/positionalencoder.go +++ b/pkg/models/bart/positionalencoder.go @@ -6,6 +6,7 @@ package bart import ( "encoding/gob" + "log" "github.com/nlpodyssey/spago/mat" "github.com/nlpodyssey/spago/mat/float" @@ -41,7 +42,7 @@ func init() { // NewPositionalEncoder returns a new PositionalEncoder. func NewPositionalEncoder[T float.DType](config PositionalEncoderConfig) *PositionalEncoder { - e := embedding.New[T](config.NumEmbeddings, config.EmbeddingDim) + e := embedding.New[T](config.NumEmbeddings+config.Offset, config.EmbeddingDim) size := config.EmbeddingDim half := (size + (size % 2)) / 2 @@ -56,7 +57,10 @@ func NewPositionalEncoder[T float.DType](config PositionalEncoderConfig) *Positi data[half+j/2] = mat.Cos(v) } } - item, _ := e.Embedding(i) + item, err := e.Embedding(i) + if err != nil { + log.Fatalf("positional encoder: error getting embedding: %s", err) + } item.ReplaceValue(mat.NewDense[T](mat.WithBacking(data))) } return &PositionalEncoder{ diff --git a/pkg/models/bert/bert_for_masked_lm.go b/pkg/models/bert/bert_for_masked_lm.go index 1b201a6..83e2912 100644 --- a/pkg/models/bert/bert_for_masked_lm.go +++ b/pkg/models/bert/bert_for_masked_lm.go @@ -47,7 +47,7 @@ func NewModelForMaskedLM[T float.DType](bert *Model) *ModelForMaskedLM { // Predict returns the predictions for the token associated to the masked nodes. func (m *ModelForMaskedLM) Predict(tokens []string) map[int]mat.Tensor { - encoded := evaluate(m.Bert.EncodeTokens(tokens)...) + encoded := waitForComputation(m.Bert.EncodeTokens(tokens)...) result := make(map[int]mat.Tensor) for _, id := range masked(tokens) { result[id] = m.Layers.Forward(encoded[id])[0] @@ -55,7 +55,7 @@ func (m *ModelForMaskedLM) Predict(tokens []string) map[int]mat.Tensor { return result } -func evaluate(xs ...mat.Tensor) []mat.Tensor { +func waitForComputation(xs ...mat.Tensor) []mat.Tensor { for _, x := range xs { x.Value() } diff --git a/pkg/models/bert/embeddings.go b/pkg/models/bert/embeddings.go index 4fddec8..29cbb31 100644 --- a/pkg/models/bert/embeddings.go +++ b/pkg/models/bert/embeddings.go @@ -6,6 +6,7 @@ package bert import ( "github.com/nlpodyssey/cybertron/pkg/tokenizers/wordpiecetokenizer" + "github.com/nlpodyssey/cybertron/pkg/vocabulary" "github.com/nlpodyssey/spago/ag" "github.com/nlpodyssey/spago/mat" "github.com/nlpodyssey/spago/mat/float" @@ -18,6 +19,7 @@ import ( // Embeddings implements a Bert input embedding module. type Embeddings struct { nn.Module + Vocab *vocabulary.Vocabulary Tokens *emb.Model // string Positions *emb.Model TokenTypes *emb.Model @@ -46,7 +48,7 @@ func NewEmbeddings[T float.DType](c Config) *Embeddings { // EncodeTokens performs the Bert input encoding. func (m *Embeddings) EncodeTokens(tokens []string) []mat.Tensor { var ( - encoded = m.Tokens.MustEncode([]int{}) // TODO: temporary []int{} should the tokens be []int? + encoded = m.Tokens.MustEncode(m.tokensToIDs(tokens)) positions = m.Positions.MustEncode(indices(len(tokens))) tokenType, _ = m.TokenTypes.Embedding(0) ) @@ -62,6 +64,15 @@ func (m *Embeddings) EncodeTokens(tokens []string) []mat.Tensor { return m.useProjection(m.Norm.Forward(encoded...)) } +// tokensToIDs returns the IDs of the given tokens. +func (m *Embeddings) tokensToIDs(tokens []string) []int { + IDs := make([]int, len(tokens)) + for i, token := range tokens { + IDs[i] = m.Vocab.MustID(token) + } + return IDs +} + // useProjection returns the output of the projector if it is not nil, otherwise the input. func (m *Embeddings) useProjection(xs []mat.Tensor) []mat.Tensor { if m.Projector == nil { diff --git a/pkg/tasks/zeroshotclassifier/bart/zeroshotclassifier.go b/pkg/tasks/zeroshotclassifier/bart/zeroshotclassifier.go index e62fb21..6203345 100644 --- a/pkg/tasks/zeroshotclassifier/bart/zeroshotclassifier.go +++ b/pkg/tasks/zeroshotclassifier/bart/zeroshotclassifier.go @@ -19,6 +19,7 @@ import ( "github.com/nlpodyssey/spago/mat" "github.com/nlpodyssey/spago/mat/float" "github.com/nlpodyssey/spago/nn" + "github.com/nlpodyssey/spago/nn/embedding" "golang.org/x/sync/errgroup" ) @@ -51,6 +52,9 @@ func LoadZeroShotClassifier(modelPath string) (*ZeroShotClassifier, error) { return nil, fmt.Errorf("failed to load bart model: %w", err) } + m.Bart.Encoder.Embeddings.SharedEmbeddings = embedding.Shared{Model: m.Bart.Embeddings} + m.Bart.Decoder.Embeddings.SharedEmbeddings = embedding.Shared{Model: m.Bart.Embeddings} + entailmentID, err := m.Bart.Config.EntailmentID() if err != nil { return nil, err