The captcha solver made by and for japanese high school girls!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1076 lines
646 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "mjAScbd2vl9P"
},
"source": [
"# OCR model for reading Captchas\n",
"\n",
"**Author:** [A_K_Nain](https://twitter.com/A_K_Nain)<br>\n",
"**Date created:** 2020/06/14<br>\n",
"**Last modified:** 2020/06/26<br>\n",
"**Description:** How to implement an OCR model using CNNs, RNNs and CTC loss."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wWvlZPBJvl9U"
},
"source": [
"## Introduction\n",
"\n",
"This example demonstrates a simple OCR model built with the Functional API. Apart from\n",
"combining CNN and RNN, it also illustrates how you can instantiate a new layer\n",
"and use it as an \"Endpoint layer\" for implementing CTC loss. For a detailed\n",
"guide to layer subclassing, please check out\n",
"[this page](https://keras.io/guides/making_new_layers_and_models_via_subclassing/)\n",
"in the developer guides."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yq0Pe4Zuvl9U"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "5q-xCl8Qvl9V"
},
"outputs": [],
"source": [
"import os\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import sys\n",
"\n",
"from pathlib import Path\n",
"from collections import Counter\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "KIc-3qB0L5OE",
"outputId": "c6fb9b04-386f-4d84-ae46-f85cc4a36647",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {},
"execution_count": 2
}
],
"source": [
"tf.executing_eagerly()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sSm7N--8vl9W"
},
"source": [
"## Load the data: [Captcha Images](https://www.kaggle.com/fournierp/captcha-version-2-images)\n",
"Let's download the data."
]
},
{
"cell_type": "code",
"source": [
"!unzip -qq images_10k.zip"
],
"metadata": {
"id": "GmxnAtyRz-L2"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "g3EVJfHBvl9X",
"outputId": "9869f54b-be6e-4cdf-8a6c-6c2a034495a4",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of images found: 10780\n",
"Number of labels found: 10780\n",
"Number of unique characters: 21\n",
"Characters present: [' ', '0', '2', '4', '8', 'A', 'D', 'G', 'H', 'J', 'K', 'M', 'N', 'P', 'R', 'S', 'T', 'V', 'W', 'X', 'Y']\n"
]
}
],
"source": [
"substitutions = {\n",
" 'B': '8',\n",
" 'F': 'P',\n",
" 'U': 'V',\n",
" '5': 'S',\n",
" '6': 'G',\n",
" 'Z': '2',\n",
" 'O': '0'\n",
"}\n",
"\n",
"def apply_substitutions(input_string):\n",
" output_string = \"\"\n",
" for char in input_string:\n",
" if char in substitutions:\n",
" output_string += substitutions[char]\n",
" else:\n",
" output_string += char\n",
"\n",
" return output_string\n",
"\n",
"data_dir = Path(\"./images_10k/\")\n",
"\n",
"# Get list of all the images\n",
"images = sorted(list(map(str, list(data_dir.glob(\"*.png\")))))\n",
"labels = [apply_substitutions(img.split(os.path.sep)[-1].split(\".png\")[0]) for img in images]\n",
"\n",
"# Maximum length of any captcha in the dataset\n",
"max_length = max([len(label) for label in labels])\n",
"labels = [x + ' ' * (max_length - len(x)) for x in labels]\n",
"\n",
"characters = set(char for label in labels for char in label)\n",
"characters = sorted(list(characters))\n",
"\n",
"print(\"Number of images found: \", len(images))\n",
"print(\"Number of labels found: \", len(labels))\n",
"print(\"Number of unique characters: \", len(characters))\n",
"print(\"Characters present: \", characters)\n",
"\n",
"# Batch size for training and validation\n",
"batch_size = 16\n",
"\n",
"# Desired image dimensions\n",
"img_width = 300\n",
"img_height = 80\n",
"\n",
"# Factor by which the image is going to be downsampled\n",
"# by the convolutional blocks. We will be using two\n",
"# convolution blocks and each block will have\n",
"# a pooling layer which downsample the features by a factor of 2.\n",
"# Hence total downsampling factor would be 4.\n",
"downsample_factor = 4"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gqn-NjRovl9Y"
},
"source": [
"## Preprocessing"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "MjQltH0Mvl9Y"
},
"outputs": [],
"source": [
"# Mapping characters to integers\n",
"char_to_num = layers.StringLookup(\n",
" vocabulary=list(characters), mask_token=None,\n",
")\n",
"\n",
"# Mapping integers back to original characters\n",
"num_to_char = layers.StringLookup(\n",
" vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True\n",
")\n",
"\n",
"\n",
"def split_data(images, labels, train_size=0.75, shuffle=True):\n",
" # 1. Get the total size of the dataset\n",
" size = len(images)\n",
" # 2. Make an indices array and shuffle it, if required\n",
" indices = np.arange(size)\n",
" if shuffle:\n",
" np.random.shuffle(indices)\n",
" # 3. Get the size of training samples\n",
" train_samples = int(size * train_size)\n",
" # 4. Split data into training and validation sets\n",
" x_train, y_train = images[indices[:train_samples]], labels[indices[:train_samples]]\n",
" x_valid, y_valid = images[indices[train_samples:]], labels[indices[train_samples:]]\n",
" return x_train, x_valid, y_train, y_valid\n",
"\n",
"\n",
"# Splitting data into training and validation sets\n",
"x_train, x_valid, y_train, y_valid = split_data(np.array(images), np.array(labels))\n",
"\n",
"\n",
"def encode_single_sample(img_path, label):\n",
" # 1. Read image\n",
" img = tf.io.read_file(img_path)\n",
" # 2. Decode and convert to grayscale\n",
" img = tf.io.decode_png(img, channels=1)\n",
" # 3. Convert to float32 in [0, 1] range\n",
" img = tf.image.convert_image_dtype(img, tf.float32)\n",
" # 4. Resize to the desired size\n",
" img = tf.image.resize(img, [img_height, img_width])\n",
" # 5. Transpose the image because we want the time\n",
" # dimension to correspond to the width of the image.\n",
" img = tf.transpose(img, perm=[1, 0, 2])\n",
" # 6. Map the characters in label to numbers\n",
" label = char_to_num(tf.strings.unicode_split(label, input_encoding=\"UTF-8\"))\n",
" # 7. Return a dict as our model is expecting two inputs\n",
" return {\"image\": img, \"label\": label}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fnwhurZ-vl9Z"
},
"source": [
"## Create `Dataset` objects"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "k2MZdcpXvl9Z"
},
"outputs": [],
"source": [
"train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
"train_dataset = (\n",
" train_dataset.map(\n",
" encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE\n",
" )\n",
" .batch(batch_size)\n",
" .prefetch(buffer_size=tf.data.AUTOTUNE)\n",
")\n",
"\n",
"validation_dataset = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))\n",
"validation_dataset = (\n",
" validation_dataset.map(\n",
" encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE\n",
" )\n",
" .batch(batch_size)\n",
" .prefetch(buffer_size=tf.data.AUTOTUNE)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NI0NRV5Ivl9Z"
},
"source": [
"## Visualize the data"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "7GT5RSNgvl9Z",
"outputId": "d1ba100d-2b96-448f-f41b-7a77742374cc",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 405
}
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x500 with 16 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAAGECAYAAABecT12AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9dViV9//4/zgc4tAIIhKKgWCAiYpiYOd0tmLHjDm7prOdMzdn1+xW7C7UCXYrBiIgnUpKn/v3B79zfzmeAwLq3u6z87gur427z32/6tkSQRAENGjQoEGDBg0aNGjQoOELo/W/fgANGjRo0KBBgwYNGjT830QjbGjQoEGDBg0aNGjQoOGroBE2NGjQoEGDBg0aNGjQ8FXQCBsaNGjQoEGDBg0aNGj4KmiEDQ0aNGjQoEGDBg0aNHwVNMKGBg0aNGjQoEGDBg0avgoaYUODBg0aNGjQoEGDBg1fBY2woUGDBg0aNGjQoEGDhq+CRtjQoEGDBg0aNGjQoEHDV0EjbGjQoEGDBg0aNGjQoOGr8J8VNrZv345EIhH/yWQyHB0d+emnn4iOjgbg6tWr4v779++rXGPQoEEYGRkBEBMTg7a2Nv369cv3nsnJyejr69O1a1e1z6CtrY2trS2DBg0iPDxc5XwPDw+l483Nzalbty5bt25FLperHH/9+nV69uyJra0turq6mJqaUr9+febPny/+Rg0a8qNfv37IZDL8/f1V9i1evBiJRMKpU6cASElJYc6cOTg7O2NoaIiFhQU1a9Zk3LhxREREiOfNnTtXqQ0bGBhQtWpVZs6cSVJSknicuv5pY2NDmzZtWLVqFcnJySrPVJxr+/j4qFxHEATKlCmDRCKhY8eOn/UONfz3CAoK4qeffsLR0REDAwOxHY4ePZonT56Ixynaa1xcnNL54eHh9OzZEzMzM0xMTOjcuTOBgYEq90lMTGTq1KlUqlQJfX197O3tGTp0KCEhIeIx6enpODg4ULlyZTIzM1Wu0a5dO0xNTYmIiODOnTtIJBJWrFihclznzp2RSCRs27ZNZV+TJk2wtbUt0jvS8N9EMfbeu3dP3JaQkMDw4cOxtLTE0NCQZs2a8eDBA5VzFWP2sGHD1F77l19+EY/J26eOHDlCr169qFChAgYGBjg5OTFp0iQSEhLEYwYNGqQ0d+T3b9CgQbRv354SJUogCILS/R8+fIhEIsHe3l7l2by9vZFIJGzatKmor+z/DNr/6wf4XzN//nzKly9Peno6Pj4+rF+/njNnzvDs2TOl4+bOncvJkyfzvU6pUqVo1aoVx48f58OHDxgYGKgcc+TIEdLT01UEkrzPcOvWLbZv346Pjw/Pnj1DJpMpHWtnZ8eiRYsAiI2NZefOnQwdOhR/f38WL14sHjd79mwWLFhAhQoVGDRoEBUqVCA9PZ379+/z+++/s2PHDt68eVPk96Xhv8Mff/zBmTNnGDlyJN7e3uL2oKAg5s+fT7du3ejYsSNZWVk0adKEly9fMnDgQMaMGUNKSgp+fn7s3buXLl26YGNjo3Tt9evXY2RkREpKChcuXGDhwoV4e3vj6+uLRCIRj1P0jaysLKKiorh69Srjx4/njz/+4MSJE1SvXl3luQt7bZlMxt69e2nUqJHS+deuXSMsLAw9Pb0v9So1/Ec4deoUvXr1Qltbm759+1KjRg20tLR4+fIlR44cYf369QQFBaldkECu0N6sWTMSExOZMWMGOjo6rFixgqZNm/Lo0SMsLCwAkMvltGrViufPn/Pjjz/i6OhIQEAA69at4/z587x48QJjY2NkMhnr16+ndevWLFq0iDlz5oj32r9/P+fOnWP16tXY2NhQqlQpDAwM8PHxYcKECUrPdePGDbS1tfH19WXw4MHi9szMTO7evct33333Fd6mhv/ryOVyOnTowOPHj5kyZQolS5Zk3bp1eHh4cP/+fSpVqqR0vEwm4/Dhw6xbtw5dXV2lffv27UMmk5Genq60ffjw4djY2NCvXz/Kli3L06dPWbNmDWfOnOHBgwfo6+szYsQIWrZsKZ4TFBTE7NmzGT58OI0bNxa3V6xYkStXrnD27FmePXuGi4uLuM/X1xdtbW1CQkIICwvDzs5OaR+gMtf8pxD+o2zbtk0AhLt37yptnzhxogAIe/fuFa5cuSIAQs2aNQVAuH//vtKxAwcOFAwNDcW/d+3aJQDCvn371N6zdevWgqmpqZCenl7gM0ybNk0AhAMHDihtb9q0qVCtWjWlbampqYKdnZ1gaGgoZGZmCoIgCPv37xcAoWfPnkJGRobKcyQkJAhz5swp4O1o0JDLpk2bBEDYvn27uK1t27aCiYmJEBYWJgiCIBw8eFAAhD179qicn5aWJiQmJop/z5kzRwCE2NhYpeO6du0qAMKNGzcEQci/bwiCIFy+fFnQ19cX7O3thQ8fPhT72l27dhVKliwpZGVlKR3/ww8/CHXq1BHs7e2FDh06FOo9adAQEBAgGBoaClWqVBEiIiJU9mdlZQkrV64UQkJCBEFQ316XLFkiAMKdO3fEbS9evBCkUqkwffp0cZuvr68ACGvWrFG6x9atWwVAOHLkiNJ2T09PQU9PT3j16pUgCILw/v17oXTp0kLdunWFnJwc8bhmzZoJVlZWSue+fPlSAARPT0/ByclJad+NGzcEQFi5cmWh3pGG/zYfj+sHDhwQAOHQoUPiMTExMYKZmZnQp08fpXMB4fvvvxe0tLSEY8eOKe1T9Idu3bqp9KkrV66oPMeOHTsEQNi8ebPa57x7964ACNu2bVPZd+3aNQEQ1q1bp7S9d+/eQqdOnQQjIyOVNWDr1q0FCwsLQS6Xq73ff4H/rBtVfjRv3hzIlWwVjBkzhhIlSjB37twCz+3SpQuGhobs3btXZV9MTAyXL1+me/fun9SYKiTpwlgeDAwMcHNzIzU1ldjYWCDXqlGyZEm2bNmiIv0DmJqafvK3aNAAMGzYMNzd3Zk8eTLx8fGiNvTXX38VXScU7dTd3V3lfJlMhomJySfvo67fFXTsrFmzePv2Lbt37y72tfv06UN8fDwXL14Ut2VmZuLl5YWnp+cnr6tBQ16WLl1Kamoq27Ztw9raWmW/trY2Y8eOpUyZMvlew8vLi7p161K3bl1xW+XKlWnRogUHDx4UtyncAq2srJTOV9xXX19fafuKFSswMDBg5MiRAPz888/ExsayceNGtLT+3zKgUaNGREdHExAQIG7z9fXFxMSE4cOH8+rVKyUXFY3GVsPn4OXlhZWVlehaDmBpaUnPnj05fvw4GRkZSsfb2trSpEkTlTXWnj17cHFxwdnZWeUeHh4eKtu6dOkCwIsXL4r8zPXq1UNXV1ds+wp8fX1p0qQJ9erVU9onl8u5desWDRs2VLKs/9fQCBsfoVg4KczVACYmJkyYMIGTJ0+q9SVUYGhoSOfOnTl//jzv3r1T2nfgwAFycnLo27fvJ58hODgYgBIlShTqmQMDA5FKpZiZmeHv74+/vz/ff/+9GE+iQUNxkUgkbNy4kcTEREaNGsWECRNwdXVl9OjR4jEKl5CdO3eq+LEWFnX9riD69+8PwIULF4p97XLlytGgQQP27dsnbjt79iyJiYn07t27UM+hQYOCU6dO4eDgQP369Yt1vlwu58mTJ7i6uqrsq1evHm/evBFjlVxdXTE0NGTWrFl4e3sTHh7OtWvXmDp1KnXr1lVyCYFcN9/Fixdz5coVxowZw6ZNmxg7diy1atVSOk4hNOSNZfL19cXNzY369eujo6PDjRs3lPYZGxtTo0aNYv1mDf9tHj58SO3atZUEXsht7x8+fFAbL+jp6cnJkydJSUkBIDs7m0OHDhVJQRQVFQVAyZIli/zMMpmMOnXqKPWR0NBQQkNDadiwIQ0bNlQSNp4+fUpSUtJ/XiD/zwsbiYmJxMXFERYWxoEDB5g/fz76+voqgaFjx46lRIkSzJs3r8Dr9e3bV9SO5mXv3r3Y2trStGnTAp/h8OHDzJs3Dz09PbXBqTk5OcTFxREXF8fLly8ZN24cDx48oH379hgYGPDy5UsAFQlfEATxPMW/7OzsQr0jDf9tq
},
"metadata": {}
}
],
"source": [
"_, ax = plt.subplots(4, 4, figsize=(10, 5))\n",
"for batch in train_dataset.take(1):\n",
" images = batch[\"image\"]\n",
" labels = batch[\"label\"]\n",
" for i in range(batch_size):\n",
" img = (images[i] * 255).numpy().astype(\"uint8\")\n",
" label = tf.strings.reduce_join(num_to_char(labels[i])).numpy().decode(\"utf-8\")\n",
" ax[i // 4, i % 4].imshow(img[:, :, 0].T, cmap=\"gray\", vmin=0, vmax=255)\n",
" ax[i // 4, i % 4].set_title(label)\n",
" ax[i // 4, i % 4].axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5pgP4jIIvl9a"
},
"source": [
"## Model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "ddaZyWUFvl9a",
"scrolled": true,
"outputId": "f92b93a3-872c-4cf6-9987-1a74b8492c97",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"ocr_model_v1\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" image (InputLayer) [(None, 300, 80, 1) 0 [] \n",
" ] \n",
" \n",
" Conv1 (Conv2D) (None, 300, 80, 32) 320 ['image[0][0]'] \n",
" \n",
" pool1 (MaxPooling2D) (None, 150, 40, 32) 0 ['Conv1[0][0]'] \n",
" \n",
" Conv2 (Conv2D) (None, 150, 40, 64) 18496 ['pool1[0][0]'] \n",
" \n",
" pool2 (MaxPooling2D) (None, 75, 20, 64) 0 ['Conv2[0][0]'] \n",
" \n",
" reshape (Reshape) (None, 75, 1280) 0 ['pool2[0][0]'] \n",
" \n",
" dense1 (Dense) (None, 75, 64) 81984 ['reshape[0][0]'] \n",
" \n",
" dropout (Dropout) (None, 75, 64) 0 ['dense1[0][0]'] \n",
" \n",
" bidirectional (Bidirectional) (None, 75, 256) 197632 ['dropout[0][0]'] \n",
" \n",
" bidirectional_1 (Bidirectional (None, 75, 128) 164352 ['bidirectional[0][0]'] \n",
" ) \n",
" \n",
" label (InputLayer) [(None, None)] 0 [] \n",
" \n",
" dense2 (Dense) (None, 75, 23) 2967 ['bidirectional_1[0][0]'] \n",
" \n",
" ctc_loss (CTCLayer) (None, 75, 23) 0 ['label[0][0]', \n",
" 'dense2[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 465,751\n",
"Trainable params: 465,751\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"class CTCLayer(layers.Layer):\n",
" def __init__(self, name=None):\n",
" super().__init__(name=name)\n",
" self.loss_fn = keras.backend.ctc_batch_cost\n",
"\n",
" def call(self, y_true, y_pred):\n",
" # Compute the training-time loss value and add it\n",
" # to the layer using `self.add_loss()`.\n",
" batch_len = tf.cast(tf.shape(y_true)[0], dtype=\"int64\")\n",
" input_length = tf.cast(tf.shape(y_pred)[1], dtype=\"int64\")\n",
" label_length = tf.cast(tf.shape(y_true)[1], dtype=\"int64\")\n",
"\n",
" input_length = input_length * tf.ones(shape=(batch_len, 1), dtype=\"int64\")\n",
" label_length = label_length * tf.ones(shape=(batch_len, 1), dtype=\"int64\")\n",
"\n",
" loss = self.loss_fn(y_true, y_pred, input_length, label_length)\n",
" self.add_loss(loss)\n",
"\n",
" # At test time, just return the computed predictions\n",
" return y_pred\n",
"\n",
"\n",
"def build_model():\n",
" # Inputs to the model\n",
" input_img = layers.Input(\n",
" shape=(img_width, img_height, 1), name=\"image\", dtype=\"float32\"\n",
" )\n",
" labels = layers.Input(name=\"label\", shape=(None,), dtype=\"float32\")\n",
"\n",
" # First conv block\n",
" x = layers.Conv2D(32, (3, 3), activation=\"relu\", kernel_initializer=\"he_normal\", padding=\"same\", name=\"Conv1\")(input_img)\n",
" x = layers.MaxPooling2D((2, 2), name=\"pool1\")(x)\n",
"\n",
" # Second conv block\n",
" x = layers.Conv2D(64, (3, 3), activation=\"relu\", kernel_initializer=\"he_normal\", padding=\"same\", name=\"Conv2\")(x)\n",
" x = layers.MaxPooling2D((2, 2), name=\"pool2\")(x)\n",
"\n",
" # We have used two max pool with pool size and strides 2.\n",
" # Hence, downsampled feature maps are 4x smaller. The number of\n",
" # filters in the last layer is 64. Reshape accordingly before\n",
" # passing the output to the RNN part of the model\n",
" new_shape = ((img_width // 4), (img_height // 4) * 64)\n",
" x = layers.Reshape(target_shape=new_shape, name=\"reshape\")(x)\n",
" x = layers.Dense(64, activation=\"relu\", name=\"dense1\")(x)\n",
" x = layers.Dropout(0.2)(x)\n",
"\n",
" # RNNs\n",
" x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.25))(x)\n",
" x = layers.Bidirectional(layers.LSTM(64, return_sequences=True, dropout=0.25))(x)\n",
"\n",
" # Output layer\n",
" x = layers.Dense(len(char_to_num.get_vocabulary()) + 1, activation=\"softmax\", name=\"dense2\")(x)\n",
"\n",
" # Add CTC layer for calculating CTC loss at each step\n",
" output = CTCLayer(name=\"ctc_loss\")(labels, x)\n",
"\n",
" # Define the model\n",
" model = keras.models.Model(\n",
" inputs=[input_img, labels], outputs=output, name=\"ocr_model_v1\"\n",
" )\n",
" # Optimizer\n",
" opt = keras.optimizers.Adam()\n",
" # Compile the model and return\n",
" model.compile(optimizer=opt)\n",
" return model\n",
"\n",
"\n",
"# Get the model\n",
"model = build_model()\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PCAmf-fzvl9a"
},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "M-R6QGjuvl9a",
"outputId": "4c1cc55d-e4fb-4a05-e537-672383ff9a5e",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/100\n",
"506/506 [==============================] - 47s 56ms/step - loss: 20.3903 - val_loss: 19.6792\n",
"Epoch 2/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 19.5811 - val_loss: 19.5427\n",
"Epoch 3/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 19.3835 - val_loss: 19.3300\n",
"Epoch 4/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 19.1879 - val_loss: 19.1312\n",
"Epoch 5/100\n",
"506/506 [==============================] - 27s 53ms/step - loss: 19.0987 - val_loss: 19.0328\n",
"Epoch 6/100\n",
"506/506 [==============================] - 27s 54ms/step - loss: 19.0443 - val_loss: 19.0023\n",
"Epoch 7/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 19.0237 - val_loss: 18.9892\n",
"Epoch 8/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 19.0134 - val_loss: 18.9844\n",
"Epoch 9/100\n",
"506/506 [==============================] - 27s 53ms/step - loss: 19.0469 - val_loss: 19.0271\n",
"Epoch 10/100\n",
"506/506 [==============================] - 28s 54ms/step - loss: 19.0421 - val_loss: 19.0218\n",
"Epoch 11/100\n",
"506/506 [==============================] - 28s 56ms/step - loss: 19.0394 - val_loss: 18.9750\n",
"Epoch 12/100\n",
"506/506 [==============================] - 27s 53ms/step - loss: 18.9894 - val_loss: 18.8343\n",
"Epoch 13/100\n",
"506/506 [==============================] - 27s 52ms/step - loss: 18.6315 - val_loss: 18.3019\n",
"Epoch 14/100\n",
"506/506 [==============================] - 27s 53ms/step - loss: 17.8129 - val_loss: 16.7060\n",
"Epoch 15/100\n",
"506/506 [==============================] - 27s 53ms/step - loss: 6.1195 - val_loss: 1.0395\n",
"Epoch 16/100\n",
"506/506 [==============================] - 27s 53ms/step - loss: 1.1816 - val_loss: 0.5404\n",
"Epoch 17/100\n",
"506/506 [==============================] - 28s 55ms/step - loss: 0.7410 - val_loss: 0.3403\n",
"Epoch 18/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.5583 - val_loss: 0.3136\n",
"Epoch 19/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.3938 - val_loss: 0.2274\n",
"Epoch 20/100\n",
"506/506 [==============================] - 28s 54ms/step - loss: 0.3126 - val_loss: 0.1942\n",
"Epoch 21/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.2715 - val_loss: 0.1614\n",
"Epoch 22/100\n",
"506/506 [==============================] - 27s 54ms/step - loss: 0.2289 - val_loss: 0.1321\n",
"Epoch 23/100\n",
"506/506 [==============================] - 27s 54ms/step - loss: 0.2157 - val_loss: 0.2784\n",
"Epoch 24/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.1764 - val_loss: 0.1141\n",
"Epoch 25/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 0.1557 - val_loss: 0.1546\n",
"Epoch 26/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 0.1552 - val_loss: 0.1312\n",
"Epoch 27/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 0.1342 - val_loss: 0.1289\n",
"Epoch 28/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 0.1189 - val_loss: 0.1213\n",
"Epoch 29/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 0.1313 - val_loss: 0.1317\n",
"Epoch 30/100\n",
"506/506 [==============================] - 27s 54ms/step - loss: 0.1032 - val_loss: 0.0887\n",
"Epoch 31/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.1297 - val_loss: 0.1856\n",
"Epoch 32/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.1114 - val_loss: 0.0882\n",
"Epoch 33/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.0776 - val_loss: 0.1376\n",
"Epoch 34/100\n",
"506/506 [==============================] - 27s 52ms/step - loss: 0.0889 - val_loss: 0.0902\n",
"Epoch 35/100\n",
"506/506 [==============================] - 27s 52ms/step - loss: 0.0853 - val_loss: 0.0919\n",
"Epoch 36/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.0721 - val_loss: 0.1522\n",
"Epoch 37/100\n",
"506/506 [==============================] - 28s 55ms/step - loss: 0.0805 - val_loss: 0.1482\n",
"Epoch 38/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.0901 - val_loss: 0.1106\n",
"Epoch 39/100\n",
"506/506 [==============================] - 27s 53ms/step - loss: 0.0758 - val_loss: 0.1010\n",
"Epoch 40/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.0695 - val_loss: 0.0960\n",
"Epoch 41/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.0687 - val_loss: 0.0932\n",
"Epoch 42/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 0.0701 - val_loss: 0.0785\n",
"Epoch 43/100\n",
"506/506 [==============================] - 27s 54ms/step - loss: 0.0610 - val_loss: 0.1160\n",
"Epoch 44/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 0.0693 - val_loss: 0.0857\n",
"Epoch 45/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 0.0532 - val_loss: 0.0981\n",
"Epoch 46/100\n",
"506/506 [==============================] - 27s 53ms/step - loss: 0.0484 - val_loss: 0.1173\n",
"Epoch 47/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.0624 - val_loss: 0.0861\n",
"Epoch 48/100\n",
"506/506 [==============================] - 27s 52ms/step - loss: 0.0798 - val_loss: 0.0880\n",
"Epoch 49/100\n",
"506/506 [==============================] - 29s 57ms/step - loss: 0.0542 - val_loss: 0.0849\n",
"Epoch 50/100\n",
"506/506 [==============================] - 27s 52ms/step - loss: 0.0450 - val_loss: 0.0782\n",
"Epoch 51/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.0493 - val_loss: 0.0880\n",
"Epoch 52/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.0627 - val_loss: 0.0825\n",
"Epoch 53/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 0.0509 - val_loss: 0.0714\n",
"Epoch 54/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.0616 - val_loss: 0.1088\n",
"Epoch 55/100\n",
"506/506 [==============================] - 28s 55ms/step - loss: 0.0640 - val_loss: 0.0834\n",
"Epoch 56/100\n",
"506/506 [==============================] - 28s 55ms/step - loss: 0.0548 - val_loss: 0.1003\n",
"Epoch 57/100\n",
"506/506 [==============================] - 27s 53ms/step - loss: 0.0436 - val_loss: 0.0941\n",
"Epoch 58/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.0315 - val_loss: 0.0977\n",
"Epoch 59/100\n",
"506/506 [==============================] - 27s 54ms/step - loss: 0.0607 - val_loss: 0.0941\n",
"Epoch 60/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 0.0402 - val_loss: 0.1110\n",
"Epoch 61/100\n",
"506/506 [==============================] - 27s 54ms/step - loss: 0.0300 - val_loss: 0.0891\n",
"Epoch 62/100\n",
"506/506 [==============================] - 26s 52ms/step - loss: 0.0508 - val_loss: 0.0789\n",
"Epoch 63/100\n",
"506/506 [==============================] - 26s 51ms/step - loss: 0.0481 - val_loss: 0.0834\n"
]
}
],
"source": [
"epochs = 100\n",
"early_stopping_patience = 10\n",
"# Add early stopping\n",
"early_stopping = keras.callbacks.EarlyStopping(\n",
" monitor=\"val_loss\", patience=early_stopping_patience, restore_best_weights=True\n",
")\n",
"\n",
"# Train the model\n",
"history = model.fit(\n",
" train_dataset,\n",
" validation_data=validation_dataset,\n",
" epochs=epochs,\n",
" callbacks=[early_stopping],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1EpbnVEEvl9a"
},
"source": [
"## Inference\n",
"\n",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/ocr-for-captcha)\n",
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/ocr-for-captcha)."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "8_xv3ktTvl9b",
"outputId": "6fca3298-33aa-4c09-b9fb-ac31d91b2249",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" image (InputLayer) [(None, 300, 80, 1)] 0 \n",
" \n",
" Conv1 (Conv2D) (None, 300, 80, 32) 320 \n",
" \n",
" pool1 (MaxPooling2D) (None, 150, 40, 32) 0 \n",
" \n",
" Conv2 (Conv2D) (None, 150, 40, 64) 18496 \n",
" \n",
" pool2 (MaxPooling2D) (None, 75, 20, 64) 0 \n",
" \n",
" reshape (Reshape) (None, 75, 1280) 0 \n",
" \n",
" dense1 (Dense) (None, 75, 64) 81984 \n",
" \n",
" dropout (Dropout) (None, 75, 64) 0 \n",
" \n",
" bidirectional (Bidirectiona (None, 75, 256) 197632 \n",
" l) \n",
" \n",
" bidirectional_1 (Bidirectio (None, 75, 128) 164352 \n",
" nal) \n",
" \n",
" dense2 (Dense) (None, 75, 23) 2967 \n",
" \n",
"=================================================================\n",
"Total params: 465,751\n",
"Trainable params: 465,751\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"1/1 [==============================] - 1s 1s/step\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1500x500 with 16 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABJ4AAAGbCAYAAACF5lr+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd1wUx/vHP0cvShFBBUFUgkRRiRBEJSqxEhvEHkswiYVEVGKLPysae5fYY4+9oUSxIIrEBigqIkVFFAFR+gkcXHl+f9z3Vo67g7ujmuz79dqXsjs7Mzs3OzvzmWee4RARgYWFhYWFhYWFhYWFhYWFhYWFpZrRqOsMsLCwsLCwsLCwsLCwsLCwsLD8O2GFJxYWFhYWFhYWFhYWFhYWFhaWGoEVnlhYWFhYWFhYWFhYWFhYWFhYagRWeGJhYWFhYWFhYWFhYWFhYWFhqRFY4YmFhYWFhYWFhYWFhYWFhYWFpUZghScWFhYWFhYWFhYWFhYWFhYWlhqBFZ5YWFhYWFhYWFhYWFhYWFhYWGoEVnhiYWFhYWFhYWFhYWFhYWFhYakRWOGJhYWFhYWFhYWFhYWFhYWFhaVGYIUnFpZa4pvD32Di+Ylq3ev2pxvmXJ1TzTliYWFhURGRCHB0BJYvV+9+NzdgDtuWsbCw1A5s34uFheVTIDItEjrLdPAq75XK9156fgkNVjTA+8L3NZCz6qNeCU/7H+4HJ4DDHHq/68E+0B5TL05F5odMteK8kXIDnAAOTj09JXW+VFiKgUcGQiNAA3tj9iqd/rSQaeAEcPA857nCNOdfmw9OAAePMx8z50QkwsFHB9HnUB80XtMY2su0YbHWAn0P9cWu+7tQIihh8qrMIY/lN5eDE8CB4zZHpcpE3XQqipMtZ/ncen0LV15cwVz3uVLnRSTCmltr0HJzS+j9rocO2zvgaOxRmfvndpuLrVFb8fbDW6XTZKlh9u8HOJyPh54eYG8PTJ0KZKrXXuHGDek4tbWBVq2A8eOB5OSP4VJSPob5/Xf5cY0ZI77eoIH0+Z49P96roQEYGQFt2gDjxgFXr8qPq7QU2LwZ+OILcXgTE6BdO2DSJCAhQfHzLF8uTsdRwbsiEgEHDwJ9+gCNG4uf18IC6NsX2LULKClRHHf58ld02NoqjqOiOP/5R/Y6EWBtLb4+cKD0tbJpamkBjRoBzs7A9OnA06fy03v/XnzdwQHQ1xc/u6srMHcu8OGD4nz26SNOZ+pU+dcLCsRl7+ICGBsDurpAixbAyJHAhQtKFYVCjh4FUlOl046KEv/drh1gaAjY2AAjRgBJSbL3z50LbN0KvGXbsrqmuvtcU/6eAp1lOnjy7onMNYFIgA7bO8B2ky0KSwsBAPm8fMy5OgefBX4G/eX6aLGpBX489yNe578GoHpfpfzzaC3VgtUGK/gE+SCtIE0mTz3391T4HU/JSwEngIN1t9cx5xaELQAngIMbKTdkwh97cgycAA7+iPxDbnzV3e9iy1pxWZeF7Xv9N6iJ8aOydfBBxgNwAjhYELZAYVzPsp+BE8CB/yV/tN/eHq23tEYxv1gmXEpeCgyWG2D4yeEqPVfZ9/evx3/JzUO3vd1UHrtIULZtkFdW8rj64io4ARwE3AiQufYy9yUMlhtg2Ilh4Av5aLymMdz3uiuMi4hgvdEanXZ2UqldlLQ7kkNzqSZsNtrA+7g3Hr59KPf5fzr/k9w8SMajnAAOsoqylCoDufGEzcfo9qPRwqSF1Pn49/Ho/1d/NFjRAI1WN8K4s+NkBKb+dv1h18gOK/9ZqXb6tYFWXWdAHkt7LkVL05bgCXj45/U/2B69HRefXcSTn5/AQNugyvHzhXwMOzEMF59dxO5Bu/HDFz8onf6Y9mMQGBmII7FHsKjHIrnxH31yFO0t2qNDkw4AgGJ+MbyPe+Pyi8voat0Vs7rOQhPDJsgpzkH4q3D8fOFn3HtzDyt6rcAh70NScc27Ng8NdBpg/lfzK3ymNwVvsOKfFTDUNqz0+T9v/Lna6ajCf72cy7L29lr0atULdo3spM7PvzYfq26twsROE/Gl5Zc4l3gO3535DhwOB6McRzHhhjgMgZGuEbZFbcNSj6Uqpc1SwyxdCrRsCfB4YrFi+3bg4kXgyRPAQM32ato04MsvAT4fePBALMJcuADExgKWlh/D6emJhYAF5To8hYXAuXPi6/Jo3hxYufJj2OfPgTNngL/+EgsGf/0lFoEkDB0KhIQAo0cDEyeK85WQAPz9N9C1q1g4Kc+bN8CKFWIhQh7FxYC3N3D5sjiOWbOAJk2AnBwgPBz4+Wfg3j1gzx7593fvDhySfo/x009i4WbSpI/nygtvyqCnBxw5AriX6+yEh4ufS1dX/n19+ohFQiIgPx949Ag4cADYtg1YvRr49dePYXNyxMJQQQHwww/iMszOBh4/FtchX1/5eT9zBrhzR3Henz8H+vUDXr0Sl+/48eJ4UlPF9XLgQLHYN26c6uUCAGvXAqNGiQUtCatXA7duAcOHAx06iEWlP/4AOnUC7t6VFh6HDBGLl9u2id8dljqnuvpcq3qvwrnEc5jy9xRETIgAh/NRRNl4ZyNi38XiwncXYKhjCBGJ0OdQHzx9/xQ/f/kz7M3s8TznObZFbcPlF5cR/0u82n2Vss9z981d7H+4H/+8/gdPfn4CPS0FbaISLOi+AMeeHMOUv6fgse9j6GjqAADyeHnwv+yPLy2/xM9f/iz33urud7Flrbisy8L2vf5bVOf4Udk6qMHRgENjBxx9chS/fy1/IvBI7BEAwLiO4zCi3Qh029sNy24uw4peK6TCTb04FTqaOtjSf4taz6WnpYcjsUcwtsNYqftT8lJwO/W22u9k+bbh4KODuJp8Veb8540/Vyq+Pq374Lv232HlPysxuv1o2JvZM9d+vvgztDW1scVzC7Q1tTG87XDsvL8Tr/JeyQgyAHDz1U28KXgDfzd//PDFD0q3i++LxMLNaMfR+OazbyAUCRGfFY/t0dsR8iwEd3+6C6emTsz9elp6OB1/GtsGbGPqgoSjT45CT0sPPAFPqeeXx8O3DxGaHIrbP9yWOv+m4A267+8OY11jrOi1Ah9KP2Dd7XWIzYxF5MRIqbxMdp6MWVdnIaBnABrqNlQ7LzUK1SP2xewjLAFFpUVJnf/10q+EJaAjj4+oHOf1l9cJS0An404SEVGpoJS8jnkRZwmHdkXvUit9uy125PCHg9z0br++TVgCWhWxijk3OXgyYQlo051Ncu9JykqirZFb5V5rt7Ud9djXo9LnHHlyJH194Gvqsa8HtdvartLw6qajCLacFZP5IZO0lmrRn/f/lDr/Jv8NaS/Vpl8u/MKcE4lE9NXer6j5huYkEAqkwk+9MJVabGxBIpFIqXRZaph9+4gAoijpeky//io+f0T19oquXxffe/Kk9PktW8TnV6wQ//3ypfjvb78V//vwoXT4w4eJtLWJBg0iMjSUvtajB1E7OXVXICD6+WdxfHPmfDwfGSk+t3y5/HuysuQ/y8iRRF9/rTi9yZPF8W6S/75SUhLRVvnvq0IMDYm+/161e8oi+U2//ZaocWMiPl/6+sSJRM7ORC1aEA0YIH0NIPrlF5IhK4uoSxfx9QsXPp5fs0Z87tYt2Xvy84mKi2XPFxcT2doSLV0qPz0+n8jRUVwO//wj/xkvXya6eFH+tcp48ECcbmio9Plbt4hKSqTPJSUR6eoSjRkjG8/UqeIyZNuyOqUm+lzHnxwnLAHtjN7JnHuV94oMlxvSiJMjmHO3Xt8iLAH9ce8Pqfv3PthLWAI68/SM3Pgr+lYrep65V+cSloCOPzkudb6i7/jL3JeEJaC1t9ZKnb/y/AphCWjJ9SXMucnBk0kzQJNiMmLkxqWIqva72LKuGLbv9d+hJtoyIuXr4LLwZYQloDupd+TG0yawjdR4xvdvX9Jeqk1PMp8w507FnSIsAW2L3Kbyc0nGYN8e/5a0lmrR+8L3UuGX31xOTdY2Ife97mqNEcvzy4VfCEuqJiFkfsgk01Wm5LHfgzl3NPYoYQloy90tzLmIVxGEJaCVESvlxjPp/CTSCNCgt
},
"metadata": {}
}
],
"source": [
"# Get the prediction model by extracting layers till the output layer\n",
"prediction_model = keras.models.Model(\n",
" model.get_layer(name=\"image\").input, model.get_layer(name=\"dense2\").output\n",
")\n",
"prediction_model.summary()\n",
"\n",
"# A utility function to decode the output of the network\n",
"def decode_batch_predictions(pred):\n",
" input_len = np.ones(pred.shape[0]) * pred.shape[1]\n",
" # Use greedy search. For complex tasks, you can use beam search\n",
" results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][\n",
" :, :max_length\n",
" ]\n",
" # Iterate over the results and get back the text\n",
" output_text = []\n",
" for res in results:\n",
" res = tf.strings.reduce_join(num_to_char(res)).numpy().decode(\"utf-8\")\n",
" output_text.append(res)\n",
" return output_text\n",
"\n",
"def lev(s1, s2):\n",
" m, n = len(s1), len(s2)\n",
" dp = np.zeros((m + 1, n + 1), dtype=int)\n",
"\n",
" for i in range(m + 1):\n",
" for j in range(n + 1):\n",
" if i == 0:\n",
" dp[i][j] = j\n",
" elif j == 0:\n",
" dp[i][j] = i\n",
" elif s1[i - 1] == s2[j - 1]:\n",
" dp[i][j] = dp[i - 1][j - 1]\n",
" else:\n",
" dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])\n",
"\n",
" return dp[m][n]\n",
"# Let's check results on some validation samples\n",
"_, ax = plt.subplots(4, 4, figsize=(15, 5))\n",
"\n",
"for batch in validation_dataset.take(1):\n",
" batch_images = batch[\"image\"]\n",
" batch_labels = batch[\"label\"]\n",
"\n",
" preds = prediction_model.predict(batch_images)\n",
" pred_texts = decode_batch_predictions(preds)\n",
"\n",
" orig_texts = []\n",
" for label in batch_labels:\n",
" #print(tf.strings.reduce_join(num_to_char(label)).numpy().decode(\"utf-8\"))\n",
" label = tf.strings.reduce_join(num_to_char(label)).numpy().decode(\"utf-8\")\n",
" orig_texts.append(label)\n",
"\n",
" for i in range(len(pred_texts)):\n",
" img = (batch_images[i, :, :, 0] * 255).numpy().astype(np.uint8)\n",
" img = img.T\n",
"\n",
" pred = pred_texts[i].replace('[UNK]', '')\n",
" comp = orig_texts[i];\n",
" if len(comp) == 5:\n",
" comp += ' '\n",
" if len(pred) == 5:\n",
" pred += ' '\n",
"\n",
" dist = lev(pred, comp)\n",
" #print([pred, comp])\n",
" title = f\"P: {pred} T: {comp} ({dist})\"\n",
" ax[i // 4, i % 4].imshow(img, cmap=\"gray\")\n",
" ax[i // 4, i % 4].set_title(title, color=('green' if comp in pred else 'red'))\n",
" ax[i // 4, i % 4].axis(\"off\")\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"source": [
"solved = []\n",
"for batch in validation_dataset:\n",
" batch_images = batch[\"image\"]\n",
" batch_labels = batch[\"label\"]\n",
"\n",
" preds = prediction_model.predict(batch_images)\n",
" pred_texts = decode_batch_predictions(preds)\n",
"\n",
" orig_texts = []\n",
" for label in batch_labels:\n",
" label = tf.strings.reduce_join(num_to_char(label)).numpy().decode(\"utf-8\")\n",
" orig_texts.append(label)\n",
"\n",
" for i in range(len(pred_texts)):\n",
" img = (batch_images[i, :, :, 0] * 255).numpy().astype(np.uint8)\n",
" img = img.T\n",
"\n",
" pred = pred_texts[i].replace('[UNK]', '')\n",
" comp = orig_texts[i];\n",
"\n",
" solved.append(comp == pred)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9bFKwT0Z0c7X",
"outputId": "d3783e8e-7156-4ab8-d7e7-b0e6e2f4848d"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 36ms/step\n",
"1/1 [==============================] - 0s 38ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 39ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 34ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 28ms/step\n",
"1/1 [==============================] - 0s 39ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 38ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 28ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 41ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 28ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 34ms/step\n",
"1/1 [==============================] - 0s 63ms/step\n",
"1/1 [==============================] - 0s 52ms/step\n",
"1/1 [==============================] - 0s 39ms/step\n",
"1/1 [==============================] - 0s 40ms/step\n",
"1/1 [==============================] - 0s 39ms/step\n",
"1/1 [==============================] - 0s 45ms/step\n",
"1/1 [==============================] - 0s 44ms/step\n",
"1/1 [==============================] - 0s 42ms/step\n",
"1/1 [==============================] - 0s 41ms/step\n",
"1/1 [==============================] - 0s 40ms/step\n",
"1/1 [==============================] - 0s 44ms/step\n",
"1/1 [==============================] - 0s 44ms/step\n",
"1/1 [==============================] - 0s 44ms/step\n",
"1/1 [==============================] - 0s 41ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 28ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 28ms/step\n",
"1/1 [==============================] - 0s 28ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 34ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 36ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 36ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 36ms/step\n",
"1/1 [==============================] - 0s 28ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 34ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 37ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 28ms/step\n",
"1/1 [==============================] - 0s 35ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 28ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 46ms/step\n",
"1/1 [==============================] - 0s 38ms/step\n",
"1/1 [==============================] - 0s 46ms/step\n",
"1/1 [==============================] - 0s 44ms/step\n",
"1/1 [==============================] - 0s 44ms/step\n",
"1/1 [==============================] - 0s 42ms/step\n",
"1/1 [==============================] - 0s 48ms/step\n",
"1/1 [==============================] - 0s 48ms/step\n",
"1/1 [==============================] - 0s 42ms/step\n",
"1/1 [==============================] - 0s 40ms/step\n",
"1/1 [==============================] - 0s 43ms/step\n",
"1/1 [==============================] - 0s 47ms/step\n",
"1/1 [==============================] - 0s 38ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 33ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 32ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 31ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 30ms/step\n",
"1/1 [==============================] - 0s 34ms/step\n",
"1/1 [==============================] - 0s 29ms/step\n",
"1/1 [==============================] - 0s 25ms/step\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(f'Solved {solved.count(True)}/{len(solved)} ({100 * (sum(solved) / len(solved)):.04f}%)')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RTfy0Jrp0txP",
"outputId": "bb89cdc3-0f92-453e-8fad-2675312a19d1"
},
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Solved 2657/2695 (98.5900%)\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "0T6jO2OrL5OK",
"outputId": "0f98dd24-214f-4f86-8d07-b3d2934e6517",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, lstm_cell_1_layer_call_fn, lstm_cell_1_layer_call_and_return_conditional_losses, lstm_cell_2_layer_call_fn while saving (showing 5 of 10). These functions will not be directly callable after loading.\n"
]
}
],
"source": [
"model.save('captcha_75_25.h5')\n",
"model.save('captcha_75_25.keras')\n",
"model.save('captcha_75_25.tf')"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}