From 6de2a4d1f1c0111849479e2f8be8580809f60802 Mon Sep 17 00:00:00 2001 From: Ahmed Almaghz <53489256+AhmedAlmaghz@users.noreply.github.com> Date: Mon, 11 Nov 2024 21:41:01 +0300 Subject: [PATCH 1/2] [i18n-ar] Translated file : `docs/source/ar/torchscript.md` into Arabic (#33079) * Add docs/source/ar/torchscript.md to Add_docs_source_ar_torchscript.md * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Update docs/source/ar/torchscript.md Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> * Merge troubleshooting.md with this Branch * Update _toctree.yml * Update torchscript.md * Update troubleshooting.md --------- Co-authored-by: Abdullah Mohammed <554032+abodacs@users.noreply.github.com> --- docs/source/ar/_toctree.yml | 8 +- docs/source/ar/torchscript.md | 154 +++++++++++++++++++++++++++ docs/source/ar/troubleshooting.md | 171 ++++++++++++++++++++++++++++++ 3 files changed, 329 insertions(+), 4 deletions(-) create mode 100644 docs/source/ar/torchscript.md create mode 100644 docs/source/ar/troubleshooting.md diff --git a/docs/source/ar/_toctree.yml b/docs/source/ar/_toctree.yml index 67564c43556db7..d9523eaf5da535 100644 --- a/docs/source/ar/_toctree.yml +++ b/docs/source/ar/_toctree.yml @@ -127,16 +127,16 @@ title: التصدير إلى ONNX - local: tflite title: التصدير إلى TFLite -# - local: torchscript -# title: التصدير إلى TorchScript + - local: torchscript + title: التصدير إلى TorchScript # - local: benchmarks # title: المعايير # - local: notebooks # title: دفاتر الملاحظات مع الأمثلة # - local: community # title: موارد المجتمع -# - local: troubleshooting -# title: استكشاف الأخطاء وإصلاحها + - local: troubleshooting + title: استكشاف الأخطاء وإصلاحها - local: gguf title: التوافق مع ملفات GGUF title: أدلة المطورين diff --git a/docs/source/ar/torchscript.md b/docs/source/ar/torchscript.md new file mode 100644 index 00000000000000..bf0bc0dde04b62 --- /dev/null +++ b/docs/source/ar/torchscript.md @@ -0,0 +1,154 @@ +# التصدير إلى TorchScript + + + +هذه هي بداية تجاربنا مع TorchScript ولا زلنا نستكشف قدراته مع نماذج المدخلات المتغيرة الحجم. إنه مجال اهتمامنا وسنعمق تحليلنا في الإصدارات القادمة، مع المزيد من الأمثلة البرمجية، وتنفيذ أكثر مرونة، ومقاييس مقارنة بين الأكواد القائمة على Python مع أكواد TorchScript المُجمّعة. + + + +وفقًا لـ [وثائق TorchScript](https://pytorch.org/docs/stable/jit.html): + +> TorchScript هي طريقة لإنشاء نماذج قابلة للتسلسل والتحسين من تعليمات PyTorch البرمجية. + +هناك وحدتان من PyTorch، [JIT and TRACE](https://pytorch.org/docs/stable/jit.html)، تتيحان للمطورين تصدير نماذجهم لإعادة استخدامها في برامج أخرى مثل برامج C++ المُحسّنة للأداء. + +نقدم واجهة تتيح لك تصدير نماذج 🤗 Transformers إلى TorchScript بحيث يمكن إعادة استخدامها في بيئة مختلفة عن برامج Python القائمة إلى PyTorch. هنا نشرح كيفية تصدير نماذجنا واستخدامها باستخدام TorchScript. + +يتطلب تصدير نموذج أمرين: + +- تهيئة مثيل للنموذج باستخدام علامة `torchscript` +- تمرير مُدخلات وهمية (dummy inputs) خلال النموذج + +تنطوي هذه الضرورات على عدة أمور يجب على المطورين توخي الحذر بشأنها كما هو مفصل أدناه. + +## علامة TorchScript والأوزان المرتبطة + +علامة `torchscript` ضرورية لأن معظم نماذج اللغة 🤗 Transformers لها أوزان مرتبطة بين طبقة `Embedding` وطبقة `Decoding`. لا يسمح لك TorchScript بتصدير النماذج ذات الأوزان المرتبطة، لذلك من الضروري فصل الأوزان ونسخها مسبقًا. + +النماذج المُهيأة باستخدام علامة `torchscript` لها طبقة `Embedding` وطبقة`Decoding` منفصلتين، مما يعني أنه لا ينبغي تدريبها لاحقًا. سيؤدي التدريب إلى عدم تزامن الطبقتين، مما يؤدي إلى نتائج غير متوقعة. + +هذا لا ينطبق على النماذج التي لا تحتوي على رأس نموذج اللغة، حيث لا تملك أوزانًا مرتبطة. يمكن تصدير هذه النماذج بأمان دون علامة `torchscript`. + +## المدخلات الوهمية والأطوال القياسية + +تُستخدم المُدخلات الوهمية لتمرير أمامي خلال النموذج. أثناء انتشار قيم المُدخلات عبر الطبقات، يتتبع PyTorch العمليات المختلفة التي يتم تنفيذها على كل مصفوفة(tensor). ثم يتم استخدام هذه العمليات المُسجلة بعد ذلك لإنشاء *أثر* النموذج. + +يتم إنشاء التتبع بالنسبة لأبعاد المُدخلات. وبالتالي، فهو مُقيّد بأبعاد المُدخلات الوهمية، ولن يعمل لأي طول تسلسل أو حجم دفعة مختلف. عند المحاولة بحجم مختلف، يتم رفع الخطأ التالي: + +``` +`The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2` +``` + +نوصي بتتبع النموذج باستخدام حجم مُدخلات وهمية لا يقل عن أكبر مُدخل سيتم تقديمه للنموذج أثناء الاستدلال. يمكن أن تساعد الحشوة(padding) في ملء القيم المفقودة. ومع ذلك، نظرًا لتتبع النموذج بحجم مُدخل أكبر، ستكون أبعاد المصفوفة ستكون كبيرة أيضًا، مما يؤدي عنه المزيد من الحسابات. + +انتبه إلى إجمالي عدد العمليات المُنفذة على كل مُدخل وتابع الأداء عن كثب عند تصدير نماذج متغيرة طول التسلسل. + +## استخدام TorchScript في Python + +يوضح هذا القسم كيفية حفظ النماذج وتحميلها، بالإضافة إلى كيفية استخدام التتبع للاستدلال. + +### حفظ نموذج + +لتصدير `BertModel` باستخدام TorchScript، قم بتهيئة ـ `BertModel` من فئة `BertConfig` ثم احفظه على القرص تحت اسم الملف `traced_bert.pt`: + +```python +from transformers import BertModel, BertTokenizer, BertConfig +import torch + +enc = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") + +# Tokenizing input text +text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" +tokenized_text = enc.tokenize(text) + +# Masking one of the input tokens +masked_index = 8 +tokenized_text[masked_index] = "[MASK]" +indexed_tokens = enc.convert_tokens_to_ids(tokenized_text) +segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] + +# Creating a dummy input +tokens_tensor = torch.tensor([indexed_tokens]) +segments_tensors = torch.tensor([segments_ids]) +dummy_input = [tokens_tensor, segments_tensors] + +# Initializing the model with the torchscript flag +# Flag set to True even though it is not necessary as this model does not have an LM Head. +config = BertConfig( + vocab_size_or_config_json_file=32000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + torchscript=True, +) + +# Instantiating the model +model = BertModel(config) + +# The model needs to be in evaluation mode +model.eval() + +# If you are instantiating the model with *from_pretrained* you can also easily set the TorchScript flag +model = BertModel.from_pretrained("google-bert/bert-base-uncased", torchscript=True) + +# Creating the trace +traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) +torch.jit.save(traced_model, "traced_bert.pt") +``` + +### تحميل نموذج + +يمكنك الآن تحميل `BertModel` المُحفظ سابقًا، `traced_bert.pt`، من القرص واستخدامه على `dummy_input` المُهيأ سابقًا: + +```python +loaded_model = torch.jit.load("traced_bert.pt") +loaded_model.eval() + +all_encoder_layers, pooled_output = loaded_model(*dummy_input) +``` + +### استخدام نموذج مُتتبع للاستدلال + +استخدم النموذج المُتتبع للاستدلال باستخدام أسلوب `__call__` الخاص به: + +```python +traced_model(tokens_tensor, segments_tensors) +``` + +## نشر نماذج Hugging Face TorchScript على AWS باستخدام Neuron SDK + +قدمت AWS عائلة [Amazon EC2 Inf1](https://aws.amazon.com/ec2/instance-types/inf1/) من اﻷجهزة لخفض التكلفة وأداء التعلم الآلي عالي الأداء في البيئة السحابية. تعمل أجهزة Inf1 بواسطة شريحة Inferentia من AWS، وهي مُسرّع أجهزة مُخصص، متخصص في أعباء عمل الاستدلال للتعلم العميق. [AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/#) هي SDK لـ Inferentia التي تدعم تتبع نماذج المحولات وتحسينها للنشر على Inf1. توفر Neuron SDK ما يلي: + +1. واجهة برمجة تطبيقات سهلة الاستخدام مع تغيير سطر واحد من التعليمات البرمجية لتتبع نموذج TorchScript وتحسينه للاستدلال في البيئة السحابية. +2. تحسينات الأداء الجاهزة للاستخدام [تحسين التكلفة والأداء](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/benchmark/>). +3. دعم نماذج Hugging Face المحولات المبنية باستخدام إما [PyTorch](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/bert_tutorial/tutorial_pretrained_bert.html) أو [TensorFlow](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/tensorflow/huggingface_bert/huggingface_bert.html). + +### الآثار المترتبة + +تعمل نماذج المحولات المستندة إلى بنية [BERT (تمثيلات الترميز ثنائية الاتجاه من المحولات)](https://huggingface.co/docs/transformers/main/model_doc/bert) أو متغيراتها مثل [distilBERT](https://huggingface.co/docs/transformers/main/model_doc/distilbert) و [roBERTa](https://huggingface.co/docs/transformers/main/model_doc/roberta) بشكل أفضل على Inf1 للمهام غير التوليدية مثل الإجابة على الأسئلة الاستخراجية، وتصنيف التسلسلات، وتصنيف الرموز (tokens). ومع ذلك، يمكن تكييف مهام توليد النصوص للعمل على Inf1 وفقًا لهذا [برنامج تعليمي AWS Neuron MarianMT](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/transformers-marianmt.html). يمكن العثور على مزيد من المعلومات حول النماذج التي يمكن تحويلها جاهزة على Inferentia في قسم [ملاءمة بنية النموذج](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/models/models-inferentia.html#models-inferentia) من وثائق Neuron. + +### التبعيات (Dependencies) + +يتطلب استخدام AWS Neuron لتحويل النماذج [بيئة SDK Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/neuron-frameworks/pytorch-neuron/index.html#installation-guide) والتي تأتي مسبقًا على [AMI للتعلم العميق من AWS](https://docs.aws.amazon.com/dlami/latest/devguide/tutorial-inferentia-launching.html). + +### تحويل نموذج لـ AWS Neuron + +قم بتحويل نموذج لـ AWS NEURON باستخدام نفس التعليمات البرمجية من [استخدام TorchScript في Python](torchscript#using-torchscript-in-python) لتتبع `BertModel`. قم باستيراد امتداد إطار عمل `torch.neuron` للوصول إلى مكونات Neuron SDK من خلال واجهة برمجة تطبيقات Python: + +```python +from transformers import BertModel, BertTokenizer, BertConfig +import torch +import torch.neuron +``` + +كل ما عليك فعله هو تعديل السطر التالي: + +```diff +- torch.jit.trace(model, [tokens_tensor, segments_tensors]) ++ torch.neuron.trace(model, [token_tensor, segments_tensors]) +``` + +يتيح ذلك لـ Neuron SDK تتبع النموذج وتحسينه لمثيلات Inf1. + +لمعرفة المزيد حول ميزات AWS Neuron SDK والأدوات ودروس البرامج التعليمية والتحديثات الأخيرة، يرجى الاطلاع على [وثائق AWS NeuronSDK](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html). diff --git a/docs/source/ar/troubleshooting.md b/docs/source/ar/troubleshooting.md new file mode 100644 index 00000000000000..7874a9fad13304 --- /dev/null +++ b/docs/source/ar/troubleshooting.md @@ -0,0 +1,171 @@ +# استكشاف الأخطاء وإصلاحها + +تحدث الأخطاء أحيانًا، لكننا هنا للمساعدة! يغطي هذا الدليل بعض المشكلات الأكثر شيوعًا التي واجهناها وكيفية حلها. مع ذلك، لا يُقصد بهذا الدليل أن يكون مجموعة شاملة لكل مشكلات 🤗 Transformers. لمزيد من المساعدة في استكشاف مشكلتك وإصلاحها، جرب ما يلي: + + + +1. اطلب المساعدة على [المنتديات](https://discuss.huggingface.co/). هناك فئات محددة يمكنك نشر سؤالك فيها، مثل [المبتدئين](https://discuss.huggingface.co/c/beginners/5) أو [🤗 Transformers](https://discuss.huggingface.co/c/transformers/9). تأكد من كتابة منشور جيد وواضح على المنتدى مع بعض التعليمات البرمجية القابلة للتكرار لزيادة احتمالية حل مشكلتك! + + +2. قم بإنشاء [مشكلة](https://github.com/huggingface/transformers/issues/new/choose) في مستودع 🤗 Transformers إذا كانت هناك مشكلة متعلقة بالمكتبة. حاول تضمين أكبر قدر ممكن من المعلومات التي تصف المشكلة لمساعدتنا في معرفة ما هو الخطأ وكيفية إصلاحه. + +3. تحقق من دليل [الترحيل](migration) إذا كنت تستخدم إصدارًا أقدم من مكتبة 🤗 Transformers حيث تم إدخال بعض التغييرات المهمة بين الإصدارات. + + +للحصول على مزيد من التفاصيل حول استكشاف الأخطاء وإصلاحها والحصول على المساعدة، راجع [الفصل 8](https://huggingface.co/course/chapter8/1?fw=pt) من دورة Hugging Face. + +## بيئات جدار الحماية + +بعض وحدات معالجة الرسومات (GPU) على السحابة وإعدادات الشبكة الداخلية محمية بجدار حماية من الاتصالات الخارجية، مما يؤدي إلى حدوث خطأ في الاتصال. عندما تحاول تعليمات البرنامج النصي تنزيل أوزان النموذج أو مجموعات البيانات، سيتوقف التنزيل ثم ينتهي بخطأ مثل: + +``` +ValueError: Connection error, and we cannot find the requested files in the cached path. +Please try again or make sure your Internet connection is on. +``` + +في هذه الحالة، يجب محاولة تشغيل 🤗 Transformers في [وضع عدم الاتصال](installation#offline-mode) لتجنب خطأ الاتصال. + +## CUDA نفاد الذاكرة + +يمكن أن يكون تدريب النماذج الكبيرة التي تحتوي على ملايين المعلمات أمرًا صعبًا بدون الأجهزة المناسبة. أحد الأخطاء الشائعة التي قد تواجهها عند نفاد ذاكرة GPU هو: + +``` +CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 11.17 GiB total capacity; 9.70 GiB already allocated; 179.81 MiB free; 9.85 GiB reserved in total by PyTorch) +``` + +فيما يلي بعض الحلول المحتملة التي يمكنك تجربتها لتقليل استخدام الذاكرة: + +- قلل من قيمة [`per_device_train_batch_size`](main_classes/trainer#transformers.TrainingArguments.per_device_train_batch_size) في [`TrainingArguments`]. + +- حاول استخدام [`gradient_accumulation_steps`](main_classes/trainer#transformers.TrainingArguments.gradient_accumulation_steps) في [`TrainingArguments`] لزيادة حجم الدُفعة بشكل فعال. + + +راجع دليل [الأداء](performance) لمزيد من التفاصيل حول تقنيات توفير الذاكرة. + + +## عدم القدرة على تحميل نموذج TensorFlow محفوظ + +تقوم طريقة TensorFlow [model.save](https://www.tensorflow.org/tutorials/keras/save_and_load#save_the_entire_model) بحفظ النموذج بالكامل - الهندسة المعمارية، الأوزان، تكوين التدريب - في ملف واحد. ومع ذلك، عند تحميل ملف النموذج مرة أخرى، قد تواجه خطأ لأن مكتبة 🤗 Transformers قد لا تقوم بتحميل جميع الكائنات المتعلقة بـ TensorFlow في ملف النموذج. لتجنب المشكلات المتعلقة بحفظ وتحميل نماذج TensorFlow، نوصي بما يلي: + +- احفظ أوزان النموذج كملف `h5` باستخدام [`model.save_weights`](https://www.tensorflow.org/tutorials/keras/save_and_load#save_the_entire_model) ثم أعد تحميل النموذج باستخدام [`~TFPreTrainedModel.from_pretrained`]: + +```python +>>> from transformers import TFPreTrainedModel +>>> from tensorflow import keras + +>>> model.save_weights("some_folder/tf_model.h5") +>>> model = TFPreTrainedModel.from_pretrained("some_folder") +``` + +- احفظ النموذج باستخدام [`~TFPretrainedModel.save_pretrained`] وقم بتحميله مرة أخرى باستخدام [`~TFPreTrainedModel.from_pretrained`]: + +```python +>>> from transformers import TFPreTrainedModel + +>>> model.save_pretrained("path_to/model") +>>> model = TFPreTrainedModel.from_pretrained("path_to/model") +``` + +## ImportError + +خطأ شائع آخر قد تواجهه، خاصة إذا كان نموذجًا تم إصداره حديثًا، هو `ImportError`: + +``` +ImportError: cannot import name 'ImageGPTImageProcessor' from 'transformers' (unknown location) +``` + +بالنسبة لأنواع الأخطاء هذه، تحقق من أن لديك أحدث إصدار من مكتبة Hugging Face Transformers مثبتًا للوصول إلى أحدث النماذج: + +```bash +pip install transformers --upgrade +``` + +## خطأ CUDA: تم تشغيل التأكيد على جانب الجهاز + +في بعض الأحيان، قد تواجه خطأ CUDA عامًا حول خطأ في كود الجهاز. + +``` +RuntimeError: CUDA error: device-side assert triggered +``` + +يجب عليك محاولة تشغيل الكود على وحدة المعالجة المركزية (CPU) أولاً للحصول على رسالة خطأ أكثر دقة. أضف متغير البيئة التالي في بداية كودك للتبديل إلى وحدة المعالجة المركزية: + +```python +>>> import os + +>>> os.environ["CUDA_VISIBLE_DEVICES"] = "" +``` + +الخيار الآخر هو الحصول على تتبع مكدس أفضل من GPU. أضف متغير البيئة التالي في بداية كودك للحصول على تتبع المكدس للإشارة إلى مصدر الخطأ: + +```python +>>> import os + +>>> os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +``` + +## إخراج غير صحيح عند عدم إخفاء رموز الحشو + +في بعض الحالات، قد يكون `hidden_state` غير صحيحة إذا تضمنت `input_ids` رموز حشو. ولإثبات ذلك، قم بتحميل نموذج ومجزىء لغوى. يمكنك الوصول إلى `pad_token_id` للنموذج لمعرفة قيمته. قد تكون `pad_token_id` `None` لبعض النماذج، ولكن يمكنك دائمًا تعيينها يدويًا. + +```python +>>> from transformers import AutoModelForSequenceClassification +>>> import torch + +>>> model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-uncased") +>>> model.config.pad_token_id +0 +``` + +يوضح المثال التالي المُخرجات بدون إخفاء رموز الحشو: + +```python +>>> input_ids = torch.tensor([[7592, 2057, 2097, 2393, 9611, 2115], [7592, 0, 0, 0, 0, 0]]) +>>> output = model(input_ids) +>>> print(output.logits) +tensor([[ 0.0082, -0.2307], +[ 0.1317, -0.1683]], grad_fn=) +``` + +هنا المُخرجات الفعلية للتسلسل الثاني: + +```python +>>> input_ids = torch.tensor([[7592]]) +>>> output = model(input_ids) +>>> print(output.logits) +tensor([[-0.1008, -0.4061]], grad_fn=) +``` + +يجب عليك في معظم الوقت توفير `attention_mask` للنموذج لتجاهل رموز الحشو لتجنب هذا الخطأ الصامت. الآن يتطابق مُخرجات التسلسل الثاني مع مُخرجاته الفعلية: + + +بشكل افتراضي، ينشئ مجزىء النصوص `attention_mask` لك استنادًا إلى إعدادات المجزىء المحدد. + + +```python +>>> attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 0, 0, 0, 0, 0]]) +>>> output = model(input_ids, attention_mask=attention_mask) +>>> print(output.logits) +tensor([[ 0.0082, -0.2307], +[-0.1008, -0.4061]], grad_fn=) +``` + +لا ينشئ 🤗 Transformers تلقائيًا `attention_mask` لإخفاء رمز الحشو إذا تم توفيره لأن: + +- بعض النماذج ليس لها رمز حشو. + +- بالنسبة لبعض الاستخدامات، يريد المستخدمون أن ينتبه النموذج إلى رمز الحشو. +## ValueError: فئة التكوين غير المعترف بها XYZ لهذا النوع من AutoModel + +بشكل عام، نوصي باستخدام فئة [`AutoModel`] لتحميل النسخ المدربة مسبقًا من النماذج. يمكن لهذه الفئة أن تستنتج وتُحمل تلقائيًا البنية الصحيحة من نسخ معينة بناءً على التكوين. إذا رأيت هذا الخطأ `ValueError` عند تحميل نموذج من نسخة، فهذا يعني أن الفئة التلقائية (Auto) لم تتمكن من العثور على خريطة من التكوين في نقطة التفتيش المعطاة إلى نوع النموذج الذي تُحاول تحميله. وغالبًا ما يحدث هذا عندما لا تدعم نقطة التفتيش مهمة معينة. + +على سبيل المثال، سترى هذا الخطأ في المثال التالي لأنه لا يوجد GPT2 للإجابة على الأسئلة: + +```py +>>> from transformers import AutoProcessor, AutoModelForQuestionAnswering + +>>> processor = AutoProcessor.from_pretrained("openai-community/gpt2-medium") +>>> model = AutoModelForQuestionAnswering.from_pretrained("openai-community/gpt2-medium") +ValueError: Unrecognized configuration class for this kind of AutoModel: AutoModelForQuestionAnswering. +Model type should be one of AlbertConfig, BartConfig, BertConfig, BigBirdConfig, BigBirdPegasusConfig, BloomConfig, ... +``` From 33eef992503689ba1af98090e26d3e98865b2a9b Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Mon, 11 Nov 2024 20:52:09 +0100 Subject: [PATCH 2/2] Agents: Small fixes in streaming to gradio + add tests (#34549) * Better support transformers.agents in gradio: small fixes and additional tests --- src/transformers/agents/agents.py | 3 +- src/transformers/agents/monitoring.py | 42 +++++++--- src/transformers/agents/python_interpreter.py | 36 ++++---- src/transformers/agents/tools.py | 16 ++-- tests/agents/test_monitoring.py | 82 +++++++++++++++++++ 5 files changed, 138 insertions(+), 41 deletions(-) create mode 100644 tests/agents/test_monitoring.py diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 73b7186d25a3c7..c461c50f29592c 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -1141,11 +1141,10 @@ def step(self): ) self.logger.warning("Print outputs:") self.logger.log(32, self.state["print_outputs"]) + observation = "Print outputs:\n" + self.state["print_outputs"] if result is not None: self.logger.warning("Last output from code snippet:") self.logger.log(32, str(result)) - observation = "Print outputs:\n" + self.state["print_outputs"] - if result is not None: observation += "Last output from code snippet:\n" + str(result)[:100000] current_step_logs["observation"] = observation except Exception as e: diff --git a/src/transformers/agents/monitoring.py b/src/transformers/agents/monitoring.py index 8e28a72deb2a3e..755418d35a56a3 100644 --- a/src/transformers/agents/monitoring.py +++ b/src/transformers/agents/monitoring.py @@ -18,11 +18,19 @@ from .agents import ReactAgent -def pull_message(step_log: dict): +def pull_message(step_log: dict, test_mode: bool = True): try: from gradio import ChatMessage except ImportError: - raise ImportError("Gradio should be installed in order to launch a gradio demo.") + if test_mode: + + class ChatMessage: + def __init__(self, role, content, metadata=None): + self.role = role + self.content = content + self.metadata = metadata + else: + raise ImportError("Gradio should be installed in order to launch a gradio demo.") if step_log.get("rationale"): yield ChatMessage(role="assistant", content=step_log["rationale"]) @@ -46,30 +54,40 @@ def pull_message(step_log: dict): ) -def stream_to_gradio(agent: ReactAgent, task: str, **kwargs): +def stream_to_gradio(agent: ReactAgent, task: str, test_mode: bool = False, **kwargs): """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" try: from gradio import ChatMessage except ImportError: - raise ImportError("Gradio should be installed in order to launch a gradio demo.") + if test_mode: + + class ChatMessage: + def __init__(self, role, content, metadata=None): + self.role = role + self.content = content + self.metadata = metadata + else: + raise ImportError("Gradio should be installed in order to launch a gradio demo.") for step_log in agent.run(task, stream=True, **kwargs): if isinstance(step_log, dict): - for message in pull_message(step_log): + for message in pull_message(step_log, test_mode=test_mode): yield message - if isinstance(step_log, AgentText): - yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log.to_string()}\n```") - elif isinstance(step_log, AgentImage): + final_answer = step_log # Last log is the run's final_answer + + if isinstance(final_answer, AgentText): + yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```") + elif isinstance(final_answer, AgentImage): yield ChatMessage( role="assistant", - content={"path": step_log.to_string(), "mime_type": "image/png"}, + content={"path": final_answer.to_string(), "mime_type": "image/png"}, ) - elif isinstance(step_log, AgentAudio): + elif isinstance(final_answer, AgentAudio): yield ChatMessage( role="assistant", - content={"path": step_log.to_string(), "mime_type": "audio/wav"}, + content={"path": final_answer.to_string(), "mime_type": "audio/wav"}, ) else: - yield ChatMessage(role="assistant", content=str(step_log)) + yield ChatMessage(role="assistant", content=str(final_answer)) diff --git a/src/transformers/agents/python_interpreter.py b/src/transformers/agents/python_interpreter.py index fbece2bebd350f..6e90f356cb928e 100644 --- a/src/transformers/agents/python_interpreter.py +++ b/src/transformers/agents/python_interpreter.py @@ -848,6 +848,13 @@ def evaluate_ast( raise InterpreterError(f"{expression.__class__.__name__} is not supported.") +def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str: + if len(print_outputs) < max_len_outputs: + return print_outputs + else: + return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n" + + def evaluate_python_code( code: str, static_tools: Optional[Dict[str, Callable]] = None, @@ -890,25 +897,12 @@ def evaluate_python_code( PRINT_OUTPUTS = "" global OPERATIONS_COUNT OPERATIONS_COUNT = 0 - for node in expression.body: - try: + try: + for node in expression.body: result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) - except InterpreterError as e: - msg = "" - if len(PRINT_OUTPUTS) > 0: - if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT: - msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n" - else: - msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n" - msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" - raise InterpreterError(msg) - finally: - if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT: - state["print_outputs"] = PRINT_OUTPUTS - else: - state["print_outputs"] = ( - PRINT_OUTPUTS[:MAX_LEN_OUTPUT] - + f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._" - ) - - return result + state["print_outputs"] = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT) + return result + except InterpreterError as e: + msg = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT) + msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" + raise InterpreterError(msg) diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index 994e1bdd817b0c..84bcf0fde61f18 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import ast import base64 import importlib import inspect @@ -141,15 +142,19 @@ def validate_arguments(self, do_validate_forward: bool = True): required_attributes = { "description": str, "name": str, - "inputs": Dict, + "inputs": dict, "output_type": str, } authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"] for attr, expected_type in required_attributes.items(): attr_value = getattr(self, attr, None) + if attr_value is None: + raise TypeError(f"You must set an attribute {attr}.") if not isinstance(attr_value, expected_type): - raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.") + raise TypeError( + f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead." + ) for input_name, input_content in self.inputs.items(): assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary." assert ( @@ -248,7 +253,6 @@ def save(self, output_dir): def from_hub( cls, repo_id: str, - model_repo_id: Optional[str] = None, token: Optional[str] = None, **kwargs, ): @@ -266,9 +270,6 @@ def from_hub( Args: repo_id (`str`): The name of the repo on the Hub where your tool is defined. - model_repo_id (`str`, *optional*): - If your tool uses a model and you want to use a different model than the default, you can pass a second - repo ID or an endpoint url to this argument. token (`str`, *optional*): The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). @@ -354,6 +355,9 @@ def from_hub( if tool_class.output_type != custom_tool["output_type"]: tool_class.output_type = custom_tool["output_type"] + if not isinstance(tool_class.inputs, dict): + tool_class.inputs = ast.literal_eval(tool_class.inputs) + return tool_class(**kwargs) def push_to_hub( diff --git a/tests/agents/test_monitoring.py b/tests/agents/test_monitoring.py new file mode 100644 index 00000000000000..c43c9cb8bf86dd --- /dev/null +++ b/tests/agents/test_monitoring.py @@ -0,0 +1,82 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers.agents.agent_types import AgentImage +from transformers.agents.agents import AgentError, ReactCodeAgent, ReactJsonAgent +from transformers.agents.monitoring import stream_to_gradio + + +class MonitoringTester(unittest.TestCase): + def test_streaming_agent_text_output(self): + def dummy_llm_engine(prompt, **kwargs): + return """ +Code: +```` +final_answer('This is the final answer.') +```""" + + agent = ReactCodeAgent( + tools=[], + llm_engine=dummy_llm_engine, + max_iterations=1, + ) + + # Use stream_to_gradio to capture the output + outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True)) + + self.assertEqual(len(outputs), 3) + final_message = outputs[-1] + self.assertEqual(final_message.role, "assistant") + self.assertIn("This is the final answer.", final_message.content) + + def test_streaming_agent_image_output(self): + def dummy_llm_engine(prompt, **kwargs): + return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' + + agent = ReactJsonAgent( + tools=[], + llm_engine=dummy_llm_engine, + max_iterations=1, + ) + + # Use stream_to_gradio to capture the output + outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True)) + + self.assertEqual(len(outputs), 2) + final_message = outputs[-1] + self.assertEqual(final_message.role, "assistant") + self.assertIsInstance(final_message.content, dict) + self.assertEqual(final_message.content["path"], "path.png") + self.assertEqual(final_message.content["mime_type"], "image/png") + + def test_streaming_with_agent_error(self): + def dummy_llm_engine(prompt, **kwargs): + raise AgentError("Simulated agent error") + + agent = ReactCodeAgent( + tools=[], + llm_engine=dummy_llm_engine, + max_iterations=1, + ) + + # Use stream_to_gradio to capture the output + outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True)) + + self.assertEqual(len(outputs), 3) + final_message = outputs[-1] + self.assertEqual(final_message.role, "assistant") + self.assertIn("Simulated agent error", final_message.content)