Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented FastKAN in TinyGrad #16

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
653 changes: 653 additions & 0 deletions notebooks/FastKANTinyGrad.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,653 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 65,
"id": "cc5a5054-7017-4139-943d-d3e8b0999b46",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"METAL\n"
]
}
],
"source": [
"from tinygrad import Device\n",
"print (Device.DEFAULT)\n",
"from typing import *"
]
},
{
"cell_type": "code",
"execution_count": 474,
"id": "922745a9-09f7-41f2-a406-d31c59d45c80",
"metadata": {},
"outputs": [],
"source": [
"from tinygrad import Tensor, nn\n",
"import tinygrad.function as F\n",
"import numpy as np\n",
"\n",
"class SplineLinearFunction:\n",
" def __init__(\n",
" self,\n",
" in_features: int, out_features: int, init_scale: float = 0.1\n",
" ):\n",
" self.init_scale = init_scale\n",
" self.linear_function = nn.Linear(in_features, out_features, bias=False)\n",
"\n",
" def __call__(self, x:Tensor) -> Tensor:\n",
" return self.linear_function(x)\n",
"\n",
"class RadialBasisFunction:\n",
" def __init__(\n",
" self,\n",
" grid_min = -2.,\n",
" grid_max = 2.,\n",
" num_grids = 8,\n",
" denominator = None\n",
" ):\n",
" self.grid_min = grid_min\n",
" self.grid_max = grid_max\n",
" self.num_grids = num_grids\n",
" # You don't need a special Parameter initialization here.\n",
" # You can initialize a Tensor later with this\n",
" self.grid = Tensor(np.linspace(grid_min, grid_max, num_grids, dtype=np.float32), requires_grad=True)\n",
" self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)\n",
" def __call__(self, x):\n",
" return (-(((x[..., None] - self.grid) / self.denominator) ** 2)).exp()"
]
},
{
"cell_type": "code",
"execution_count": 475,
"id": "4184fb72-c4a7-4ac6-b9e2-1648335cf57d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<Tensor <LB METAL (4, 8) float (<UnaryOps.EXP2: 1>, None)> on METAL with grad None>\n"
]
},
{
"data": {
"text/plain": [
"<Tensor <LB METAL (3,) float ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),))> on METAL with grad None>"
]
},
"execution_count": 475,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rbf = RadialBasisFunction()\n",
"print(rbf(Tensor([0., 1., 3., 10.])))\n",
"\n",
"slf = SplineLinearFunction(1, 3, 1)\n",
"slf(Tensor([10]))"
]
},
{
"cell_type": "code",
"execution_count": 478,
"id": "3f6838a0-9247-455f-8146-2dc53fdb6131",
"metadata": {},
"outputs": [],
"source": [
"class FastKANLayer:\n",
" def __init__(\n",
" self,\n",
" input_dim: int,\n",
" output_dim: int,\n",
" grid_min: float = -2.,\n",
" grid_max: float = 2.,\n",
" num_grids: int = 8,\n",
" use_base_update: bool = True,\n",
" use_layernorm: bool = True,\n",
" base_activation = Tensor.silu,\n",
" spline_weight_init_scale: float = 0.1\n",
" ) -> None:\n",
" self.input_dim = input_dim\n",
" self.output_dim = output_dim\n",
" # normally you'd init layernorm here.\n",
" # but because layernorm *isn't* a layer in tinygrad,\n",
" # it's a function, I'm gonna hold off until the call\n",
" self.layernorm = None\n",
" if use_layernorm:\n",
" assert input_dim > 1, \"Do not use layernorms on 1D inputs. Set `use_layernorm=False`.\"\n",
" self.layernorm = nn.LayerNorm(input_dim)\n",
" self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)\n",
" self.spline_linear = SplineLinearFunction(input_dim * num_grids, output_dim, spline_weight_init_scale)\n",
" self.use_base_update = use_base_update\n",
" if use_base_update:\n",
" self.base_activation = base_activation\n",
" self.base_linear = nn.Linear(input_dim, output_dim)\n",
"\n",
" def __call__(\n",
" self, x: Tensor, use_layernorm=True\n",
" ) -> Tensor:\n",
" if self.layernorm is not None and use_layernorm:\n",
" spline_basis = self.rbf(self.layernorm(x))\n",
" else:\n",
" spline_basis = self.rbf(x)\n",
" spline_basis_view = spline_basis.view(*spline_basis.shape[:-2], -1)\n",
" ret = self.spline_linear(spline_basis_view)\n",
" if self.use_base_update:\n",
" base = self.base_linear(self.base_activation(x))\n",
" ret = ret + base\n",
" return ret\n",
"\n",
" def plot_curve(\n",
" self,\n",
" input_index: int,\n",
" output_index: int,\n",
" num_pts: int = 1000,\n",
" num_extrapolate_bins: int = 2\n",
" ):\n",
" '''this function returns the learned curves in a FastKANLayer.\n",
" input_index: the selected index of the input, in [0, input_dim) .\n",
" output_index: the selected index of the output, in [0, output_dim) .\n",
" num_pts: num of points sampled for the curve.\n",
" num_extrapolate_bins (N_e): num of bins extrapolating from the given grids. The curve \n",
" will be calculate in the range of [grid_min - h * N_e, grid_max + h * N_e].\n",
" '''\n",
" ng = self.rbf.num_grids\n",
" h = self.rbf.denominator\n",
" assert input_index < self.input_dim\n",
" assert output_index < self.output_dim\n",
" w = self.spline_linear.linear_function.weight[\n",
" output_index, input_index * ng : (input_index + 1) * ng\n",
" ] # num_grids,\n",
" x = Tensor(np.linspace(\n",
" self.rbf.grid_min - num_extrapolate_bins * h,\n",
" self.rbf.grid_max + num_extrapolate_bins * h,\n",
" num_pts\n",
" )) # num_pts, num_grids\n",
" Tensor.no_grad = True\n",
" y = (w * self.rbf(x)).sum(-1)\n",
" Tensor.no_grad = False\n",
" return x, y"
]
},
{
"cell_type": "code",
"execution_count": 479,
"id": "a28a1eea-71f7-4cc3-9602-da1ab48cdf35",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Tensor <LB METAL (2, 2) float (<BinaryOps.ADD: 1>, None)> on METAL with grad None>"
]
},
"execution_count": 479,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fastKANLayer = FastKANLayer(2, 2)\n",
"fastKANLayer(Tensor([[1, 2], [1, 2]]))"
]
},
{
"cell_type": "code",
"execution_count": 480,
"id": "bc957af7-a460-48a8-ac57-482d67363048",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<>:21: SyntaxWarning: invalid escape sequence '\\p'\n",
"<>:21: SyntaxWarning: invalid escape sequence '\\p'\n",
"/var/folders/w1/gydggfx96qv449qgnv_xxj000000gn/T/ipykernel_51246/3288975105.py:21: SyntaxWarning: invalid escape sequence '\\p'\n",
" plt.ylabel(\"$\\phi_{p,q}(x)$\")\n"
]
},
{
"ename": "ValueError",
"evalue": "setting an array element with a sequence.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;31mTypeError\u001b[0m: float() argument must be a string or a real number, not 'Tensor'",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[480], line 19\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m j \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(d_out):\n\u001b[1;32m 18\u001b[0m x, y \u001b[38;5;241m=\u001b[39m layer\u001b[38;5;241m.\u001b[39mplot_curve(i, j, \u001b[38;5;241m200\u001b[39m, num_extrapolate_bins\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m)\n\u001b[0;32m---> 19\u001b[0m plt\u001b[38;5;241m.\u001b[39mplot(\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mfloat\u001b[39;49m\u001b[43m)\u001b[49m, y\u001b[38;5;241m.\u001b[39mcast(dtypes\u001b[38;5;241m.\u001b[39mfloat)\u001b[38;5;241m.\u001b[39mnumpy(), label\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m$\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124mphi_\u001b[39m\u001b[38;5;124m{\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m,\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mj\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m}$\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 20\u001b[0m plt\u001b[38;5;241m.\u001b[39mxlabel(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m$x$\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 21\u001b[0m plt\u001b[38;5;241m.\u001b[39mylabel(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m$\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124mphi_\u001b[39m\u001b[38;5;124m{\u001b[39m\u001b[38;5;124mp,q}(x)$\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mValueError\u001b[0m: setting an array element with a sequence."
]
}
],
"source": [
"from tinygrad import dtypes\n",
"d_in = 2\n",
"d_out = 3\n",
"\n",
"layer = FastKANLayer(\n",
" d_in, d_out,\n",
" use_base_update=False,\n",
" use_layernorm=False\n",
")\n",
"\n",
"x, y = layer.plot_curve(0, 1, num_pts=1000, num_extrapolate_bins=3)\n",
"x.shape, y.shape\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"for i in range(d_in):\n",
" for j in range(d_out):\n",
" x, y = layer.plot_curve(i, j, 200, num_extrapolate_bins=3)\n",
" plt.plot(np.array(x, dtype=float), y.cast(dtypes.float).numpy(), label=r\"$\\phi_{\" + f\"{i},{j}\" + r\"}$\")\n",
"plt.xlabel(\"$x$\")\n",
"plt.ylabel(\"$\\phi_{p,q}(x)$\")\n",
"plt.legend(loc=\"upper right\")"
]
},
{
"cell_type": "code",
"execution_count": 481,
"id": "b98d7920-0f18-4813-892d-9bb432095ecc",
"metadata": {},
"outputs": [],
"source": [
"class FastKAN:\n",
" def __init__(\n",
" self,\n",
" layers_hidden: List[int],\n",
" grid_min: float = -2.,\n",
" grid_max: float = 2.,\n",
" num_grids: int = 8,\n",
" use_base_update: bool = True,\n",
" base_activation = Tensor.silu,\n",
" spline_weight_init_scale: float = 0.1,\n",
" ) -> None:\n",
" self.layers = [\n",
" FastKANLayer(\n",
" in_dim, out_dim,\n",
" grid_min=grid_min,\n",
" grid_max=grid_max,\n",
" num_grids=num_grids,\n",
" use_base_update=use_base_update,\n",
" base_activation=base_activation,\n",
" spline_weight_init_scale=spline_weight_init_scale,\n",
" ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])\n",
" ]\n",
"\n",
" def __call__(self, x):\n",
" for layer in self.layers:\n",
" x = layer(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 482,
"id": "a277cad6-2777-444b-a26a-477ed256fd9d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Tensor <LB METAL (3, 1) float (<BinaryOps.ADD: 1>, None)> on METAL with grad None>"
]
},
"execution_count": 482,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fastKAN = FastKAN([2, 3, 1])\n",
"fastKAN(Tensor([[2, 1], [1, 2], [5, 9]]))"
]
},
{
"cell_type": "code",
"execution_count": 458,
"id": "5c6a297c-da13-481c-ad07-0c8428c1896e",
"metadata": {},
"outputs": [],
"source": [
"class AttentionWithFastKANTransform:\n",
" def __init__(\n",
" self,\n",
" q_dim: int,\n",
" k_dim: int,\n",
" v_dim: int,\n",
" head_dim: int,\n",
" num_heads: int,\n",
" gating: bool = True,\n",
" ):\n",
" self.num_heads = num_heads\n",
" total_dim = head_dim * self.num_heads\n",
" self.gating = gating\n",
" self.linear_q = FastKANLayer(q_dim, total_dim)\n",
" self.linear_k = FastKANLayer(k_dim, total_dim)\n",
" self.linear_v = FastKANLayer(v_dim, total_dim)\n",
" self.linear_o = FastKANLayer(total_dim, q_dim)\n",
" self.linear_g = None\n",
" if self.gating:\n",
" self.linear_g = FastKANLayer(q_dim, total_dim)\n",
" # precompute the 1/sqrt(head_dim)\n",
" self.norm = head_dim**-0.5\n",
"\n",
" def __call__(\n",
" self,\n",
" q: Tensor,\n",
" k: Tensor,\n",
" v: Tensor,\n",
" bias: Tensor = None, # additive attention bias\n",
" ) -> Tensor: \n",
"\n",
" wq = self.linear_q(q).view(*q.shape[:-1], 1, self.num_heads, -1) * self.norm # *q1hc\n",
" wk = self.linear_k(k).view(*k.shape[:-2], 1, k.shape[-2], self.num_heads, -1) # *1khc\n",
" att = (wq * wk).sum(-1).softmax(-2) # *qkh\n",
" del wq, wk\n",
" if bias is not None:\n",
" att = att + bias[..., None]\n",
"\n",
" wv = self.linear_v(v).view(*v.shape[:-2],1, v.shape[-2], self.num_heads, -1) # *1khc\n",
" o = (att[..., None] * wv).sum(-3) # *qhc\n",
" del att, wv\n",
"\n",
" o = o.view(*o.shape[:-2], -1) # *q(hc)\n",
"\n",
" if self.linear_g is not None:\n",
" # gating, use raw query input\n",
" g = self.linear_g(q)\n",
" o = torch.sigmoid(g) * o\n",
"\n",
" # merge heads\n",
" o = self.linear_o(o)\n",
" return o"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "a6072486-7be0-4597-b384-49b422c8a6e2",
"metadata": {},
"outputs": [],
"source": [
"attentionWithFastKANTransform = AttentionWithFastKANTransform(\n",
" 2, 3, 4, 10, 5\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 483,
"id": "6f5fb137-bcff-4982-83b0-c95579a0b5b8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(60000, 1, 28, 28) dtypes.uchar (60000,) dtypes.uchar\n",
"(10000, 1, 28, 28) dtypes.uchar (10000,) dtypes.uchar\n"
]
}
],
"source": [
"#Okay, we've ascended the hill!\n",
"# Time to train MNIST\n",
"\n",
"from tinygrad.nn.datasets import mnist\n",
"X_train, Y_train, X_test, Y_test = mnist()\n",
"print(X_train.shape, X_train.dtype, Y_train.shape, Y_train.dtype)\n",
"print(X_test.shape, X_test.dtype, Y_test.shape, Y_test.dtype)\n",
"# (60000, 1, 28, 28) dtypes.uchar (60000,) dtypes.uchar"
]
},
{
"cell_type": "code",
"execution_count": 484,
"id": "9f50ecf8-edaa-4a29-bc77-96a115dcf86e",
"metadata": {},
"outputs": [],
"source": [
"model = FastKAN([28 * 28, 64, 10])"
]
},
{
"cell_type": "code",
"execution_count": 485,
"id": "d8c8dbe2-ad60-491f-930b-4e29b72a3e01",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<Tensor <LB METAL (10000,) int (<BinaryOps.ADD: 1>, None)> on METAL with grad None>\n",
"(10000,)\n",
"0.13490000367164612\n"
]
}
],
"source": [
"print(model(X_test.view(-1, 28 * 28)).argmax(axis=1))\n",
"print(Y_test.shape)\n",
"\n",
"acc = (model(X_test.view(-1, 28 * 28)).argmax(axis=1) == Y_test).mean()\n",
"# NOTE: tinygrad is lazy, and hasn't actually run anything by this point\n",
"print(acc.item()) # ~10% accuracy, as expected from a random model"
]
},
{
"cell_type": "code",
"execution_count": 486,
"id": "b2c4d91c-7e08-4db2-a844-ed9de2f43304",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[<Tensor <LB METAL (784,) float ShapeTracker(views=(View(shape=(784,), strides=(0,), offset=0, mask=None, contiguous=False),))> on METAL with grad None>, <Tensor <LB METAL (784,) float ShapeTracker(views=(View(shape=(784,), strides=(0,), offset=0, mask=None, contiguous=False),))> on METAL with grad None>, <Tensor <LB METAL (8,) float (<MetaOps.COPY: 3>, <buf real:True device:METAL size:8 dtype:dtypes.float offset:0>)> on METAL with grad None>, <Tensor <LB METAL (64, 6272) float (<BinaryOps.ADD: 1>, <buf real:True device:METAL size:401408 dtype:dtypes.float offset:0>)> on METAL with grad None>, <Tensor <LB METAL (64, 784) float (<BinaryOps.ADD: 1>, <buf real:True device:METAL size:50176 dtype:dtypes.float offset:0>)> on METAL with grad None>, <Tensor <LB METAL (64,) float (<BinaryOps.ADD: 1>, <buf real:True device:METAL size:64 dtype:dtypes.float offset:0>)> on METAL with grad None>, <Tensor <LB METAL (64,) float ShapeTracker(views=(View(shape=(64,), strides=(0,), offset=0, mask=None, contiguous=False),))> on METAL with grad None>, <Tensor <LB METAL (64,) float ShapeTracker(views=(View(shape=(64,), strides=(0,), offset=0, mask=None, contiguous=False),))> on METAL with grad None>, <Tensor <LB METAL (8,) float (<MetaOps.COPY: 3>, <buf real:True device:METAL size:8 dtype:dtypes.float offset:0>)> on METAL with grad None>, <Tensor <LB METAL (10, 512) float (<BinaryOps.ADD: 1>, <buf real:True device:METAL size:5120 dtype:dtypes.float offset:0>)> on METAL with grad None>, <Tensor <LB METAL (10, 64) float (<BinaryOps.ADD: 1>, <buf real:True device:METAL size:640 dtype:dtypes.float offset:0>)> on METAL with grad None>, <Tensor <LB METAL (10,) float (<BinaryOps.ADD: 1>, <buf real:True device:METAL size:10 dtype:dtypes.float offset:0>)> on METAL with grad None>]\n"
]
}
],
"source": [
"print(nn.state.get_parameters(model))\n",
"optim = nn.optim.AdamW(nn.state.get_parameters(model), lr=1e-3, weight_decay=1e-4)\n",
"batch_size = 128\n",
"def step():\n",
" Tensor.training = True # makes dropout work\n",
" samples = Tensor.randint(batch_size, high=X_train.shape[0])\n",
" X, Y = X_train[samples], Y_train[samples]\n",
" optim.zero_grad()\n",
" loss = (model(X.view(-1, 28 * 28)) + 1e-8).sparse_categorical_crossentropy(Y).backward()\n",
" optim.step()\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 487,
"id": "7a830e30-57f1-4ccf-bac5-a35592cae10c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.3113804580643773,\n",
" 0.02766275010071695,\n",
" 0.02911062492057681,\n",
" 0.02816504193469882,\n",
" 0.025042542023584247]"
]
},
"execution_count": 487,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import timeit\n",
"timeit.repeat(step, repeat=5, number=1)"
]
},
{
"cell_type": "code",
"execution_count": 488,
"id": "aef1c655-2d4c-41ac-9a03-67fac5275b8d",
"metadata": {},
"outputs": [],
"source": [
"from tinygrad import TinyJit\n",
"jit_step = TinyJit(step)"
]
},
{
"cell_type": "code",
"execution_count": 489,
"id": "d2f917eb-5046-4287-85d4-7a4ab8d70d89",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.0612322089727968,\n",
" 0.032529500080272555,\n",
" 0.027406624983996153,\n",
" 0.056881166994571686,\n",
" 0.020832375157624483]"
]
},
"execution_count": 489,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import timeit\n",
"timeit.repeat(jit_step, repeat=5, number=1)"
]
},
{
"cell_type": "code",
"execution_count": 490,
"id": "cd0245b2-de81-414a-a7cf-8b4620f590d4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 0, loss 5.56, acc 64.73%\n",
"step 100, loss 0.47, acc 86.68%\n",
"step 200, loss 0.23, acc 89.41%\n",
"step 300, loss 0.22, acc 91.03%\n",
"step 400, loss 0.17, acc 92.23%\n",
"step 500, loss 0.13, acc 93.24%\n",
"step 600, loss 0.20, acc 93.58%\n",
"step 700, loss 0.13, acc 94.19%\n",
"step 800, loss 0.12, acc 93.95%\n",
"step 900, loss 0.17, acc 94.72%\n",
"step 1000, loss 0.11, acc 94.75%\n",
"step 1100, loss 0.05, acc 95.34%\n",
"step 1200, loss 0.11, acc 95.40%\n",
"step 1300, loss 0.05, acc 95.67%\n",
"step 1400, loss 0.09, acc 95.55%\n",
"step 1500, loss 0.13, acc 95.87%\n",
"step 1600, loss 0.10, acc 96.31%\n",
"step 1700, loss 0.08, acc 96.02%\n",
"step 1800, loss 0.04, acc 95.80%\n",
"step 1900, loss 0.05, acc 96.41%\n",
"step 2000, loss 0.05, acc 96.12%\n",
"step 2100, loss 0.09, acc 96.12%\n",
"step 2200, loss 0.07, acc 96.35%\n",
"step 2300, loss 0.11, acc 96.51%\n",
"step 2400, loss 0.05, acc 96.38%\n",
"step 2500, loss 0.12, acc 96.55%\n",
"step 2600, loss 0.03, acc 96.24%\n",
"step 2700, loss 0.04, acc 96.54%\n",
"step 2800, loss 0.04, acc 96.78%\n",
"step 2900, loss 0.02, acc 96.55%\n",
"step 3000, loss 0.07, acc 96.45%\n",
"step 3100, loss 0.07, acc 96.92%\n",
"step 3200, loss 0.02, acc 96.74%\n",
"step 3300, loss 0.06, acc 97.00%\n",
"step 3400, loss 0.03, acc 96.84%\n",
"step 3500, loss 0.05, acc 97.01%\n",
"step 3600, loss 0.03, acc 97.10%\n",
"step 3700, loss 0.02, acc 96.85%\n",
"step 3800, loss 0.02, acc 97.07%\n",
"step 3900, loss 0.01, acc 96.69%\n",
"step 4000, loss 0.02, acc 97.00%\n",
"step 4100, loss 0.05, acc 97.15%\n",
"step 4200, loss 0.03, acc 96.85%\n",
"step 4300, loss 0.03, acc 96.94%\n",
"step 4400, loss 0.02, acc 97.06%\n",
"step 4500, loss 0.02, acc 96.86%\n",
"step 4600, loss 0.04, acc 96.86%\n",
"step 4700, loss 0.06, acc 97.18%\n",
"step 4800, loss 0.01, acc 97.10%\n",
"step 4900, loss 0.05, acc 96.73%\n",
"step 5000, loss 0.01, acc 96.81%\n",
"step 5100, loss 0.04, acc 96.94%\n",
"step 5200, loss 0.03, acc 96.60%\n",
"step 5300, loss 0.03, acc 96.87%\n",
"step 5400, loss 0.02, acc 97.25%\n",
"step 5500, loss 0.04, acc 97.23%\n",
"step 5600, loss 0.05, acc 97.20%\n",
"step 5700, loss 0.02, acc 97.05%\n",
"step 5800, loss 0.02, acc 97.26%\n",
"step 5900, loss 0.01, acc 97.30%\n",
"step 6000, loss 0.06, acc 97.27%\n",
"step 6100, loss 0.05, acc 96.99%\n",
"step 6200, loss 0.04, acc 96.95%\n",
"step 6300, loss 0.00, acc 97.10%\n",
"step 6400, loss 0.00, acc 97.00%\n",
"step 6500, loss 0.05, acc 96.85%\n",
"step 6600, loss 0.03, acc 97.48%\n",
"step 6700, loss 0.07, acc 97.05%\n",
"step 6800, loss 0.03, acc 97.29%\n",
"step 6900, loss 0.02, acc 97.46%\n"
]
}
],
"source": [
"for step in range(7000):\n",
" loss = jit_step()\n",
" if step%100 == 0:\n",
" Tensor.training = False\n",
" acc = (model(X_test.view(-1, 28 * 28)).argmax(axis=1) == Y_test).mean().item()\n",
" print(f\"step {step:4d}, loss {loss.item():.2f}, acc {acc*100.:.2f}%\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}