Обучение моделей TensorFlow с использованием GradientTape

Изучение TensorFlow моделей с использованием GradientTape

Фото от Sivani Bandaru на Unsplash

Использование GradientTape для обновления весов

TensorFlow, без сомнения, является самой популярной библиотекой для глубокого обучения. Раньше я написал столько учебников по TensorFlow, и продолжаю их писать. TensorFlow очень хорошо организован и прост в использовании, вам не нужно слишком беспокоиться о разработке и обучении моделей. Пакет самостоятельно берет на себя большую часть работы. Вероятно, это и является причиной его популярности в индустрии. Но в то же время, иногда хотелось бы иметь контроль над внутренними функциональными возможностями. Это дает вам много возможностей для экспериментов с моделями. Если вы ищете работу, дополнительные знания могут дать вам преимущество.

Ранее я написал статью о том, как разрабатывать пользовательские функции активации, слои и функции потерь. В этой статье мы увидим, как вы можете обучать модель вручную и самостоятельно обновлять веса. Но не волнуйтесь. Вам не придется вспоминать дифференциальное исчисление снова. У нас есть метод GradientTape(), доступный в самом TensorFlow, чтобы позаботиться о этой части.

Если GradientTape() для вас совершенно новый, не стесняйтесь проверить эти упражнения по GradientTape(), которые показывают, как работает GradientTape(): Введение в GradientTape в TensorFlow — Regenerative (regenerativetoday.com)

Подготовка данных

В этой статье мы работаем с простым алгоритмом классификации в TensorFlow с использованием GradientTape(). Пожалуйста, загрузите набор данных по этой ссылке:

Датасет прогнозирования сердечной недостаточности (kaggle.com)

Этот набор данных имеет лицензию на открытую базу данных.

Вот необходимые импорты:

import tensorflow as tffrom tensorflow.keras.models import Modelfrom tensorflow.keras.layers import Dense, Inputimport numpy as npimport matplotlib.pyplot as pltimport matplotlib.ticker as mtickerimport pandas as pdfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import confusion_matriximport itertoolsfrom tqdm import tqdmimport tensorflow_datasets as tfds

Создание DataFrame с набором данных:

import pandas as pddf = pd.read_csv('heart.csv')df