Skip to content

Commit

Permalink
Add complex dtypes to examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl committed Mar 31, 2024
1 parent d01f928 commit 0bc4280
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 19 deletions.
22 changes: 16 additions & 6 deletions docs/examples/classical_solve.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,45 @@
"execution_count": 1,
"id": "cb3a7781-2358-40c4-82f3-e908bddeb578",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-31T05:43:45.094105Z",
"start_time": "2024-03-31T05:43:44.111990Z"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"solution: [-2.7321298 -8.52878 -7.7226872]\n"
"A=\n",
"[[-1.8459436 -0.2744466j 0.02393756-0.03172905j 0.76815367-1.4444253j ]\n",
" [-1.0467293 +0.05608991j 1.0891742 -0.03264743j 0.7513123 +0.56285536j]\n",
" [ 0.38307396-1.0190808j 0.01203694-1.1971304j 0.19252291-0.26424018j]]\n",
"b=[0.23162952+0.3614433j 0.05800135+1.6094692j 0.8979094 +0.16941352j]\n",
"x=[-0.07652722-0.34397143j -0.22629777+1.0359733j 0.22135164-0.00880566j]\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import lineax as lx\n",
"\n",
"\n",
"matrix = jr.normal(jr.PRNGKey(0), (3, 3))\n",
"vector = jr.normal(jr.PRNGKey(1), (3,))\n",
"matrix = jr.normal(jr.PRNGKey(0), (3, 3), dtype=jnp.complex64)\n",
"vector = jr.normal(jr.PRNGKey(1), (3,), dtype=jnp.complex64)\n",
"operator = lx.MatrixLinearOperator(matrix)\n",
"solution = lx.linear_solve(operator, vector)\n",
"print(\"solution:\", solution.value)"
"print(f\"A=\\n{matrix}\\nb={vector}\\nx={solution.value}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py39",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "py39"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
74 changes: 61 additions & 13 deletions docs/examples/structured_matrices.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
"execution_count": 1,
"id": "8e275652-dd80-4a9a-b3ac-b96dc16d3334",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-31T05:45:20.247996Z",
"start_time": "2024-03-31T05:45:19.640650Z"
},
"tags": []
},
"outputs": [
Expand Down Expand Up @@ -50,6 +54,10 @@
"execution_count": 2,
"id": "ba23ecc4-bdea-4293-a138-ce77bc83082c",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-31T05:45:20.346869Z",
"start_time": "2024-03-31T05:45:20.249222Z"
},
"tags": []
},
"outputs": [],
Expand All @@ -72,6 +80,10 @@
"execution_count": 3,
"id": "6984f62f-75fc-4d6e-ab42-fdade471be5b",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-31T05:45:20.350245Z",
"start_time": "2024-03-31T05:45:20.347651Z"
},
"tags": []
},
"outputs": [
Expand Down Expand Up @@ -101,6 +113,10 @@
"execution_count": 4,
"id": "102ada9a-0533-40cf-9bad-02918fffb6b1",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-31T05:45:20.407881Z",
"start_time": "2024-03-31T05:45:20.350889Z"
},
"tags": []
},
"outputs": [],
Expand All @@ -118,9 +134,13 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "d8f5bf66-53cd-4e81-a8d7-a19e86307ad3",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-31T05:45:20.754368Z",
"start_time": "2024-03-31T05:45:20.409445Z"
},
"tags": []
},
"outputs": [
Expand All @@ -129,7 +149,14 @@
"evalue": "`Tridiagonal` may only be used for linear solves with tridiagonal matrices",
"output_type": "error",
"traceback": [
"\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m `Tridiagonal` may only be used for linear solves with tridiagonal matrices\n"
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[5], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m not_tridiagonal_matrix \u001b[38;5;241m=\u001b[39m jr\u001b[38;5;241m.\u001b[39mnormal(jr\u001b[38;5;241m.\u001b[39mPRNGKey(\u001b[38;5;241m0\u001b[39m), (\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m4\u001b[39m))\n\u001b[1;32m 2\u001b[0m not_tridiagonal_operator \u001b[38;5;241m=\u001b[39m lx\u001b[38;5;241m.\u001b[39mMatrixLinearOperator(not_tridiagonal_matrix)\n\u001b[0;32m----> 3\u001b[0m solution \u001b[38;5;241m=\u001b[39m \u001b[43mlx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear_solve\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnot_tridiagonal_operator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msolver\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTridiagonal\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
" \u001b[0;31m[... skipping hidden 16 frame]\u001b[0m\n",
"File \u001b[0;32m~/PycharmProjects/lineax/lineax/_solve.py:792\u001b[0m, in \u001b[0;36mlinear_solve\u001b[0;34m(operator, vector, solver, options, state, throw)\u001b[0m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m Solution(\n\u001b[1;32m 786\u001b[0m value\u001b[38;5;241m=\u001b[39mvector,\n\u001b[1;32m 787\u001b[0m result\u001b[38;5;241m=\u001b[39mRESULTS\u001b[38;5;241m.\u001b[39msuccessful,\n\u001b[1;32m 788\u001b[0m state\u001b[38;5;241m=\u001b[39mstate,\n\u001b[1;32m 789\u001b[0m stats\u001b[38;5;241m=\u001b[39m{},\n\u001b[1;32m 790\u001b[0m )\n\u001b[1;32m 791\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m state \u001b[38;5;241m==\u001b[39m sentinel:\n\u001b[0;32m--> 792\u001b[0m state \u001b[38;5;241m=\u001b[39m \u001b[43msolver\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\u001b[43moperator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 793\u001b[0m dynamic_state, static_state \u001b[38;5;241m=\u001b[39m eqx\u001b[38;5;241m.\u001b[39mpartition(state, eqx\u001b[38;5;241m.\u001b[39mis_array)\n\u001b[1;32m 794\u001b[0m dynamic_state \u001b[38;5;241m=\u001b[39m lax\u001b[38;5;241m.\u001b[39mstop_gradient(dynamic_state)\n",
" \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
"File \u001b[0;32m~/PycharmProjects/lineax/lineax/_solver/tridiagonal.py:47\u001b[0m, in \u001b[0;36mTridiagonal.init\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 44\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Tridiagonal` may only be used for linear solves with square matrices\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 45\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_tridiagonal(operator):\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 48\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Tridiagonal` may only be used for linear solves with tridiagonal \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatrices\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 50\u001b[0m )\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m tridiagonal(operator), pack_structures(operator)\n",
"\u001b[0;31mValueError\u001b[0m: `Tridiagonal` may only be used for linear solves with tridiagonal matrices"
]
}
],
Expand All @@ -153,15 +180,20 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"id": "b5add874-7a2c-4000-84c3-8c94a121a831",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-31T05:45:23.756921Z",
"start_time": "2024-03-31T05:45:23.599534Z"
},
"tags": []
},
"outputs": [],
"source": [
"matrix = jr.normal(jr.PRNGKey(0), (4, 4))\n",
"operator = lx.MatrixLinearOperator(matrix.T @ matrix)"
"matrix = jr.normal(jr.PRNGKey(0), (4, 4), dtype=jnp.complex64)\n",
"operator = lx.MatrixLinearOperator(matrix.T.conj() @ matrix)\n",
"vector = vector.astype(jnp.complex64)"
]
},
{
Expand All @@ -174,9 +206,13 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"id": "78400416-e774-4f74-a530-e368db84af0e",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-31T05:45:25.369705Z",
"start_time": "2024-03-31T05:45:25.141555Z"
},
"tags": []
},
"outputs": [
Expand All @@ -203,9 +239,13 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"id": "f6dc2966-1dfa-4a3c-be6a-974926695547",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-31T05:45:27.161215Z",
"start_time": "2024-03-31T05:45:27.094359Z"
},
"tags": []
},
"outputs": [
Expand All @@ -218,7 +258,9 @@
}
],
"source": [
"operator = lx.MatrixLinearOperator(matrix.T @ matrix, lx.positive_semidefinite_tag)\n",
"operator = lx.MatrixLinearOperator(\n",
" matrix.T.conj() @ matrix, lx.positive_semidefinite_tag\n",
")\n",
"solution2 = lx.linear_solve(operator, vector)\n",
"print(default_solver.select_solver(operator))"
]
Expand All @@ -233,18 +275,24 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"id": "fdcde152-9ac1-4532-a174-3fc39d83d289",
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-31T05:45:28.758538Z",
"start_time": "2024-03-31T05:45:28.753906Z"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 1.400575 -0.41042092 0.5313305 0.28422552]\n",
"[ 1.4005749 -0.41042086 0.53133047 0.2842255 ]\n"
"[ 1.1774138-1.3543551j -7.3548455-4.9371862j 13.12787 -3.2803552j\n",
" 2.7882547+6.808086j ]\n",
"[ 1.1774142-1.3543541j -7.3548408-4.9371862j 13.127863 -3.2803519j\n",
" 2.7882524+6.8080816j]\n"
]
}
],
Expand All @@ -256,9 +304,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "py39",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "py39"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down

0 comments on commit 0bc4280

Please sign in to comment.