В распоряжении SberDevices — огромные core-модели, построенные на всем известной архитектуре Transformer. Обучение такой модели может занимать очень много времени, а Inference — требует большого количества памяти компьютера. Расскажу, как при таких вводных мы обучаем core-модели и какие хаки мы используем, чтобы их облегчить и ускорить. Речь пойдёт о ML с позиции пайплайнов и продакшена виртуального ассистента Салют.
Привет, Хабр! Меня зовут Александр Абрамов и я ML Lead продукта в SberDevices. Эта статья — про обучение core-моделей retrieval-based диалоговых систем, поговорим про хинты для ускорения обучения и сходимости, также затрону тему общей схемы inference и оптимизации её компонентов.
Статья создана на основе моего доклада для конференции Highload.
Архитектура в голосовом ассистенте Салют решает множество задач: распознавание речи, предобработка текста, определение именованных сущностей, классификация намерения, аннотирование, исполнение навыка.
Всё это нужно, чтобы вести осознанный, интересный, эмпатичный, эмоциональный диалог. При этом на задержку даётся не больше одной секунды, то есть все эти задачи должны последовательно исполниться за 1 секунду и никак не более.
Core-модели в Салют строятся на больших трансформерных архитектурах.
Вначале у нас был BERT Large трансформер, занимающий в памяти GPU около 2 гигабайт. И это без учёта дообучения, то есть размер этот постоянно увеличивался. Ведь обучали мы его на огромных наборах данных — порядка миллионов различных диалогов, огромное количество специальной пост-тег разметки именованных сущностей, намерений, различной информации, определение тональности текста. Всё это помещалось для обучения в наш пайплайн с BERT Large.
Чтобы эта модель эффективно представляла наш текст в виде вектора, мы использовали Metric-Learning.
Чтобы задачи влезли в наш пайплайн обучения модели, мы использовали HOROVOD multi-GPU learning. HOROVOD — библиотека, позволяющая обучать модель как на TensorFlow, так и на PyTorch из коробки на нескольких GPU.
Мы используем его так:
Все данные шардируются, батчируются, раскидываются на несколько GPU-нод, которые параллельно учатся на одной и той же архитектуре. На старте каждая архитектура инициализируется рандомно, независимо друг от друга, и в конце каждой эпохи обучения веса “суммируются” за счёт обмена градиентами между GPU. По сути, мы используем DataParallel подход multiGPU-leaning. Так мы создаём улучшенную модель из весов моделей, которые мы раскидывали на GPU при помощи HOROVOD. В результате получили не простой BERT, а именно SBERT-multitask.
SBERT-multitask эффективно представляет поисковые запросы. А кроме этих запросов мы учим его вести диалоги, определять именованные сущности и так далее. Также есть несколько дополнительных тасок, поэтому нам требуется размещать в памяти GPU всю эту информацию.
Мы используем не только распараллеливания вычислений, но и способы облегчения модели при помощи дистилляции и сжатия. Это когда мы от сильной модели учителя передаём знания более слабой модели ученику, с минимальными потерями качества.
В данном случае мы используем в качестве одного из примеров LABSE-модель. Это тоже одна из последних известных моделей с эффективным представлением sentence-эмбеддингов. Используем дистилляцию эмбеддингов нашей SBERT модели, которую мы так масштабно обучили, к LABSE-модели. Это позволяет уменьшить размерность поискового индекса примерно в 3,5 раза. Делаем мы это за счёт пожимающего shift-слоя, который позволяет вместо размерности вектора 1024 иметь 768, а в конечном счете мы дошли до 300 от мерного представления. К тому же при дистилляции идёт наследование свойств — это добавляет 4% к качеству релевантности и 10% к скорости поиска. А поисковый индекс, используемый нами из векторов, с пониженной размерностью, весит в памяти всего лишь 400 MB.
Мы должны подумать не только о том, чтобы наша модель быстро работала, была сжатой, занимала меньше памяти, она ещё должна сохранять адекватность разговора. Для этого нам нужно продумать её устойчивость к атакам.
Чтобы модель была устойчива к атакам, мы сделали два фреймворка. Первый — simple_aug, который использует все возможные лингвистические замены наших символов во входной фразе на ошибочные символы, так у нас появляется match ошибок. При этом мы дополнительно используем RuTextFooler, который заменяет наши фразы не на уровне символов, а полностью — на похожие по написанию, но разные по смыслу. Таким образом, наша устойчивость увеличивается. В настоящее время оба этих модуля являются частями нашей библиотеки augmentex.
Такая устойчивость не только повышает качество модели, то есть её релевантность и разнообразие ответов, но и влияет на скорость сходимости. То есть вы получаете лучшую модель при обучении не за 10 эпох, а за 5 или за 7. Это на 5-10% увеличивает скорость обучения конечной модели.
Так как модель у нас большая и мы используем там много тасок, то в один батч для каждой таски помещается не больше 24-32 сэмплов. Если мы умножим это на 10 задач, то в самом большом размере будет 320 значений.
Если мы используем dgx v100 размером в 32 GB, то он едва ли это вместит, то есть батч будет маленький. Но мы знаем, что чем больше батч, тем лучше качество у трансформеров.
Исправить потенциальное падение оценки качества от размера батча можно с помощью Metric Learning подхода. У нас есть вопрос от пользователя, релевантный ответ и неправильный ответ на этот вопрос. Такими «тройками» мы кормим модель. Раньше использовали только уникальные тройки (вопрос + положительная пара или отрицательная пара), но затем начали использовать FullBatching.
Когда в батче все пары тематик уникальные, то для каждого сэмпла в батче все сэмплы, расположенные не в этой строке по индексу, тоже являются нерелевантными. Эту информацию тоже нужно использовать. Это значит, что у нас не 32 примеров батча, а n2 от 32, то есть на самом деле 322, хотя в памяти хранится только изначальное число.
Этот FullBatching мы используем при вычислениях внутри Loss, что позволяет учесть гораздо больше информации. Также это позволяет умещаться в памяти, не используя при этом гораздо больший объем данных, увеличить скорость сходимости и точность. Мы видим за раз больше негативных и позитивных примеров, не увеличивая количество видимых сэмплов. Смотрим не только на свою пару, а на все.
«Вишенка на торте» — подход, который позволяет нам очень быстро переобучать наши модели при поставке нового контента. Все предыдущие этапы занимают от двух недель до месяца на большом суперкомпьютере вроде Christofari. Но когда у нас появляются новые диалоги, то модель на них тоже надо дообучить. При этом мы не можем под каждую новую поставку ставить модель в очередь на месяц-два.
Решение есть: использовать самый популярный в последнее время P-Learning или адаптеры. То есть фризим основную часть модели, а feed-forward слои — переобучаем, т.е. только определённые маленькие модельки поверх трансформера, тем самым просто смещая распределение каких-то примеров или лейблов.
Это позволяет очень быстро за один спринт, релизный цикл, получить модель с обновленными данными по контенту.
Расскажу, как мы ускоряли и облегчали наши модели.
Когда мы говорим про скорость, у нас есть верхняя граница по latency. Если получается выиграть немного дополнительного времени, мы всегда задумываемся, как можно использовать больше фичей или модель побольше, потому что это напрямую влияет на онлайн метрики.
А ещё, когда мы говорим про скорость, подразумеваем скорость процессов — то есть скорость того, как быстро можно доставить изменения на продакшн, будь то багфиксы или новые фичи.
Важно, что все эти улучшения мы делаем при условии отсутствия деградации по определённым параметрам или принципам, которые мы для себя разработали:
метрики, за которыми мы следим;
удобство поддержания этого механизма;
возможность откатиться к прошлым версиям моделей;
быстрые итерации нововведений;
масштабируемость.
Разберём, как работают виртуальные ассистенты или conversation агенты.
Представьте, что у вас есть некая фраза от пользователя.
Всегда есть верхнеуровневый модуль Intent Recognizer, который роутирует трафик и определяет намерение пользователя: поставить таймер, включить музыку или пообщаться.
Дальше мы попадаем в модуль, про который и идёт речь в этой статье, — это общение на свободные темы. Здесь тоже есть внутренний intent recognizer, потому что трафик напрямую в «болталку» или чат-модель вы отправлять не будете. У вас есть некий набор интентов, которые, например, определяют характер персонажа. Это называется библия персонажа — как его зовут, как он выглядит, какие у него ответы. Или это могут быть какие-то кастомные сценарии или промо-акции, приуроченные к определённым датам. Поэтому и существует такое разделение.
Блок Annotators — это набор классификаторов поверх векторов, контекста, предложения или просто токенов, которые обогащают запрос. Эти предсказания можно использовать либо для сценарной логики, либо просто как фичи в ранжировщике, который подбирает наиболее интересную и релевантную реплику в диалоге.
У нас есть двухэтапный подход:
Есть базовая, сильная модель-векторизатор, которая умеет матчить похожие по смыслу, но разные по написанию фразы в одну точку в семантическом пространстве. Потом мы используем целую плеяду разных очень простых и лёгких моделей, которые могут решить конкретную задачу. Например, определить, является ли текущий контекст провокационным, или идентифицировать эмоцию/тематику конкретного запроса. Это намного быстрее, чем обучать 20 огромных моделей, потому что модели маленькие.
Такой подход позволяет легко добавлять новый аннотатор или задачи. Нужно просто дообучить эту простую сетку на небольшом объёме данных. Здесь же можно говорить и про замену моделей, то есть они разделены. Но появляется момент, за которым нужно следить, — это консистентность. Когда следующая модель напрямую зависит от векторов предыдущей, важно следить за тем, что их версии совпадают. Ведь если при этом вектора не сильно разъедутся на простых функциональных тестах, этого можно даже не заметить.
Когда есть такое выделение базового векторизатора, его можно обложить и холодным, и горячим кэшем. Скажу по секрету: если вы возьмёте кэш в 40 тысяч разных фраз, то покроете половину распределения даже в «болталке». Ведь когда мы в жизни общаемся, используем примерно эти же фразы: «привет», «как дела» и так далее. Поэтому для этого можно использовать кэши и не гонять лишний раз большую тяжёлую модель.
Обсудим блоки распознавания интентов и ChitChat модель.
Задачи «Болталки» и ведения простой беседы можно решать двумя способами:
Использовать генеративные сети, которые авторегрессионно, токен за токеном восстанавливают ответ, к примеру, GPT. Мы так и делаем — используем большую сеть.
Решать задачу информационного поиска. Это когда есть огромная база из заготовленных реплик, нужно сметчить и достать наиболее подходящие под данный контекст реплики из этой базы, а потом их отранжировать.
Когда мы говорим про подход, основанный на поиске и распознавании интентов, задача сводится к векторизации текущего контекста. Это нужно, чтобы посмотреть, какие кандидаты лучше всего подходят и решается с помощью быстрого поиска ближайших соседей.
Мы с самого начала начали использовать библиотеку для хранения и поиска информации в векторном представлении FAISS. Нашим требованием он в целом удовлетворял, поэтому и дальше использовали его.
Кусок на изображении выше — из ReadMe. То есть можно использовать буквально три строки.
Но есть нюанс. Дело в том, что быстрый поиск ближайших соседей мы используем для разных задач. Получается, что есть разные индексы, которые можно по-разному конфигурировать. Так как в нашем случае у нас нет большой базы в миллиард условно разных сэмплов, мы не используем Product Quantization, чтобы сжимать данные. Нам хватает перевести всё во float16.
Мы видим определённый трейд-офф между точностью, воспроизводимостью и скоростью, которая нас будет устраивать. Поэтому в случае распознавания интентов и «болталки», основанном на информационном поиске на этапе генерации кандидатов, мы используем связку нашего SBERT и FAISS. Здесь же, естественно, у нас заготовлены все кэши для базы, мы подобрали настройки.
Упомяну момент, которого в ReadMe нет — то, как вы будете загружать в память этот самый индекс. Можно либо на CPU, либо на GPU, при этом часть индексов будет на GPU. Можно сохранить весь большой индекс полностью, засериализовать и загружать. А можно создать индекс и загружать шардами. Второй вариант позволит избежать пики по памяти во время старта пода. Это может быть достаточно критично, потому что если вывалиться из памяти GPU, можно просто завязнуть в этом бесконечном рестарте.
Консистентность больше относится в целом ко всему проекту, а не только к двухэтапной системе аннотаторов.
Дело в том, что с самого начала прототипы делали просто: есть Docker-контейнер для проекта и модели, в котором мы хранили код. Естественно, это не тот путь, которому нужно следовать. Со временем мы разделили код и определённые конфигурации. Мы называем их «статики» — это наши модели и кэши.
Это нужно, потому что в большинстве компаний релизные процессы кода, статиков и весов модели отличаются по скорости. Поэтому в нашем случае намного удобнее в Docker держать исключительно код и библиотеки, которые исполняются. Тогда и Docker будет легче, а значит будет быстрее загружатся в registry и стартовать. А сами модели и кэши храним отдельно, и во время старта пода подтягиваем из бакета. Это даёт удобство: можно раздельно обновлять модели и быть уверенными, что всё в целом будет хорошо.
Но здесь нужны проверки:
Проверка совпадения версии между кодом и статиками.
Совпадение векторизатора и кэшей, которые к нему относятся.
Поэтому у нас есть множество разных проверок, которые оценивают консистентность между векторами кэшей и векторизатора.
На изображении выше указаны пункты, которые принесли нам больше всего пользы с точки зрения обучения и inference:
Мультитаск, когда модель базово обучается на разные синергичные задачи — те, которые не сильно разделяют эмбеддинги между собой. Так можно получить более сильную модель. Мульти-GPU на HOROVOD, а батч — побольше. Так мы получаем более сильную базовую модель векторизатора.
Поиск на inference. Есть поисковый запрос и база, которая помогает уменьшить размерность. То есть когда у нас в базе лежит не 1000 float32 на каждое предложение, а всего 300, это здорово экономит память.
Adversarial атаки. Только начав, мы просто хотели повысить консистентность к опечаткам и к ошибкам ASR. Но эта штука оказалась действительно классной. Это очень дешёвое действие, которое позволяет и улучшить сходимость, и подтянуть метрики, и сделать модель более устойчивой.
Дистилляция.
Для inference, рекомендуем:
- Использовать кэши, где это возможно, в рамках разумного.
- Использовать библиотеку для быстрого поиска ближайших соседей.
- Наладить CI/CD процессы.
- Подумать про двухэтапный подход, когда нужно проаннотировать или проставить метки для текущего контекста. То есть сделать одну базовую модель-векторизатор и множество маленьких, которые решают каждая свою задачу. Причём их можно тренировать прямо на этапе мультитаска, а потом отдельно тюнить.
- Переводить модели в специальные форматы. Мы все модели переводим в специальные форматы для inference, благо, они есть как у TensorFlow, так и у PyTorch. Это позволяет не делать ненужных вычислений во время inference и ужимать сами модели.