Skip to content

Commit

Permalink
Update Spago to v1.0.2-0.20231029222829-dea27c85cd66;
Browse files Browse the repository at this point in the history
Replace `ag.Node` with `mat.Tensor`
  • Loading branch information
matteo-grella committed Oct 30, 2023
1 parent 47c6ce7 commit 52d622f
Show file tree
Hide file tree
Showing 104 changed files with 655 additions and 5,736 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
with:
go-version: '1.20.3'
go-version: '1.21.3'
- name: Run tests and generate coverage report
run: go test -coverprofile cover.out -covermode atomic ./...
- name: Upload coverage to Codecov
Expand All @@ -23,7 +23,7 @@ jobs:
steps:
- uses: actions/setup-go@v3
with:
go-version: '1.20.3'
go-version: '1.21.3'
- uses: actions/checkout@v3
- name: go vet
run: go vet ./...
Expand All @@ -34,7 +34,7 @@ jobs:
steps:
- uses: actions/setup-go@v3
with:
go-version: '1.20.3'
go-version: '1.21.3'
- name: Install gocyclo
run: go install github.com/fzipp/gocyclo/cmd/gocyclo@latest
- uses: actions/checkout@v3
Expand All @@ -47,7 +47,7 @@ jobs:
steps:
- uses: actions/setup-go@v3
with:
go-version: '1.20.3'
go-version: '1.21.3'
- name: Install staticcheck
run: go install honnef.co/go/tools/cmd/staticcheck@latest
- uses: actions/checkout@v3
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Usage of server:
-network value
network type for server listening
-task value
type of inference/computation that the model can fulfill ("text2text"|"zero-shot-classification"|"question-answering"|"text-classification"|"token-classification"|"text-encoding")
type of inference/computation that the model can fulfill ("textgeneration"|"zero-shot-classification"|"question-answering"|"text-classification"|"token-classification"|"text-encoding")
-tls value
whether to enable TLS ("true"|"false")
-tls-cert value
Expand All @@ -82,7 +82,7 @@ For example, to run Cybertron in server mode for Machine Translation (e.g. `en`
```console
echo "CYBERTRON_MODEL=Helsinki-NLP/opus-mt-en-it" > .env
echo "CYBERTRON_MODELS_DIR=models" >> .env
echo "CYBERTRON_MODEL_TASK=text2text" >> .env
echo "CYBERTRON_MODEL_TASK=text-generation" >> .env
```

and execute the following command:
Expand Down
6 changes: 3 additions & 3 deletions cmd/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
type TaskType string

const (
Text2TextTask TaskType = "text2text"
TextGenerationTask TaskType = "text-generation"
ZeroShotClassificationTask TaskType = "zero-shot-classification"
QuestionAnsweringTask TaskType = "question-answering"
TextClassificationTask TaskType = "text-classification"
Expand All @@ -31,7 +31,7 @@ const (

// TaskTypeValues is the list of supported task types.
var TaskTypeValues = []TaskType{
Text2TextTask,
TextGenerationTask,
ZeroShotClassificationTask,
QuestionAnsweringTask,
TextClassificationTask,
Expand Down Expand Up @@ -124,7 +124,7 @@ func (conf *config) bindFlagSet(fs *flag.FlagSet) {
flagParseFunc(tasks.ParseConversionPolicy, &mm.ConversionPolicy))
fs.Func("model-conversion-precision", `floating-point bits of precision to use if the model is converted ("32"|"64")`,
flagParseFunc(tasks.ParseFloatPrecision, &mm.ConversionPrecision))
fs.Func("task", `type of inference/computation that the model can fulfill ("text2text"|"zero-shot-classification"|"question-answering"|"text-classification"|"token-classification"|"text-encoding"|"language-modeling")`,
fs.Func("task", `type of inference/computation that the model can fulfill ("text-generation"|"zero-shot-classification"|"question-answering"|"text-classification"|"token-classification"|"text-encoding"|"language-modeling")`,
flagParseFunc(ParseTaskType, &conf.task))

s := conf.serverConfig
Expand Down
38 changes: 35 additions & 3 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ import (
"github.com/nlpodyssey/cybertron/pkg/tasks"
"github.com/nlpodyssey/cybertron/pkg/tasks/languagemodeling"
"github.com/nlpodyssey/cybertron/pkg/tasks/questionanswering"
"github.com/nlpodyssey/cybertron/pkg/tasks/text2text"
"github.com/nlpodyssey/cybertron/pkg/tasks/textclassification"
"github.com/nlpodyssey/cybertron/pkg/tasks/textencoding"
"github.com/nlpodyssey/cybertron/pkg/tasks/textgeneration"
"github.com/nlpodyssey/cybertron/pkg/tasks/tokenclassification"
"github.com/nlpodyssey/cybertron/pkg/tasks/zeroshotclassifier"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/mem"
"github.com/shirou/gopsutil/v3/process"
)

const defaultModelsDir = "models"
Expand Down Expand Up @@ -71,6 +74,8 @@ func run() error {
}
defer tasks.Finalize(m)

logMetrics()

requestHandler, err := server.ResolveRequestHandler(m)
if err != nil {
return err
Expand All @@ -84,12 +89,39 @@ func run() error {
return s.Start(ctx)
}

func logMetrics() {
// Set up zerolog to print with human-readable timestamps
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})

// Get total CPU count
totalCpu, _ := cpu.Counts(false)
// Get process CPU percentage
p, _ := process.NewProcess(int32(os.Getpid()))
percent, _ := p.CPUPercent()

log.Info().
Int("total_cpus", totalCpu).
Float64("cpu_used_by_process_percent", percent).
Msg("CPU Metrics")

// Get total available RAM
vmStat, _ := mem.VirtualMemory()
// Get process RAM usage
memInfo, _ := p.MemoryInfo()

log.Info().
Uint64("total_RAM_available", vmStat.Total).
Uint64("RAM_used_by_process", memInfo.RSS).
Msg("RAM Metrics")
}

func loadModelForTask(conf *config) (m any, err error) {
switch conf.task {
case ZeroShotClassificationTask:
return tasks.Load[zeroshotclassifier.Interface](conf.loaderConfig)
case Text2TextTask:
return tasks.Load[text2text.Interface](conf.loaderConfig)
case TextGenerationTask:
return tasks.Load[textgeneration.Interface](conf.loaderConfig)
case QuestionAnsweringTask:
return tasks.Load[questionanswering.Interface](conf.loaderConfig)
case TextClassificationTask:
Expand Down
12 changes: 6 additions & 6 deletions examples/abstractivequestionasnwering/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
//lint:ignore ST1001 allow dot import just to make the example more readable
. "github.com/nlpodyssey/cybertron/examples"
"github.com/nlpodyssey/cybertron/pkg/tasks"
"github.com/nlpodyssey/cybertron/pkg/tasks/text2text"
"github.com/nlpodyssey/cybertron/pkg/tasks/text2text/bart"
"github.com/nlpodyssey/cybertron/pkg/tasks/textgeneration"
"github.com/nlpodyssey/cybertron/pkg/tasks/textgeneration/bart"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
Expand All @@ -35,19 +35,19 @@ func main() {

modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")

m, err := tasks.Load[*bart.Text2Text](&tasks.Config{
m, err := tasks.Load[*bart.TextGeneration](&tasks.Config{
ModelsDir: modelsDir,
ModelName: text2text.DefaultModelForAbstractiveQuestionAnswering,
ModelName: textgeneration.DefaultModelForAbstractiveQuestionAnswering,
})
if err != nil {
log.Fatal().Err(err).Send()
}
defer tasks.Finalize(m)

opts := text2text.DefaultOptions()
opts := textgeneration.DefaultOptions()

start := time.Now()
result, err := m.Generate(context.Background(), text2text.PrepareInputForAbstractiveQuestionAnswering(query, passages), opts)
result, err := m.Generate(context.Background(), textgeneration.PrepareInputForAbstractiveQuestionAnswering(query, passages), opts)
if err != nil {
panic(err)
}
Expand Down
4 changes: 2 additions & 2 deletions examples/relationextraction/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
//lint:ignore ST1001 allow dot import just to make the example more readable
. "github.com/nlpodyssey/cybertron/examples"
"github.com/nlpodyssey/cybertron/pkg/tasks"
"github.com/nlpodyssey/cybertron/pkg/tasks/text2text"
"github.com/nlpodyssey/cybertron/pkg/tasks/textgeneration"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
Expand All @@ -38,7 +38,7 @@ func main() {
}
defer tasks.Finalize(m)

opts := text2text.DefaultOptions()
opts := textgeneration.DefaultOptions()

fn := func(text string) error {
start := time.Now()
Expand Down
69 changes: 64 additions & 5 deletions examples/textgeneration/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,39 @@ import (
"context"
"fmt"
"os"
"runtime"
"time"

//lint:ignore ST1001 allow dot import just to make the example more readable
. "github.com/nlpodyssey/cybertron/examples"
"github.com/nlpodyssey/cybertron/pkg/tasks"
"github.com/nlpodyssey/cybertron/pkg/tasks/text2text"
"github.com/nlpodyssey/cybertron/pkg/tasks/textgeneration"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/mem"
"github.com/shirou/gopsutil/v3/process"
)

func main() {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
LoadDotenv()

modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")
modelName := HasEnvVar("CYBERTRON_MODEL")
modelsDir := "/Users/mg/Projects/nlpodyssey/cybertron/models" //HasEnvVar("CYBERTRON_MODELS_DIR")
modelName := "Helsinki-NLP/opus-mt-it-en"

m, err := tasks.Load[text2text.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName})
start := time.Now()
m, err := tasks.Load[textgeneration.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName})
if err != nil {
log.Fatal().Err(err).Send()
}
defer tasks.Finalize(m)

opts := text2text.DefaultOptions()
log.Debug().Msgf("Loaded model %q in %v", modelName, time.Since(start))

logMetrics()

opts := textgeneration.DefaultOptions()

fn := func(text string) error {
start := time.Now()
Expand All @@ -41,6 +50,7 @@ func main() {
}
fmt.Println(time.Since(start).Seconds())
fmt.Println(result.Texts[0])
runtime.GC()
return nil
}

Expand All @@ -49,3 +59,52 @@ func main() {
log.Fatal().Err(err).Send()
}
}

func logMetrics() error {
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})

// Get total CPU count
totalCpu, err := cpu.Counts(false)
if err != nil {
return err
}
// Get process CPU percentage
p, err := process.NewProcess(int32(os.Getpid()))
if err != nil {
return err
}
percent, err := p.CPUPercent()
if err != nil {
return err
}

// Log CPU Metrics
log.Info().
Int("total_cpu_cores", totalCpu).
Float64("process_cpu_usage_percent", percent).
Msg("CPU Metrics")

// Get total available RAM
vmStat, err := mem.VirtualMemory()
if err != nil {
return err
}
// Get process RAM usage
memInfo, err := p.MemoryInfo()
if err != nil {
return err
}

// Log RAM Metrics
log.Info().
Float64("total_ram_available_mb", byteToMb(vmStat.Total)).
Float64("process_ram_usage_mb", byteToMb(memInfo.RSS)).
Msg("RAM Metrics")

return nil
}

func byteToMb(b uint64) float64 {
return float64(b) / 1024 / 1024
}
Loading

0 comments on commit 52d622f

Please sign in to comment.