The captcha solver made by and for japanese high school girls!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

887 lines
330 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "mjAScbd2vl9P"
},
"source": [
"# OCR model for reading Captchas\n",
"\n",
"**Author:** [A_K_Nain](https://twitter.com/A_K_Nain)<br>\n",
"**Date created:** 2020/06/14<br>\n",
"**Last modified:** 2020/06/26<br>\n",
"**Description:** How to implement an OCR model using CNNs, RNNs and CTC loss."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wWvlZPBJvl9U"
},
"source": [
"## Introduction\n",
"\n",
"This example demonstrates a simple OCR model built with the Functional API. Apart from\n",
"combining CNN and RNN, it also illustrates how you can instantiate a new layer\n",
"and use it as an \"Endpoint layer\" for implementing CTC loss. For a detailed\n",
"guide to layer subclassing, please check out\n",
"[this page](https://keras.io/guides/making_new_layers_and_models_via_subclassing/)\n",
"in the developer guides."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yq0Pe4Zuvl9U"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "5q-xCl8Qvl9V"
},
"outputs": [],
"source": [
"import os\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import sys\n",
"\n",
"from pathlib import Path\n",
"from collections import Counter\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "KIc-3qB0L5OE",
"outputId": "5162487a-c946-4569-aa2a-b6dd3944e85a",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {},
"execution_count": 2
}
],
"source": [
"tf.executing_eagerly()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sSm7N--8vl9W"
},
"source": [
"## Load the data: [Captcha Images](https://www.kaggle.com/fournierp/captcha-version-2-images)\n",
"Let's download the data."
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"id": "g3EVJfHBvl9X",
"outputId": "63444ebd-1935-47ff-c031-0060c94fe3fc",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of images found: 10610\n",
"Number of labels found: 10610\n",
"Number of unique characters: 22\n",
"Characters present: [' ', '0', '2', '4', '5', '8', 'A', 'D', 'G', 'H', 'J', 'K', 'M', 'N', 'P', 'R', 'S', 'T', 'V', 'W', 'X', 'Y']\n"
]
}
],
"source": [
"substitutions = {\n",
" 'B': '8',\n",
" 'F': 'P',\n",
" 'U': 'V',\n",
" '6': 'G',\n",
" 'Z': '2',\n",
" 'O': '0'\n",
"}\n",
"\n",
"def apply_substitutions(input_string):\n",
" output_string = \"\"\n",
" for char in input_string:\n",
" if char in substitutions:\n",
" output_string += substitutions[char]\n",
" else:\n",
" output_string += char\n",
"\n",
" return output_string\n",
"\n",
"data_dir = Path(\"./images_10k/\")\n",
"\n",
"# Get list of all the images\n",
"images = sorted(list(map(str, list(data_dir.glob(\"*.png\")))))\n",
"labels = [apply_substitutions(img.split(os.path.sep)[-1].split(\".png\")[0]) for img in images]\n",
"\n",
"# Maximum length of any captcha in the dataset\n",
"max_length = max([len(label) for label in labels])\n",
"labels = [x + ' ' * (max_length - len(x)) for x in labels]\n",
"\n",
"characters = set(char for label in labels for char in label)\n",
"characters = sorted(list(characters))\n",
"\n",
"print(\"Number of images found: \", len(images))\n",
"print(\"Number of labels found: \", len(labels))\n",
"print(\"Number of unique characters: \", len(characters))\n",
"print(\"Characters present: \", characters)\n",
"\n",
"# Batch size for training and validation\n",
"batch_size = 16\n",
"\n",
"# Desired image dimensions\n",
"img_width = 300\n",
"img_height = 80\n",
"\n",
"# Factor by which the image is going to be downsampled\n",
"# by the convolutional blocks. We will be using two\n",
"# convolution blocks and each block will have\n",
"# a pooling layer which downsample the features by a factor of 2.\n",
"# Hence total downsampling factor would be 4.\n",
"downsample_factor = 4\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gqn-NjRovl9Y"
},
"source": [
"## Preprocessing"
]
},
{
"cell_type": "code",
"source": [
"!rm -rf sdir"
],
"metadata": {
"id": "8bVogUbzY6Fi"
},
"execution_count": 48,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"id": "MjQltH0Mvl9Y"
},
"outputs": [],
"source": [
"from skimage.morphology import opening, square, label\n",
"from skimage.measure import regionprops\n",
"from skimage.io import imread, imsave\n",
"from skimage import img_as_ubyte\n",
"\n",
"# Mapping characters to integers\n",
"char_to_num = layers.StringLookup(\n",
" vocabulary=list(characters), mask_token=None,\n",
")\n",
"\n",
"# Mapping integers back to original characters\n",
"num_to_char = layers.StringLookup(\n",
" vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True\n",
")\n",
"\n",
"def filter_image(img, kernel_size=3, num_components=8, min_height_ratio=0.25, max_height_ratio=1):\n",
" # Binarize the image\n",
" binary_image = img < 0.5 # Pixels with a value less than 0.5 will be True (1)\n",
"\n",
" # Label connected components in the image\n",
" label_image = label(binary_image)\n",
"\n",
" # Get properties of the labeled regions\n",
" properties = regionprops(label_image)\n",
"\n",
" # Sort the regions by area (in descending order)\n",
" properties.sort(key=lambda x: x.area, reverse=True)\n",
"\n",
" # Create an empty image to store the result\n",
" filtered_image = np.zeros_like(label_image, dtype=bool)\n",
"\n",
" # Keep only the largest components that satisfy the height constraints\n",
" for prop in properties[:num_components]:\n",
" minr, minc, maxr, maxc = prop.bbox\n",
" height = maxr - minr\n",
" if height > max_height_ratio * img.shape[0] or height < min_height_ratio * img.shape[0]:\n",
" continue\n",
" filtered_image[label_image == prop.label] = 1\n",
"\n",
" return filtered_image == 0\n",
"\n",
"\n",
"def read_and_process(imgpath, cdir):\n",
" img = imread(imgpath, as_gray=True);\n",
" img = np.hstack([img, np.ones((img_height, img_width - img.shape[1]))]).astype(\"float32\")\n",
" img = filter_image(img)\n",
" output_path = os.path.join(cdir, Path(imgpath).stem + \".png\")\n",
" imsave(output_path, np.clip(img_as_ubyte(img), 0, 238))\n",
" return tf.convert_to_tensor((1 - img).astype(\"float32\").reshape((80, 300, 1)));\n",
"\n",
"def load_data(images, labels, cache, shuffle=True):\n",
" os.makedirs(cache, exist_ok=True)\n",
" # 1. Get the total size of the dataset\n",
" size = len(images)\n",
" # 2. Make an indices array and shuffle it, if required\n",
" indices = np.arange(size)\n",
" if shuffle:\n",
" np.random.shuffle(indices)\n",
" # 3. Get the size of training samples\n",
" train_samples = int(size)\n",
" # 4. Split data into training and validation sets\n",
" x_train, y_train = images[indices], labels[indices]\n",
" x_train = [read_and_process(x, cache) for x in x_train]\n",
" return x_train, y_train\n",
"\n",
"\n",
"# Splitting data into training and validation sets\n",
"rx_train, ry_train = load_data(np.array(images), np.array(labels), Path(\"sdir\"))\n",
"\n",
"\n",
"def encode_single_sample(img, label):\n",
" img = tf.image.convert_image_dtype(img, tf.float32)\n",
" # 4. Resize to the desired size\n",
" #img = tf.image.resize_with_pad(img, img_height, img_width)\n",
" # 5. Transpose the image because we want the time\n",
" # dimension to correspond to the width of the image.\n",
" img = tf.transpose(img, perm=[1, 0, 2])\n",
" # 6. Map the characters in label to numbers\n",
" label = char_to_num(tf.strings.unicode_split(label, input_encoding=\"UTF-8\"))\n",
" # 7. Return a dict as our model is expecting two inputs\n",
" return {\"image\": img, \"label\": label}\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fnwhurZ-vl9Z"
},
"source": [
"## Create `Dataset` objects"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"id": "IEec36ZDL5OH",
"outputId": "e79ca0b6-4bce-4830-c863-d7167cd2f666",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2653\n",
"7957\n"
]
}
],
"source": [
"split_index = int(len(rx_train) * 0.75)\n",
"\n",
"# Move the first 75% of x_valid to x_train\n",
"x_train = rx_train[:split_index];\n",
"# Move the first 75% of y_valid to y_train\n",
"y_train = ry_train[:split_index];\n",
"\n",
"# Keep only the last 25% of x_valid\n",
"x_valid = rx_train[split_index:]\n",
"# Keep only the last 25% of y_valid\n",
"y_valid = ry_train[split_index:]\n",
"\n",
"print(len(x_valid))\n",
"print(len(x_train))"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"id": "k2MZdcpXvl9Z"
},
"outputs": [],
"source": [
"train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
"train_dataset = (\n",
" train_dataset.map(\n",
" encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE\n",
" )\n",
" .batch(batch_size)\n",
" .prefetch(buffer_size=tf.data.AUTOTUNE)\n",
")\n",
"\n",
"validation_dataset = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))\n",
"validation_dataset = (\n",
" validation_dataset.map(\n",
" encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE\n",
" )\n",
" .batch(batch_size)\n",
" .prefetch(buffer_size=tf.data.AUTOTUNE)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NI0NRV5Ivl9Z"
},
"source": [
"## Visualize the data"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"id": "7GT5RSNgvl9Z",
"outputId": "67525f74-36e2-4ca5-9246-6990c7c5a368",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 405
}
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x500 with 16 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAAGECAYAAABecT12AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd1hUx/f/37u0XToKItgQEVGJkohClAhELMSCxooi0URjUBH52DB2Sezd2GNBAjYs2LFiwYagYkfEgjRBEJHO7vn9wW/vl3V3YUEQ1Hk9z3mUe+fOPTM7c2fOlDM8IiIwGAwGg8FgMBgMRhXDr2kFGAwGg8FgMBgMxpcJMzYYDAaDwWAwGAxGtcCMDQaDwWAwGAwGg1EtMGODwWAwGAwGg8FgVAvM2GAwGAwGg8FgMBjVAjM2GAwGg8FgMBgMRrXAjA0Gg8FgMBgMBoNRLTBjg8FgMBgMBoPBYFQLzNhgMBgMBoPBYDAY1QIzNhgMBoPBYDAYDEa18FUYG3v37gWPx8PBgwdl7rVt2xY8Hg/nz5+Xude4cWN07NgRPB6vXHFycuKeS0xMxKBBg6Cvrw9dXV24ubkhPj5eKu7nz59zz/71119y9R42bBh4PB60tbWlrjs5OYHH46F3794yz0jiXbZsmcz1kSNHolmzZhAIBKhfvz46d+6MOXPmKMw3BmP9+vXg8Xiws7OTut6qVSu0bdtWJvzBgwfB4/Hg6Ogoc2/btm3g8Xg4deoUAGDHjh1cHbh8+bJMeCJCo0aNwOPx0KtXL6l7PB4P48ePl3lmwYIF4PF4+PXXXyEWi5VKy4EDB8Dj8fDvv/8qyAXg9OnT4PF4WLNmjcIwDIakTN+8eVPqelZWFjp06ACBQICTJ09i7ty54PF4MDY2Rm5urkw8ZmZmMmU+Pz8fCxcuRKtWraCpqYkGDRpg4MCBuH//vlQ4Sdzp6elydZQX94f1SVE7ouw7GIyK8OzZM4wfPx6WlpbQ1NSEpqYmWrVqhXHjxiEmJoYLJ6/cjRgxQqovpqGhAUtLS8yePRv5+fky7yodls/nw9TUFN26dUN4eLhUOHn1BAACAwOhoqKCHj16cPErWze/ZlRrWoFPgYODAwDg8uXL6NevH3f93bt3uHfvHlRVVREREQFnZ2fuXkJCAhISEmBra4vAwEDu+vv37+Hl5YV+/frh559/5q4bGxtz952dnZGVlYU///wTampqWLlyJRwdHXH79m3UrVtXSjeBQIBdu3Zh5syZUtdzcnIQGhoKgUCgMF1Hjx5FVFQU2rVrV2b64+Li0L59ewiFQvz6668wMzNDcnIyoqOjsXjxYsybN6/M5xlfL0FBQTAzM8ONGzcQFxcHCwsLACV1auvWrcjKyoKenh4XPiIiAqqqqoiMjERRURHU1NSk7qmoqOD777+XeodAIEBwcDBXTyVcuHABr169goaGhlK6Llq0CDNmzMAvv/yCf//9F3y+9FiKorT07NkTenp6CA4OxqhRo+TGHRwcDBUVFQwZMkQpXRgMCe/evUO3bt0QExODgwcPokePHrh27RoA4PXr19iwYQMmTZpUbjzDhg3D4cOHMXr0aHz33XdISkrCunXr8P333+Pu3bto0qRJdSeFwahyjh49isGDB0NVVRXDhg1D27Ztwefz8ejRIxw4cAAbNmzAs2fPyizfGhoa3GBRVlYWQkND4e/vj6dPnyIoKEgmfNeuXeHp6QkiwrNnz7B+/Xr8+OOPOHbsGFxdXRW+JygoCCNGjICLiwsOHTrE9c9Y3VQC+kpo2rQpdejQQerayZMnicfjkbu7O3Xv3l3qXnBwMAGg0NBQqetpaWkEgObMmSP3PYsXLyYAdOPGDe7aw4cPSUVFhaZPn85de/bsGQGgn3/+mQDQ7du3peIJCgoiNTU16t27N2lpaUndc3R0pMaNG5OBgQH17t1b6p4k3qVLl3LXxo4dS6qqqvT8+XMZfVNTU+Wmg8GIj48nAHTgwAEyMjKiuXPncvcCAgIIAB0/flzqGXt7exo6dCgBoKtXr0rds7S0pG+//Zb7e/v27VwdMDQ0pKKiIqnwo0ePpnbt2lGTJk2oZ8+eUvcA0Lhx47i/lyxZQgDI09OTRCJRhdJCRPTbb78Rn8+nxMREmWfz8vJIT0+PevTooSirGAwi+r8yHRkZSURE7969I3t7e1JXV6ejR49y4ebMmUMAyMbGhoyNjSk3N1cqng/L/KtXrwgATZ48WSrcuXPnCACtWLFCJu60tDS5OipTn+S1I6Up7x0MhjLExcWRlpYWtWzZkpKSkmTuFxUV0erVq+nly5dEJL/c/fLLLzJ9JLFYTPb29sTj8SglJUXq3odlnYgoJiaGAFC3bt24ax/Wk127dpGKigq5uLhQXl4ed70idfNr5qtYRgWUjMTeunULeXl53LWIiAi0bt0arq6uuHbtmtSyi4iICPB4PHTq1KlC7wkJCUH79u3Rvn177pqVlRW6dOmCvXv3yoT//vvv0bRpUwQHB0tdDwoKQo8ePVCnTh2579HR0YGvry+OHDmC6OjoMnV6+vQpGjZsKNe6rlevnjLJYnyFBAUFwcDAAD179sSAAQOkRogksxARERHctfz8fERHR+Pnn3+Gubm51L20tDTExsbKzF4AgLu7O968eYPTp09z1woLCxESEoKhQ4eWq+eKFSswdepUeHh4YPv27TIzGuWlBQA8PDwgFouxe/dumWePHTuGrKwsDBs2rFxdGAwJ79+/R48ePRAdHY39+/ejZ8+eMmFmz56N1NRUbNiwocy4srOzAfzfDLoEExMTAIBQKKwirRmMT8eSJUuQk5OD7du3c2W5NKqqqpgwYQIaNWpUoXh5PB4cHBxARDJL2OXxzTffwNDQEM+ePZN7f+/evfDw8ICTkxMOHz4steKE1U3l+KqMjaKiIly/fp27FhERgY4dO6Jjx47IysrCvXv3pO5ZWVnJLHsqC7FYjJiYGNja2src69ChA54+fcoVzNK4u7tj9+7dICIAQHp6Ok6dOlVuR8vHxwcGBgaYO3dumeGaNGmChIQEnDt3Tum0MBhBQUH4+eefoa6uDnd3dzx58gSRkZEAAHNzc5iamkrttYiMjERhYSFXp0obG1euXAEAucaGmZkZvv/+e+zatYu7duLECWRlZZW7bGn16tWYNGkShg4dih07dsg1NMpLCwB07twZDRs2lDH6gZIlVJqamujbt2+ZujAYEnJycuDq6orIyEjs27dP7tpvAPjhhx/w448/YsmSJVIDYR/SrFkzNGzYEMuXL8eRI0fw6tUr3LhxA3/88QeaNm0qt55kZGQgPT1dRj7cy1QWubm5cuOQt8+EwagoR48ehYWFhcw+uqrg+fPnAAADA4Nyw2ZmZiIzM1Nuf2///v0YNmwYOnfujCNHjsgYD5Wpm18jX5WxAYDrHBUXF+P69evo1KkTmjVrBmNjY+5ednY27t69K7djVBYZGRkoKCiQa6FLriUlJcncGzp0KF6+fMl1zvbu3QuBQIA+ffqU+T5dXV1MnDix3NmNCRMmQF1dHV26dMG3336LiRMnIjQ0lDUYDIVERUXh0aNH3IfSwcEBDRs2lJoR6NSpE27cuIGioiIAJQZ606ZNYWJiImNsSOqWojo1dOhQHDp0iOtwBQUFwdHREaampgp1PHr0KCZOnAh3d3fs3LkTKioqlU4Ln8+Hu7s7oqKiEBsby11/9+4djh8/Djc3NxlHDQyGIn755Rdcv34d+/btK/c7PmfOHKSmpmLjxo0Kw6ipqWH//v3Q0tJCnz590KhRI9jZ2eH9+/e4cuUK9PX1ZZ5p0aIFjIyMZCQhIUHpdMyZM0duHEuXLlU6DgZDHu/evUNSUhKsra1l7r19+1bKuC3LEJcgCfv06VMsX74c+/fvh7W1NVq0aCETNj8/H+np6UhLS8ONGzcwcOBAiEQiDBw4UCrcrVu3MGTIEDg4OODo0aNyZykqUze/Rr4aY6Nly5aoW7cu1+m5c+cOcnJy0LFjRwCQ6hxdvXoVIpGowsaGpELI29AqmXaTV2lat26NNm3acCO7wcHBcHNzg6amZrnvl
},
"metadata": {}
}
],
"source": [
"_, ax = plt.subplots(4, 4, figsize=(10, 5))\n",
"for batch in train_dataset.take(1):\n",
" images2 = batch[\"image\"]\n",
" labels2 = batch[\"label\"]\n",
" for i in range(batch_size):\n",
" img = (images2[i] * 255).numpy().astype(\"uint8\")\n",
" label = tf.strings.reduce_join(num_to_char(labels2[i])).numpy().decode(\"utf-8\")\n",
" ax[i // 4, i % 4].imshow(img[:, :, 0].T, cmap=\"gray\", vmin=0, vmax=255)\n",
" ax[i // 4, i % 4].set_title(label)\n",
" ax[i // 4, i % 4].axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5pgP4jIIvl9a"
},
"source": [
"## Model"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"id": "ddaZyWUFvl9a",
"scrolled": true,
"outputId": "8b407d1a-15bd-4cee-89cb-fb06eb723fda",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"ocr_model_v1\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" image (InputLayer) [(None, 300, 80, 1) 0 [] \n",
" ] \n",
" \n",
" Conv1 (Conv2D) (None, 300, 80, 32) 320 ['image[0][0]'] \n",
" \n",
" pool1 (MaxPooling2D) (None, 150, 40, 32) 0 ['Conv1[0][0]'] \n",
" \n",
" Conv2 (Conv2D) (None, 150, 40, 64) 18496 ['pool1[0][0]'] \n",
" \n",
" pool2 (MaxPooling2D) (None, 75, 20, 64) 0 ['Conv2[0][0]'] \n",
" \n",
" reshape (Reshape) (None, 75, 1280) 0 ['pool2[0][0]'] \n",
" \n",
" dense1 (Dense) (None, 75, 64) 81984 ['reshape[0][0]'] \n",
" \n",
" dropout_9 (Dropout) (None, 75, 64) 0 ['dense1[0][0]'] \n",
" \n",
" bidirectional_19 (Bidirectiona (None, 75, 256) 197632 ['dropout_9[0][0]'] \n",
" l) \n",
" \n",
" bidirectional_20 (Bidirectiona (None, 75, 128) 164352 ['bidirectional_19[0][0]'] \n",
" l) \n",
" \n",
" label (InputLayer) [(None, None)] 0 [] \n",
" \n",
" dense2 (Dense) (None, 75, 24) 3096 ['bidirectional_20[0][0]'] \n",
" \n",
" ctc_loss (CTCLayer) (None, 75, 24) 0 ['label[0][0]', \n",
" 'dense2[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 465,880\n",
"Trainable params: 465,880\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"\n",
"class CTCLayer(layers.Layer):\n",
" def __init__(self, name=None):\n",
" super().__init__(name=name)\n",
" self.loss_fn = keras.backend.ctc_batch_cost\n",
"\n",
" def call(self, y_true, y_pred):\n",
" # Compute the training-time loss value and add it\n",
" # to the layer using `self.add_loss()`.\n",
" batch_len = tf.cast(tf.shape(y_true)[0], dtype=\"int64\")\n",
" input_length = tf.cast(tf.shape(y_pred)[1], dtype=\"int64\")\n",
" label_length = tf.cast(tf.shape(y_true)[1], dtype=\"int64\")\n",
"\n",
" input_length = input_length * tf.ones(shape=(batch_len, 1), dtype=\"int64\")\n",
" label_length = label_length * tf.ones(shape=(batch_len, 1), dtype=\"int64\")\n",
"\n",
" loss = self.loss_fn(y_true, y_pred, input_length, label_length)\n",
" self.add_loss(loss)\n",
"\n",
" # At test time, just return the computed predictions\n",
" return y_pred\n",
"\n",
"\n",
"def build_model():\n",
" # Inputs to the model\n",
" input_img = layers.Input(\n",
" shape=(img_width, img_height, 1), name=\"image\", dtype=\"float32\"\n",
" )\n",
" labels = layers.Input(name=\"label\", shape=(None,), dtype=\"float32\")\n",
"\n",
" # First conv block\n",
" x = layers.Conv2D(32, (3, 3), activation=\"relu\", kernel_initializer=\"he_normal\", padding=\"same\", name=\"Conv1\")(input_img)\n",
" x = layers.MaxPooling2D((2, 2), name=\"pool1\")(x)\n",
"\n",
" # Second conv block\n",
" x = layers.Conv2D(64, (3, 3), activation=\"relu\", kernel_initializer=\"he_normal\", padding=\"same\", name=\"Conv2\")(x)\n",
" x = layers.MaxPooling2D((2, 2), name=\"pool2\")(x)\n",
"\n",
" # We have used two max pool with pool size and strides 2.\n",
" # Hence, downsampled feature maps are 4x smaller. The number of\n",
" # filters in the last layer is 64. Reshape accordingly before\n",
" # passing the output to the RNN part of the model\n",
" new_shape = ((img_width // 4), (img_height // 4) * 64)\n",
" x = layers.Reshape(target_shape=new_shape, name=\"reshape\")(x)\n",
" x = layers.Dense(64, activation=\"relu\", name=\"dense1\")(x)\n",
" x = layers.Dropout(0.2)(x)\n",
"\n",
" # RNNs\n",
" x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.25))(x)\n",
" x = layers.Bidirectional(layers.LSTM(64, return_sequences=True, dropout=0.25))(x)\n",
"\n",
" # Output layer\n",
" x = layers.Dense(len(char_to_num.get_vocabulary()) + 1, activation=\"softmax\", name=\"dense2\")(x)\n",
"\n",
" # Add CTC layer for calculating CTC loss at each step\n",
" output = CTCLayer(name=\"ctc_loss\")(labels, x)\n",
"\n",
" # Define the model\n",
" model = keras.models.Model(\n",
" inputs=[input_img, labels], outputs=output, name=\"ocr_model_v1\"\n",
" )\n",
" # Optimizer\n",
" opt = keras.optimizers.Adam()\n",
" # Compile the model and return\n",
" model.compile(optimizer=opt)\n",
" return model\n",
"\n",
"\n",
"# Get the model\n",
"model = build_model()\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PCAmf-fzvl9a"
},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"id": "M-R6QGjuvl9a",
"outputId": "1430d686-b65f-470e-ee40-749311d07dcc",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/100\n",
"498/498 [==============================] - 35s 51ms/step - loss: 20.7945 - val_loss: 19.4729\n",
"Epoch 2/100\n",
"498/498 [==============================] - 23s 45ms/step - loss: 16.1562 - val_loss: 5.6912\n",
"Epoch 3/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 3.3449 - val_loss: 1.3842\n",
"Epoch 4/100\n",
"498/498 [==============================] - 22s 43ms/step - loss: 1.5384 - val_loss: 0.9462\n",
"Epoch 5/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 1.1034 - val_loss: 0.7941\n",
"Epoch 6/100\n",
"498/498 [==============================] - 23s 45ms/step - loss: 0.8891 - val_loss: 0.6393\n",
"Epoch 7/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 0.7451 - val_loss: 0.5670\n",
"Epoch 8/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 0.6231 - val_loss: 0.5043\n",
"Epoch 9/100\n",
"498/498 [==============================] - 22s 44ms/step - loss: 0.5347 - val_loss: 0.4431\n",
"Epoch 10/100\n",
"498/498 [==============================] - 29s 58ms/step - loss: 0.4694 - val_loss: 0.4051\n",
"Epoch 11/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 0.4273 - val_loss: 0.3923\n",
"Epoch 12/100\n",
"498/498 [==============================] - 22s 43ms/step - loss: 0.3658 - val_loss: 0.3645\n",
"Epoch 13/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 0.3459 - val_loss: 0.3111\n",
"Epoch 14/100\n",
"498/498 [==============================] - 23s 46ms/step - loss: 0.3083 - val_loss: 0.4242\n",
"Epoch 15/100\n",
"498/498 [==============================] - 24s 47ms/step - loss: 0.2930 - val_loss: 0.3382\n",
"Epoch 16/100\n",
"498/498 [==============================] - 22s 44ms/step - loss: 0.2523 - val_loss: 0.3169\n",
"Epoch 17/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 0.2442 - val_loss: 0.3462\n",
"Epoch 18/100\n",
"498/498 [==============================] - 21s 43ms/step - loss: 0.2535 - val_loss: 0.4511\n",
"Epoch 19/100\n",
"498/498 [==============================] - 23s 46ms/step - loss: 0.2347 - val_loss: 0.3736\n",
"Epoch 20/100\n",
"498/498 [==============================] - 22s 44ms/step - loss: 0.2096 - val_loss: 0.3644\n",
"Epoch 21/100\n",
"498/498 [==============================] - 23s 47ms/step - loss: 0.2080 - val_loss: 0.4305\n",
"Epoch 22/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 0.1874 - val_loss: 0.3905\n",
"Epoch 23/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 0.1835 - val_loss: 0.2836\n",
"Epoch 24/100\n",
"498/498 [==============================] - 22s 44ms/step - loss: 0.1564 - val_loss: 0.3313\n",
"Epoch 25/100\n",
"498/498 [==============================] - 21s 43ms/step - loss: 0.1750 - val_loss: 0.3248\n",
"Epoch 26/100\n",
"498/498 [==============================] - 22s 44ms/step - loss: 0.1797 - val_loss: 0.2985\n",
"Epoch 27/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 0.1588 - val_loss: 0.3010\n",
"Epoch 28/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 0.1469 - val_loss: 0.3371\n",
"Epoch 29/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 0.1814 - val_loss: 0.3944\n",
"Epoch 30/100\n",
"498/498 [==============================] - 21s 43ms/step - loss: 0.1236 - val_loss: 0.3072\n",
"Epoch 31/100\n",
"498/498 [==============================] - 22s 44ms/step - loss: 0.1224 - val_loss: 0.3949\n",
"Epoch 32/100\n",
"498/498 [==============================] - 22s 45ms/step - loss: 0.1378 - val_loss: 0.4653\n",
"Epoch 33/100\n",
"498/498 [==============================] - 21s 43ms/step - loss: 0.1348 - val_loss: 0.3056\n"
]
}
],
"source": [
"epochs = 100\n",
"early_stopping_patience = 10\n",
"# Add early stopping\n",
"early_stopping = keras.callbacks.EarlyStopping(\n",
" monitor=\"val_loss\", patience=early_stopping_patience, restore_best_weights=True\n",
")\n",
"\n",
"# Train the model\n",
"history = model.fit(\n",
" train_dataset,\n",
" validation_data=validation_dataset,\n",
" epochs=epochs,\n",
" callbacks=[early_stopping],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1EpbnVEEvl9a"
},
"source": [
"## Inference\n",
"\n",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/ocr-for-captcha)\n",
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/ocr-for-captcha)."
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"id": "8_xv3ktTvl9b",
"outputId": "19ac5608-fbb5-4603-ede8-1a1f15ec6234",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"model_6\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" image (InputLayer) [(None, 300, 80, 1)] 0 \n",
" \n",
" Conv1 (Conv2D) (None, 300, 80, 32) 320 \n",
" \n",
" pool1 (MaxPooling2D) (None, 150, 40, 32) 0 \n",
" \n",
" Conv2 (Conv2D) (None, 150, 40, 64) 18496 \n",
" \n",
" pool2 (MaxPooling2D) (None, 75, 20, 64) 0 \n",
" \n",
" reshape (Reshape) (None, 75, 1280) 0 \n",
" \n",
" dense1 (Dense) (None, 75, 64) 81984 \n",
" \n",
" dropout_9 (Dropout) (None, 75, 64) 0 \n",
" \n",
" bidirectional_19 (Bidirecti (None, 75, 256) 197632 \n",
" onal) \n",
" \n",
" bidirectional_20 (Bidirecti (None, 75, 128) 164352 \n",
" onal) \n",
" \n",
" dense2 (Dense) (None, 75, 24) 3096 \n",
" \n",
"=================================================================\n",
"Total params: 465,880\n",
"Trainable params: 465,880\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"1/1 [==============================] - 1s 1s/step\n",
"['JSMYJY', 'JSMYJY']\n",
"['JNKSK ', 'JNKSK ']\n",
"['RHMMVG', 'RHMMVG']\n",
"['NNXY0 ', 'NNXY0 ']\n",
"['YPMTNN', 'YPMTNN']\n",
"['WDKJTX', 'WDKJTX']\n",
"['JYRGNS', 'JYRGNS']\n",
"['PSDD0 ', 'PSDD0 ']\n",
"['KKPPV ', 'KKPPV ']\n",
"['AW2GPD', 'AW2GPD']\n",
"['D0VSJ ', 'D0VSJ ']\n",
"['ST2VWP', 'ST2VWP']\n",
"['4VHHN ', '4VHHN ']\n",
"['TWTNSK', 'TWTNSK']\n",
"['JXYARA', 'JXYARA']\n",
"['8KNRGD', '8KNRGD']\n",
"1/1 [==============================] - 0s 33ms/step\n",
"['PGYYXT', 'PGYYXT']\n"
]
},
{
"output_type": "error",
"ename": "IndexError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-78-d7d1093153ee>\u001b[0m in \u001b[0;36m<cell line: 43>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcomp\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0mtitle\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34mf\"P: {pred} T: {comp} ({dist})\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 68\u001b[0;31m \u001b[0max\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcmap\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"gray\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 69\u001b[0m \u001b[0max\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_title\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtitle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'green'\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcomp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpred\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m'red'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0max\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"off\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mIndexError\u001b[0m: index 4 is out of bounds for axis 0 with size 4"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1500x500 with 16 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABJ4AAAGbCAYAAACF5lr+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd1gUydPHv0sGyQhIRuUUFRUVc+TMGc9wRhQzeqY7L3ieijkH9Mw5nxEVc8CMAUURAxhARSRKzqneP3h3fgy7ywbARezP89Sj9HSo6Zmt6anprhYQEYHBYDAYDAaDwWAwGAwGg8EoY1SUrQCDwWAwGAwGg8FgMBgMBqNywhxPDAaDwWAwGAwGg8FgMBiMcoE5nhgMBoPBYDAYDAaDwWAwGOUCczwxGAwGg8FgMBgMBoPBYDDKBeZ4YjAYDAaDwWAwGAwGg8FglAvM8cRgMBgMBoPBYDAYDAaDwSgXmOOJwWAwGAwGg8FgMBgMBoNRLjDHE4PBYDAYDAaDwWAwGAwGo1xgjicGg8FgMBgMBoPBYDAYDEa5wBxPDEYp6HGwB8adGadQ2RY7WuCPK3+UsUYMBoMhG8x+MRgMxv94GPkQGgs18CHpg9xlL769CN0luohLjysHzRgMRmUnLScNZivNcPDZQbnLfsn4gipLquD8m/PloFnZ8U06nvY83QPBfAEnWou0UGtDLfxy/hfEpMUoVOeN9zcgmC/A8ZfHeenBMcEYcHQA7NbZQWuRFqzWWKHz/s7Y8GADL5/9OnsI5gvQaV8nsfVvf7yd0/fR50fIzc9F/c31UXN9TWTmZorkf5/0HjqLdTDw2EDeOStStjgd9nTg9Z8k8brhJUvXcXU6bXLipeXk58D7vjcabW0E/aX6MFxmiHqb6mG873iExIdw+Ypezzsf74jUTUSwWWsDwXwBeh3qBQD47/l/EMwXYOujrWL18TzrCfWF6giKDgJQeH0ULSuJux/v4vK7y/izzZ+89AIqwIq7K1Dduzq0FmmhweYGOBx8WKT8n63/xMaAjYhOiy6xHUbl4mvZL2E7Wou0EJkSKVJG3G+26O+kKPuD9kN1gSq6HeiGrLwsAIUPyHnX58FpkxOqLKkCkxUmcN7ijGkXpuFz6meurNcNLwjmCxCfEc+rMyI5AjXX14TxcmMERgXKfK5fy34JbfqU81NE8pfU348+P+LlTc5KRrPtzaC1SAsX317k0u98vIPuB7vDao0VtBZpwXatLXof7o1DwYd45QXzBfjl/C8iOiy5vQSC+QKMPj0aBVQg87kKYfar8lKeNkYoqgtUYbbSDAOODsCruFci+UedGgXdJboS6yt+X79Pes/VvejWIrFlhp0cBsF8gUi9Qpvww4YfxJa78u4KV7fwN9vncB/oLNZBanaqRB2HnRwGjYUa+JLxhUvLzsvGhgcb0GZXGxgtN4LGQg1YrrZEn8N9cDj4MPIL8iXWJ7SF0qTDng4S65BWp/pCddivs8fUC1ORlJUkkl+SjQdKtmvyjM+ECMuNPTNWbHuzr83m8sRnxCM3PxdVV1RFm11tJJ6vsL3GWxvz0mPTY/HX1b9Qf3N96C7RhdYiLTisd4DHaQ+xektitt9sDKk/BHaGdrz0V3Gv0O1AN+gu0YXxcmOM8Bkh4mDq5tANDsYOWHpnqcztMcqO8rZ7jz8/Fjkuzs4J7VHvw71F8gvt3Cr/VVzaP37/QDBfgBvvb4jkF76v/PvwX9yLuAeV+SqYdXWWWF2X31kOwXwBzr0+x6XJct/KQvG+lST26+zlrlPeMao847LhJ4dDa5EWXn95LZJ/2Z1lEMwX4Ozrswq9F/pH+KPNrjbQWayDaquqYeqFqUjLSZP5/MXhfd8bepp6GOw0mJeelJWE8b7jYbrSFFWWVIHrXleRcbOJjgnGNhqLOdfnlEqH8kZN2QqUhgUdFqC6UXVk5WXhzsc72PxoM86/OY/nk55DR12n1PX7R/jDda8rbA1sMa7xOFTTrYaI5Ajcj7wP7wfemNKcf+NrqWnh+vvriE6LRjXdarxjB4MPQktNi3tpU1dVx7Ze29B6V2ssvLUQSzou4eX/5fwv0FDVwPpu60X0Kk1ZAJjddjbGNv7fQCAgMgDrH67H323+Rh3TOlx6A/MGMvSSZPof7Y8Lby5gSP0hGNd4HHLzcxESH4Kzb86ilU0rOFZ15OXXUtPCoeBDaGPLH3Tc/HATn1I+QVNVk0sb7DQYe4P24q9rf8HN0Q3muubcsYeRD7EtcBt+a/kbGlZrKKJXacoWZaX/SnSs0REOxg689NnXZmPZ3WUY13gcmlo2xenQ0xh6cigEAgHPmPR17At9TX1sCtiEBa4LSmyLUfkob/slJDs/G8vuLMOGHhukZxbDwWcHMer0KHSq0QmnBp+ClpoWcvNz0W53O4TEh2Bkw5GY0mwK0nLS8CLuBQ49P4R+dfrBUs9SYp2RKZFw3euKhMwEXB1xFY0tGkvMW5yvZb+EbA/cjlltZ5V4PpJIyU5BlwNd8CzmGXx+9kE3h24AgGMvjuHn4z/DuZozpjWfBiMtI4QnhePWh1vYHrgdQ+sPLbHeZXeWYbbfbIxsOBI7+uyAikD+b0jMflV+ysPGTG02FU2tmiI3PxfPYp5hy+MtuPH+Bp5Pei4y7lEELTUtHH5+GP+0+4eXnp6TjtMhp6GlpiWx3NuEt3gY+RDNrJrxjhUffwHAsPrD4PvaFz4hPnBv6C5SX0ZuBk6HnEY3h24w0TEBAMSlx6H7we54HPUYXWt2xT9t/4GxtjGi06JxNfwqhp4circJbzGnvfiB/091fuL93tJy0uB5zhP9HPvhpzo/cenmVczFFS+RzT03Q1dDF+k56bgWfg0bHm5AYFQg7oyW3elSEvKMz4qXO/HqBDb13AQNVQ3escPPD4uMiwfWHYitj7fiQ9IHEecPANz6cAufUj5hRosZXNrDyIfoeagnUrNTMdhpMCY2mQhNNU2EJ4bjVOgp7Hm6BzdH3UQ7u3YlnuPT6Ke4GnYV/qP9eemfUj6h3Z52MNA0wJKOS5CWk4ZV/qsQHBOMh+Me8s5rQpMJmHllJuZ3mA89Tb0S22OUD+U1tvK66QXfIb4y5z/7+iwef36MJpZNSsz3T7t/8N/z/zDx7EQ883zG3U9JWUmYcWkGmlo2xaSmk6AiUMGEJhOw+t5qDG8wHPXM6nF1fEj6gAW3FmBg3YHoWasnAPnuW2m0s2uH/f3289LGnhmLZlbNML7JeC5NV0PyxwZJKDJGlXVctqbrGpx/cx4Tz06E30g/Lj08MRwLbi5A/zr90atWocNcnvfCp9FP0XFfR9SpWgdruq7Bp5RPWOW/Cm8S3uDCsAvynD5Hbn4uvB94Y0aLGVBVUeXSC6gAPQ/1RFB0EH5v9Tuq6lTFpkeb0GFPBzwe/xg/mPzvg8tEl4lY/3A9/ML98GP1HxXSo9yhb5DdT3YTvEABkQG89F8v/krwAh16dkjuOq+HXyd4gY69OMal9TjYg0xXmFJiZqJI/pi0GN7fdmvtqOPejqS/VJ/W3VvHOxaRHEEq81Wo/5H+Inp7nvUk9QXq9DzmOZd2/MVxghdo08NNJZ6zrGWlcezFMYIX6Hr4dZnLFKf97vZUb2M97u+Hnx4SvECLby0WyZuXn0fx6fHc38Jz++nIT1R1RVXKzc/l5R93Zhw12dqE7NbaUc+DPbn08MRw0lmsQ0OOD+HV7bzFmezX2VN6TjqXXpqy4ohJiyG1BWq04/EOXvqn5E+kvkCdJp+bzKUVFBRQ211tyXqNNeXl5/Hy/3LuF7Jba0cFBQUltseoPHwt+yVsx3mLM2ku1KTIlEhemeK/WSLR38nh4MOkOl+VOu3rRJm5mVz60edHCV6gg88OiuiSmZtJyVnJ3N/zrs8jeIHi0uOIiCgyJZJ+WP8DGS4zFOkDRSgP+0VU2Bf1NtYjtQVqNOX8FN6xkvpbeE4pWSnUYkcL0lioQWdDz/LK191Yl+ptrEfZedkiuhR/tsALPHuy4s4KghfI3ced8gvyFTpfZr8qN1/LxhARbQ7YTPACLb+znJc+0mckVVlcRWJ9xe/r8
},
"metadata": {}
}
],
"source": [
"from sklearn.metrics import pairwise_distances\n",
"\n",
"# Get the prediction model by extracting layers till the output layer\n",
"prediction_model = keras.models.Model(\n",
" model.get_layer(name=\"image\").input, model.get_layer(name=\"dense2\").output\n",
")\n",
"prediction_model.summary()\n",
"\n",
"# A utility function to decode the output of the network\n",
"def decode_batch_predictions(pred):\n",
" input_len = np.ones(pred.shape[0]) * pred.shape[1]\n",
" # Use greedy search. For complex tasks, you can use beam search\n",
" results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][\n",
" :, :max_length\n",
" ]\n",
" # Iterate over the results and get back the text\n",
" output_text = []\n",
" for res in results:\n",
" res = tf.strings.reduce_join(num_to_char(res)).numpy().decode(\"utf-8\")\n",
" output_text.append(res)\n",
" return output_text\n",
"\n",
"def lev(s1, s2):\n",
" m, n = len(s1), len(s2)\n",
" dp = np.zeros((m + 1, n + 1), dtype=int)\n",
"\n",
" for i in range(m + 1):\n",
" for j in range(n + 1):\n",
" if i == 0:\n",
" dp[i][j] = j\n",
" elif j == 0:\n",
" dp[i][j] = i\n",
" elif s1[i - 1] == s2[j - 1]:\n",
" dp[i][j] = dp[i - 1][j - 1]\n",
" else:\n",
" dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])\n",
"\n",
" return dp[m][n]\n",
"# Let's check results on some validation samples\n",
"j = 0\n",
"_, ax = plt.subplots(4, 4, figsize=(15, 5))\n",
"\n",
"for batch in validation_dataset.take(4):\n",
" batch_images = batch[\"image\"]\n",
" batch_labels = batch[\"label\"]\n",
"\n",
" preds = prediction_model.predict(batch_images)\n",
" pred_texts = decode_batch_predictions(preds)\n",
"\n",
" orig_texts = []\n",
" for label in batch_labels:\n",
" #print(tf.strings.reduce_join(num_to_char(label)).numpy().decode(\"utf-8\"))\n",
" label = tf.strings.reduce_join(num_to_char(label)).numpy().decode(\"utf-8\")\n",
" orig_texts.append(label)\n",
"\n",
" for i in range(len(pred_texts)):\n",
" img = (batch_images[i, :, :, 0] * 255).numpy().astype(np.uint8)\n",
" img = img.T\n",
" pred = pred_texts[i].replace('[UNK]', '')\n",
" comp = orig_texts[i];\n",
" if len(comp) == 5:\n",
" comp += ' '\n",
" if len(pred) == 5:\n",
" pred += ' '\n",
" dist = lev(pred, comp)\n",
" print([pred, comp])\n",
" title = f\"P: {pred} T: {comp} ({dist})\"\n",
" ax[j // 4, i % 4].imshow(img, cmap=\"gray\")\n",
" ax[j // 4, i % 4].set_title(title, color=('green' if comp in pred else 'red'))\n",
" ax[j // 4, i % 4].axis(\"off\")\n",
" j += 1\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {
"id": "0T6jO2OrL5OK",
"outputId": "1a302d59-669e-4549-ac50-edc8d17494d6",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, lstm_cell_58_layer_call_fn, lstm_cell_58_layer_call_and_return_conditional_losses, lstm_cell_59_layer_call_fn while saving (showing 5 of 10). These functions will not be directly callable after loading.\n"
]
}
],
"source": [
"model.save('captcha_75_25.h5')\n",
"model.save('captcha_75_25.keras')\n",
"model.save('captcha_75_25.tf')"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}