Что, если мы могли бы легко объяснить чрезмерно сложные модели?

Как легко разъяснить чрезмерно сложные модели?

Создание контрфактических объяснений стало намного проще с помощью CFNOW, но что такое контрфактические объяснения и как их можно использовать?

Изображение, сгенерированное с помощью модели распространения иллюзии с текстом CFNOW в качестве иллюзии (попробуйте прищурить глаза и посмотреть с определенного расстояния) | Изображение автора, сделанное с использованием модели Stable Diffusion (лицензия)

Эта статья основана на следующей статье: https://www.sciencedirect.com/science/article/abs/pii/S0377221723006598

Вот адрес репозитория CFNOW: https://github.com/rmazzine/CFNOW

Если вы читаете это, то, возможно, вы знаете, насколько важным становится искусственный интеллект (ИИ) в нашем мире сегодня. Однако важно отметить, что казалось бы эффективные, новаторские подходы машинного обучения, совместно с их всеобъемлющей популярностью, могут привести к непредвиденным или нежелательным последствиям.

Именно поэтому объяснение искусственного интеллекта (XAI) является важной составной частью обеспечения этического и ответственного развития ИИ. В этой области показано, что объяснение моделей, которые состоят из миллионов или даже миллиардов параметров, не является тривиальным вопросом. Ответ на это многообразен, так как существует множество методов, раскрывающих различные аспекты модели, причем LIME [1] и SHAP [2] являются популярными примерами.

Однако сложность объяснений, генерируемых этими методами, может привести к сложным диаграммам или анализу, что потенциально может привести к неправильному их толкованию теми, кто не является осведомленными экспертами. Одним из возможных способов обойти эту сложность является простой и естественный метод объяснения вещей, называемый контрфактическими объяснениями [3].

Контрфактические объяснения используют естественное человеческое поведение для объяснения вещей – создание “альтернативных миров”, где изменение нескольких параметров может изменить результат. Это обычная практика, вы, вероятно, уже делали что-то подобное – “если бы я проснулся немного раньше, я бы не пропустил автобус”, такой тип объяснения наглядно подчеркивает основные причины результата простым способом.

Углубляясь в детали, контрфактические объяснения превосходят простое объяснение; они могут служить руководством для изменений, помогать в отладке аномального поведения и проверять, могут ли некоторые функции потенциально изменить предсказания (не оказывая существенного влияния на оценку). Этот многофункциональный характер подчеркивает важность объяснения ваших предсказаний. Это не только вопрос ответственного использования искусственного интеллекта; это также путь к улучшению моделей и использованию их за рамками предсказаний. Замечательным свойством контрфактических объяснений является их связь с принятием решений, что позволяет им прямо соответствовать изменению предсказания [6], в отличие от LIME и SHAP, которые больше подходят для объяснения оценок.

Учитывая явные преимущества, можно задаться вопросом, почему контрфакты не пользуются большей популярностью. Это вполне логичный вопрос! Основные преграды перед всеобщим принятием контрфактических объяснений тройные [4, 5]: (1) отсутствие пользовательских и совместимых алгоритмов генерации контрфактов, (2) неэффективность алгоритма генерации контрфактов, (3) и отсутствие комплексного визуального представления.

Но у меня есть хорошие новости для вас! Новый пакет, CFNOW (CounterFactuals NOW или CounterFactual Nearest Optimal Wololo), готов принять эти вызовы. CFNOW – это универсальный пакет на языке Python, способный генерировать несколько контрфактов для различных типов данных, таких как таблицы, изображения и текстовые (вложенные) данные. Он использует модельно-независимый подход и требует только минимальных данных – (1) фактическая точка (точка, которую нужно объяснить) и (2) функция предсказания.

Кроме того, CFNOW структурирован таким образом, чтобы позволять разработку и интеграцию новых стратегий поиска и настройки контрфактов на основе пользовательской логики. Он также включает CounterPlots, новую стратегию для визуального представления контрфактических объяснений.

Центральным элементом CFNOW является фреймворк, который преобразует данные в одну управляемую CF генератором структуру. В результате двухэтапного процесса находится и оптимизируется найденная контрфактуальная структура. Для предотвращения попадания в локальные минимумы пакет реализует метод Tabu Search, являющийся матурэвристическим методом, позволяющим исследовать новые области, где функция цели может быть лучше оптимизирована.

