diff --git a/examples/abstractivequestionasnwering/main.go b/examples/abstractivequestionasnwering/main.go index e754f34..5a78473 100644 --- a/examples/abstractivequestionasnwering/main.go +++ b/examples/abstractivequestionasnwering/main.go @@ -33,7 +33,7 @@ func main() { zerolog.SetGlobalLevel(zerolog.TraceLevel) LoadDotenv() - modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR") + modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models") m, err := tasks.Load[*bart.TextGeneration](&tasks.Config{ ModelsDir: modelsDir, diff --git a/examples/languagemodeling/main.go b/examples/languagemodeling/main.go index e3b0819..5193de0 100644 --- a/examples/languagemodeling/main.go +++ b/examples/languagemodeling/main.go @@ -21,8 +21,8 @@ func main() { zerolog.SetGlobalLevel(zerolog.DebugLevel) LoadDotenv() - modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR") - modelName := HasEnvVar("CYBERTRON_MODEL") + modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models") + modelName := HasEnvVarOr("CYBERTRON_MODEL", languagemodeling.DefaultModel) m, err := tasks.Load[languagemodeling.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName}) if err != nil { diff --git a/examples/questionanswering/main.go b/examples/questionanswering/main.go index 74c9074..92aef95 100644 --- a/examples/questionanswering/main.go +++ b/examples/questionanswering/main.go @@ -18,13 +18,15 @@ import ( "github.com/rs/zerolog/log" ) +// Example of content to be used as context for the question answering task. +const content = `Cloud computing is a technology that allows individuals and businesses to access computing resources over the Internet. It enables users to utilize hardware and software that are managed by third parties at remote locations. Services provided by cloud computing include storage solutions, databases, and computing power, which can be used on a pay-per-use basis. This model offers flexibility and scalability, reducing the need for large upfront investments in infrastructure. Major providers of cloud computing services include Amazon Web Services (AWS), Microsoft Azure, and Google Cloud Platform (GCP).` + func main() { zerolog.SetGlobalLevel(zerolog.TraceLevel) LoadDotenv() - modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR") - modelName := HasEnvVar("CYBERTRON_MODEL") - paragraph := HasEnvVar("CYBERTRON_QA_PARAGRAPH") + modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models") + modelName := HasEnvVarOr("CYBERTRON_MODEL", questionanswering.DefaultEnglishModel) m, err := tasks.Load[questionanswering.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName}) if err != nil { @@ -36,7 +38,7 @@ func main() { fn := func(text string) error { start := time.Now() - result, err := m.Answer(context.Background(), text, paragraph, opts) + result, err := m.Answer(context.Background(), text, content, opts) if err != nil { return err } @@ -45,7 +47,7 @@ func main() { return nil } - fmt.Println(paragraph) + fmt.Println(content) err = ForEachInput(os.Stdin, fn) if err != nil { diff --git a/examples/relationextraction/main.go b/examples/relationextraction/main.go index c238a04..88a3754 100644 --- a/examples/relationextraction/main.go +++ b/examples/relationextraction/main.go @@ -27,7 +27,7 @@ func main() { zerolog.SetGlobalLevel(zerolog.DebugLevel) LoadDotenv() - modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR") + modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models") m, err := tasks.LoadModelForTextGeneration(&tasks.Config{ ModelsDir: modelsDir, diff --git a/examples/textclassification/main.go b/examples/textclassification/main.go index bce89a3..d7cab73 100644 --- a/examples/textclassification/main.go +++ b/examples/textclassification/main.go @@ -18,12 +18,14 @@ import ( "github.com/rs/zerolog/log" ) +const limit = 5 // number of labels to show + func main() { zerolog.SetGlobalLevel(zerolog.DebugLevel) LoadDotenv() - modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR") - modelName := HasEnvVar("CYBERTRON_MODEL") + modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models") + modelName := HasEnvVarOr("CYBERTRON_MODEL", textclassification.DefaultModelForGeographicCategorizationMulti) m, err := tasks.Load[textclassification.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName}) if err != nil { @@ -38,7 +40,10 @@ func main() { return err } fmt.Println(time.Since(start).Seconds()) - fmt.Println(result) + + for i := range result.Labels[:limit] { + fmt.Printf("%s\t%0.3f\n", result.Labels[i], result.Scores[i]) + } return nil } diff --git a/examples/textencoding/main.go b/examples/textencoding/main.go index b6f9683..92ee9e9 100644 --- a/examples/textencoding/main.go +++ b/examples/textencoding/main.go @@ -18,14 +18,14 @@ import ( "github.com/rs/zerolog/log" ) -const limit = 10 +const limit = 10 // number of dimensions to show func main() { zerolog.SetGlobalLevel(zerolog.DebugLevel) LoadDotenv() - modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR") - modelName := HasEnvVar("CYBERTRON_MODEL") + modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models") + modelName := HasEnvVarOr("CYBERTRON_MODEL", textencoding.DefaultModelMulti) m, err := tasks.Load[textencoding.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName}) if err != nil { diff --git a/examples/textgeneration/main.go b/examples/textgeneration/main.go index b53fcd0..0939bfd 100644 --- a/examples/textgeneration/main.go +++ b/examples/textgeneration/main.go @@ -26,8 +26,8 @@ func main() { zerolog.SetGlobalLevel(zerolog.DebugLevel) LoadDotenv() - modelsDir := "/Users/mg/Projects/nlpodyssey/cybertron/models" //HasEnvVar("CYBERTRON_MODELS_DIR") - modelName := "Helsinki-NLP/opus-mt-it-en" + modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models") + modelName := HasEnvVarOr("CYBERTRON_MODEL", textgeneration.DefaultModelForMachineTranslation("en", "it")) start := time.Now() m, err := tasks.Load[textgeneration.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName}) diff --git a/examples/tokenclassification/main.go b/examples/tokenclassification/main.go index bd21b6b..227a3ba 100644 --- a/examples/tokenclassification/main.go +++ b/examples/tokenclassification/main.go @@ -22,8 +22,8 @@ func main() { zerolog.SetGlobalLevel(zerolog.DebugLevel) LoadDotenv() - modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR") - modelName := HasEnvVar("CYBERTRON_MODEL") + modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models") + modelName := HasEnvVarOr("CYBERTRON_MODEL", tokenclassification.DefaultEnglishModel) m, err := tasks.Load[tokenclassification.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName}) if err != nil { diff --git a/examples/utils.go b/examples/utils.go index 772ae33..f20c591 100644 --- a/examples/utils.go +++ b/examples/utils.go @@ -43,6 +43,16 @@ func HasEnvVar(key string) string { return value } +// HasEnvVarOr returns the value of the environment variable with the given key. +// It returns the alternative value if the environment variable is not set. +func HasEnvVarOr(key string, alt string) string { + value := os.Getenv(key) + if value == "" || len(strings.Trim(value, " ")) == 0 { + return alt + } + return value +} + // MarshalJSON returns the JSON string representation of the input data func MarshalJSON(data any) string { m, _ := json.MarshalIndent(data, "", " ") diff --git a/examples/zeroshotclassification/main.go b/examples/zeroshotclassification/main.go index 6332e85..c20222c 100644 --- a/examples/zeroshotclassification/main.go +++ b/examples/zeroshotclassification/main.go @@ -23,9 +23,13 @@ func main() { zerolog.SetGlobalLevel(zerolog.DebugLevel) LoadDotenv() - modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR") - modelName := HasEnvVar("CYBERTRON_MODEL") - possibleClasses := HasEnvVar("CYBERTRON_ZERO_SHOT_POSSIBLE_CLASSES") + modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models") + modelName := HasEnvVarOr("CYBERTRON_MODEL", zeroshotclassifier.DefaultModel) + + if len(os.Args) < 2 { + log.Fatal().Msg("missing possible classes (comma separated)") + } + possibleClasses := os.Args[1] m, err := tasks.Load[zeroshotclassifier.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName}) if err != nil { @@ -46,7 +50,10 @@ func main() { return err } fmt.Println(time.Since(start).Seconds()) - fmt.Println(result) + + for i := range result.Labels { + fmt.Printf("%s\t%0.3f\n", result.Labels[i], result.Scores[i]) + } return nil } diff --git a/pkg/models/bart/config.go b/pkg/models/bart/config.go index c164723..7dadab0 100644 --- a/pkg/models/bart/config.go +++ b/pkg/models/bart/config.go @@ -85,6 +85,9 @@ func ConfigFromFile(file string) (Config, error) { if config.MaxLength == 0 { config.MaxLength = config.MaxPositionEmbeddings } + if config.NumBeams == 0 { + config.NumBeams = 4 // TODO: check if this is the default value? + } return config, nil } diff --git a/pkg/tasks/tokenclassification/tokenclassification.go b/pkg/tasks/tokenclassification/tokenclassification.go index 3601dd5..05f7595 100644 --- a/pkg/tasks/tokenclassification/tokenclassification.go +++ b/pkg/tasks/tokenclassification/tokenclassification.go @@ -11,10 +11,15 @@ import ( const ( // DefaultEnglishModel is a model for Named Entities Recognition for the English language. + // It supports the following entities (CoNLL-2003 NER dataset): + // LOC, MISC, ORG, PER + DefaultEnglishModel = "dbmdz/bert-large-cased-finetuned-conll03-english" + + // DefaultEnglishModelOntonotes is a model for Named Entities Recognition for the English language. // It supports the following entities: // CARDINAL, DATE, EVENT, FAC, GPE, LANGUAGE, LAW, LOC, MONEY, NORP, ORDINAL, PERCENT, PERSON, PRODUCT, QUANTITY, TIME, WORK_OF_ART // Model card: https://huggingface.co/djagatiya/ner-bert-base-cased-ontonotesv5-englishv4 - DefaultEnglishModel = "djagatiya/ner-bert-base-cased-ontonotesv5-englishv4" + DefaultEnglishModelOntonotes = "djagatiya/ner-bert-base-cased-ontonotesv5-englishv4" // DefaultModelMulti is a multilingual model for Named Entities Recognition supporting 9 languages: // de, en, es, fr, it, nl, pl, pt, ru.