В последующих разделах этого текста будет продемонстрировано, как CFNOW можно эффективно использовать для генерации объяснений для классификаторов на табличных, изображениях и текстовых (встраивание) данных.

Табличные классификаторы

Здесь мы покажем типичные данные, у которых есть несколько типов данных на табличной основе. В приведенном ниже примере я буду использовать набор данных, в котором есть числовые непрерывные данные, категориальные бинарные и категориальные однофакторные закодированные данные для полного демонстрирования функционала CFNOW.

Прежде всего, вам необходимо установить пакет CFNOW, требуется версия Python выше 3.8:

    pip install cfnow

(здесь полный код для этого примера: https://colab.research.google.com/drive/1GUsVfcM3I6SpYCmsBAsKMsjVdm-a6iY6?usp=sharing)

В этой первой части мы создадим классификатор с помощью Adult Dataset. Здесь нет ничего нового:

import warningsimport pandas as pdfrom sklearn.model_selection import train_test_splitfrom sklearn.ensemble import RandomForestClassifierfrom sklearn.metrics import accuracy_scorewarnings.filterwarnings("ignore", message="X does not have valid feature names, but RandomForestClassifier was fitted with feature names")

Мы импортируем основные пакеты для создания модели классификации и также отключаем предупреждения, связанные с деланием прогнозов без имен столбцов.

Затем мы переходим к написанию классификатора, где класс 1 представляет доход ниже или равный 50 тыс. (<=50K), а класс 0 представляет высокий доход.

# Создаем классификаторimport warningsimport pandas as pdfrom sklearn.model_selection import train_test_splitfrom sklearn.ensemble import RandomForestClassifierfrom sklearn.metrics import accuracy_scorewarnings.filterwarnings("ignore", message="X does not have valid feature names, but RandomForestClassifier was fitted with feature names")# Загружаем набор данныхdataset_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"column_names = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',                'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',                'hours-per-week', 'native-country', 'income']data = pd.read_csv(dataset_url, names=column_names, na_values=" ?", skipinitialspace=True)# Удаляем строки с отсутствующими значениямиdata = data.dropna()# Определяем категориальные признаки, которые не являются бинарнымиnon_binary_categoricals = [column for column in data.select_dtypes(include=['object']).columns                            if len(data[column].unique()) > 2]binary_categoricals = [column for column in data.select_dtypes(include=['object']).columns                        if len(data[column].unique()) == 2]cols_numericals = [column for column in data.select_dtypes(include=['int64']).columns]# Применяем one-hot encoding к не бинарным категориальным признакамdata = pd.get_dummies(data, columns=non_binary_categoricals)# Преобразуем бинарные категориальные признаки в числовые# Это также приведет к бинаризации целевой переменной (доход)for bc in binary_categoricals:    data[bc] = data[bc].apply(lambda x: 1 if x == data[bc].unique()[0] else 0)# Делим набор данных на признаки и целевую переменнуюX = data.drop('income', axis=1)y = data['income']# Делим набор данных на обучающую и тестовую выборкиX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# Инициализируем RandomForestClassifierclf = RandomForestClassifier(random_state=42)# Обучаем классификаторclf.fit(X_train, y_train)# Делаем прогнозы для тестовой выборкиy_pred = clf.predict(X_test)# Оцениваем классификаторaccuracy = accuracy_score(y_test, y_pred)print("Точность:", accuracy)

С помощью приведенного выше кода мы создаем набор данных, предварительно обрабатываем его, создаем модель классификации и делаем прогноз и оценку на тестовом наборе данных.

Теперь давайте взять одну точку (первую из тестового набора данных) и проверим ее прогноз:

clf.predict([X_test.iloc[0]])# Результат: 0 -> Высокий доход

Теперь пришло время использовать CFNOW, чтобы определить, как мы можем изменить этот прогноз, минимально модифицируя характеристики:

from cfnow import find_tabular# Затем мы используем CFNOW для генерации минимальной модификации, чтобы изменить классификациюcf_res = find_tabular(    factual=X_test.iloc[0],    feat_types={c: 'num' if c in cols_numericals else 'cat' for c in X.columns},    has_ohe=True,    model_predict_proba=clf.predict_proba,    limit_seconds=60)

В приведенном выше коде мы:

  • factualДобавляем фактический экземпляр как pd.Series
  • feat_typesУказываем типы характеристик (“num” для числовых непрерывных и “cat” для категориальных)
  • has_oheУказываем, что у нас есть OHE-характеристики (он автоматически обнаруживает OHE-характеристики, объединяя те, у которых есть одинаковый префикс, за которым следует подчеркивание, например, country_brazil, country_usa, country_ireland).
  • model_predict_probaВключает функцию прогнозирования
  • limit_secondsОпределяет общий порог времени для выполнения, это важно, потому что этап настройки может продолжаться неопределенно долго (по умолчанию 120 секунд)

Затем, спустя некоторое время, мы можем оценить класс лучшего контрфактного значения (первый индекс cf_res.cfs)

clf.predict([cf_obj.cfs[0]])# Результат: 1 -> Низкий доход

И вот некоторые отличия с CFNOW, так как он также интегрирует CounterPlots, мы можем отобразить их графики и получить более проницательную информацию, например, такую:

Диаграмма CounterShapley для нашего CF | Изображение автора

Диаграмма CounterShapley ниже показывает относительную важность каждой характеристики для генерации контрфактного прогноза. Здесь мы получаем несколько интересных выводов, показывающих, что семейное положение (если объединено) составляет более 50% вклада в класс CF.

Диаграмма Greedy для нашего CF | Изображение автора

Диаграмма Greedy показывает нечто очень похожее на CounterShapley, основное отличие здесь – последовательность изменений. В то время как CounterShapley не учитывает какую-либо конкретную последовательность (вычисляя вклады с использованием значений Шапли), диаграмма Greedy использует жадную стратегию для изменения фактического экземпляра, меняя каждый шаг, изменяющий характеристику, которая наиболее вносит вклад в класс CF. Это может быть полезным в ситуациях, когда имеется какое-то руководство, основанное на жадном подходе (каждый шаг выбирая лучший подход для достижения цели).

Диаграмма Constellation для нашего CF | Изображение автора

Наконец, у нас есть самый сложный анализ – диаграмма Constellation. Несмотря на свою пугающий вид, по сути, ее довольно просто интерпретировать. Каждая большая красная точка представляет собой изменение одной единственной характеристики (относительно метки), а меньшие точки представляют собой комбинацию двух или более характеристик. Наконец, большая синяя точка представляет собой оценку CF. Здесь мы видим, что единственный способ получить CF с этими характеристиками – это изменить все из них на их соответствующие значения (т.е. нет набора, который генерирует CF). Мы также можем глубже исследовать взаимосвязи между характеристиками и, возможно, найти интересные закономерности.

В данном конкретном случае было интересно наблюдать, что прогноз о высоком доходе изменился бы, если бы у человека был женский пол, он был в разводе и имел собственного ребенка. Этот контрфакт может привести к дальнейшим обсуждениям экономических последствий для разных социальных групп.

Классификаторы изображений

Как уже упоминалось ранее, CFNOW может работать с различными типами данных, поэтому он также может генерировать контрфактические ситуации для изображений. Однако, что означает иметь контрфактный результат для набора данных изображений?

Ответ может различаться, потому что существует несколько способов генерации контрфактических ситуаций. Это может быть замена отдельных пикселей случайным шумом (метод, используемый в атаках адверсариального типа) или что-то более сложное, включающее продвинутые методы сегментации.

CFNOW использует метод сегментации, называемый quickshift, который является надежным и быстрым методом обнаружения “семантических” сегментов. Однако, возможно, интегрировать (и я приглашаю вас это сделать) другие методы сегментации.

Обнаружение сегментов само по себе недостаточно для генерации контрфактических объяснений. Нам также нужно изменить сегменты, заменив их на модифицированные версии. Для этой модификации CFNOW предлагает четыре опции, определенные в параметре replace_mode, где мы можем использовать: (по умолчанию) blur – добавление размытия к замененным сегментам, mean – замена сегментов на средний цвет, random – замена случайным шумом, и inpaint – восстановление изображения на основе соседних пикселей.

Если вы хотите получить весь код, вы можете найти его здесь: https://colab.research.google.com/drive/1M6bEP4x7ilSdh01Gs8xzgMMX7Uuum5jZ?usp=sharing

Ниже я покажу реализацию кода CFNOW для этого типа данных:

Сначала, давайте снова установим пакет CFNOW, если вы еще не сделали этого.

pip install cfnow

Теперь давайте добавим некоторые дополнительные пакеты, чтобы загрузить предобученную модель:

pip install torch torchvision Pillow requests

Затем давайте загрузим данные, загрузим предобученную модель и создадим функцию предсказания, совместимую с форматом данных, которые CFNOW должен получить:

import requestsimport numpy as npfrom PIL import Imagefrom torchvision import models, transformsimport torch# Загрузить предобученную модель ResNetmodel = models.resnet50(pretrained=True)model.eval()# Определить преобразование изображенияtransform = transforms.Compose([    transforms.Resize(256),    transforms.CenterCrop(224),    transforms.ToTensor(),    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])# Получить изображение из Интернетаimage_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/41/Sunflower_from_Silesia2.jpg/320px-Sunflower_from_Silesia2.jpg"response = requests.get(image_url, stream=True)image = np.array(Image.open(response.raw))def predict(images):    if len(np.shape(images)) == 4:        # Преобразование списка массивов numpy в пакет тензоров        input_images = torch.stack([transform(Image.fromarray(image.astype('uint8'))) for image in images])    elif len(np.shape(images)) == 3:        input_images = transform(Image.fromarray(images.astype('uint8')))    else:        raise ValueError("Входные данные должны быть списком изображений или одиночным изображением.")        # Проверить доступность GPU, и если его нет, использовать ЦПУ    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    input_images = input_images.to(device)    model.to(device)        # Выполнить вывод    with torch.no_grad():        outputs = model(input_images)        # Вернуть массив оценок предсказания для каждого изображения    return torch.asarray(outputs).cpu().numpy()LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"def predict_label(outputs):    # Загрузить метки, используемые предобученной моделью    labels = requests.get(LABELS_URL).json()        # Получить предсказанные метки    predicted_idxs = [np.argmax(od) for od in outputs]    predicted_labels = [labels[idx.item()] for idx in predicted_idxs]        return predicted_labels# Проверить предсказание для изображениепредicted_label = predict([np.array(image)])print("Predicted labels:", predict_label(predicted_label))

Большая часть работы по коду связана с созданием модели, получением данных и их настройкой, потому что для генерации контрфактических ситуаций с CFNOW нам просто необходимо:

от cfnow импортировать find_imagecf_img = find_image(img=image, model_predict=predict)cf_img_hl = cf_img.cfs[0]print("Предсказанные метки:", predict_label(predict([cf_img_hl])))# Показать изображение CFImage.fromarray(cf_img_hl.astype('uint8'))

В приведенном выше примере мы использовали все параметры по умолчанию, поэтому мы использовали алгоритм quickshift для сегментации изображения и замены сегментов размытыми изображениями. В результате у нас есть следующее фактическое предсказание:

Фактическое изображение классифицировано как “ромашка” | Название изображения: Подсолнух (Helianthus L). Слонечник от Пуделек (редактирование Yzmo и Vassil) из Викимедиа на основе GNU Free Documentation License, версия 1.2

Следующее:

Изображение CF классифицировано как “пчела” | Название изображения: Подсолнух (Helianthus L). Слонечник от Пуделек (редактирование Yzmo и Vassil) из Викимедиа на основе GNU Free Documentation License, версия 1.2

Так что же мы можем извлечь из этого анализа? Фактически, контрфактные изображения могут быть крайне полезными инструментами для определения того, как модель делает классификацию. Это может применяться в случаях, когда: (1) мы хотим проверить, почему модель сделала правильные классификации – убедиться, что она использует правильные особенности изображения: в этом случае, хотя она неправильно классифицировала подсолнух как ромашку, мы видим, что размытие цветка (а не фона) приводит к изменению предсказания. Кроме того, (2) это может помочь диагностировать ошибки в классификации изображений, что может привести к лучшим пониманию обработки изображений и/или сбора данных.

Текстовые классификаторы

Наконец, у нас есть текстовые классификаторы на основе встроенных представлений. Хотя простые текстовые классификаторы (которые используют структуру данных, похожую на таблицу) могут использовать генератор контрфактных таблиц, для текстовых классификаторов на основе встроенных представлений это не так очевидно.

Обоснование заключается в том, что встроенные представления имеют переменное количество входных данных и слов, которые могут существенно влиять на оценку предсказания и классификацию.

CFNOW решает эту проблему с помощью двух стратегий: (1) удаление доказательств или (2) добавление антонимов. Первая стратегия прямолинейна: чтобы измерить влияние каждого слова на текст, мы просто удаляем их и смотрим, какие нужно удалить, чтобы изменить классификацию. В то время как при добавлении антонимов мы, возможно, сможем сохранить семантическую структуру (поскольку удаление слова может серьезно ее повредить).

Ниже приведен код, показывающий, как использовать CFNOW в этом контексте.

Если вам нужен полный код, вы можете проверить его здесь: https://colab.research.google.com/drive/1ZMbqJmJoBukqRJGqhUaPjFFRpWlujpsi?usp=sharing

Сначала установите пакет CFNOW:

pip install cfnow

Затем установите необходимые пакеты для текстовой классификации:

pip install transformers

Затем, как и в предыдущих разделах, сначала мы построим классификатор:

from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from transformers import pipeline
import numpy as np
# Загрузка предобученной модели и токенизатора для анализа тональности
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(model_name)
# Определение конвейера анализа тональности
sentiment_analysis = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
  
# Определение простого набора данных
text_factual = "Мне понравился этот фильм, потому что он был смешным, но моим друзьям он не понравился, потому что был слишком длинным и скучным."
result = sentiment_analysis(text_factual)
print(f"{text_factual}: {result[0]['label']} (уверенность: {result[0]['score']:.2f})")
  
def pred_score_text(list_text):
    if type(list_text) == str:
        sa_pred = sentiment_analysis(list_text)[0]
        sa_score = sa_pred['score']
        sa_label = sa_pred['label']
        return sa_score if sa_label == "POSITIVE" else 1.0 - sa_score
    return np.array([sa["score"] if sa["label"] == "POSITIVE" else 1.0 - sa["score"] for sa in sentiment_analysis(list_text)])

Как видно из приведенного выше кода, наш информативный текст имеет ОТРИЦАТЕЛЬНУЮ тональность с высокой уверенностью (≥0,9), теперь попробуем сгенерировать контрфактуальный вариант для этого текста:

from cfnow import find_text
cf_text = find_text(text_input=text_factual, textual_classifier=pred_score_text)
result_cf = sentiment_analysis(cf_text.cfs[0])
print(f"CF: {cf_text.cfs[0]}: {result_cf[0]['label']} (уверенность: {result_cf[0]['score']:.2f})")

С помощью приведенного выше кода, изменив всего одно слово (но), классификация изменилась с ОТРИЦАТЕЛЬНОЙ на ПОЛОЖИТЕЛЬНУЮ с высокой уверенностью. Это демонстрирует, насколько полезными могут быть контрфактуальные объяснения, поскольку эти незначительные изменения могут оказывать влияние на понимание того, как модель предсказывает предложения и/или помогать устранять нежелательные поведенческие модели.

Вывод

Это было (относительно) краткое введение в CFNOW и контрфактуальные объяснения. В литературе существует обширное (и растущее) количество работ, посвященных контрфактуальным объяснениям, которые стоит изучить, если вы хотите более подробно изучить эту тему. Особенно стоит обратить внимание на эту фундаментальную статью [3], написанную моим научным руководителем, профессором Дэвидом Мартенсом, – это отличное введение в контрфактуальные объяснения. Кроме того, есть хорошие обзоры, такие как обзор, написанный Вермой и др. [7]. В заключение, контрфактуальные объяснения – это простой и удобный способ объяснения сложных алгоритмов машинного обучения, и они могут выполнять намного больше, чем просто объяснения, если применяются правильно. CFNOW может предоставить простой, быстрый и гибкий способ генерации контрфактуальных объяснений, позволяя практикующим специалистам не только объяснять, но и использовать максимум потенциала своих данных и моделей.

Ссылки:

[1] — https://github.com/marcotcr/lime[2] — https://github.com/shap/shap[3] — https://www.jstor.org/stable/26554869[4] — https://www.mdpi.com/2076-3417/11/16/7274[5] — https://arxiv.org/pdf/2306.06506.pdf[6] — https://arxiv.org/abs/2001.07417[7] — https://arxiv.org/abs/2010.10596