Skip to content

Commit 56b406d

Browse files
committed
experiments: add quantization experiment (with timestep issues)
1 parent fbafcfa commit 56b406d

File tree

2 files changed

+173
-1
lines changed

2 files changed

+173
-1
lines changed

notebooks/quantization.ipynb

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"%load_ext autoreload\n",
10+
"%autoreload 2\n",
11+
"\n",
12+
"import numpy as np\n",
13+
"import matplotlib.pyplot as plt\n",
14+
"import plotly.express as px\n",
15+
"import networkx\n",
16+
"from loguru import logger\n",
17+
"from tqdm.notebook import tqdm\n",
18+
"\n",
19+
"from pim.models.network import Network\n",
20+
"from pim.models.new.stone import StoneExperiment, StoneResults\n",
21+
"from pim.models.new.stone.rate import CXRatePontin, CPU4PontinLayer\n",
22+
"\n",
23+
"from pim.models.stone import analysis\n",
24+
"\n",
25+
"logger.remove()\n"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"x = np.linspace(0, 1, 100)\n",
35+
"bins = np.linspace(0, 1, 10, endpoint=False)\n",
36+
"print(bins)\n",
37+
"y = (np.digitize(x, bins)-1) / 10\n",
38+
"plt.plot(x, x)\n",
39+
"plt.plot(x, y)"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": null,
45+
"metadata": {},
46+
"outputs": [],
47+
"source": [
48+
"def create_quantized_layer(N):\n",
49+
" def closure(*args, **kwargs):\n",
50+
" return QuantizedCPU4PontinLayer(N, *args, **kwargs)\n",
51+
" return closure\n",
52+
"\n",
53+
"class QuantizedCPU4PontinLayer(CPU4PontinLayer):\n",
54+
" def __init__(self, N, *args, **kwargs):\n",
55+
" super().__init__(*args, **kwargs)\n",
56+
" self.N = N\n",
57+
" self.bins = np.linspace(0, 1, self.N, endpoint = False)\n",
58+
" \n",
59+
" def step(self, network: Network, dt: float):\n",
60+
" \"\"\"Memory neurons update.\n",
61+
" cpu4[0-7] store optic flow peaking at left 45 deg\n",
62+
" cpu[8-15] store optic flow peaking at right 45 deg.\"\"\"\n",
63+
" tb1 = network.output(self.TB1)\n",
64+
" tn1 = network.output(self.TN1) * dt\n",
65+
" tn2 = network.output(self.TN2) * dt\n",
66+
"\n",
67+
" mem_update = np.dot(self.W_TN, tn2)\n",
68+
" mem_update -= np.dot(self.W_TB1, tb1)\n",
69+
" mem_update = np.clip(mem_update, 0, 1)\n",
70+
" mem_update *= self.gain\n",
71+
" self.memory += mem_update\n",
72+
" self.memory -= 0.125 * self.gain * dt\n",
73+
" self.memory = np.clip((np.digitize(self.memory, self.bins) - 1) / self.N, 0.0, 1.0)"
74+
]
75+
},
76+
{
77+
"cell_type": "code",
78+
"execution_count": null,
79+
"metadata": {},
80+
"outputs": [],
81+
"source": [
82+
"parameters = {\n",
83+
" \"model\": \"stone\",\n",
84+
" \"T_outbound\": 1500,\n",
85+
" \"T_inbound\": 1000,\n",
86+
" \"time_subdivision\": 1,\n",
87+
" \"noise\": 0.1,\n",
88+
" \"cx\": \"pontin\"\n",
89+
"}\n",
90+
"\n",
91+
"def create_experiment(cpu4):\n",
92+
" cx = CXRatePontin(CPU4LayerClass=cpu4, noise = parameters[\"noise\"])\n",
93+
" cx.setup()\n",
94+
" experiment = StoneExperiment(parameters)\n",
95+
" experiment.cx = cx\n",
96+
" return experiment\n",
97+
"\n",
98+
"def run_experiment(cpu4, N = 0, ts = 1, report = False):\n",
99+
" experiment = create_experiment(cpu4)\n",
100+
" experiment.parameters[\"time_subdivision\"] = ts\n",
101+
" results = experiment.run(\"test\")\n",
102+
" if report:\n",
103+
" results.report()\n",
104+
" return np.linalg.norm(results.closest_position())\n"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": null,
110+
"metadata": {},
111+
"outputs": [],
112+
"source": [
113+
"run_experiment(CPU4PontinLayer, N=1, ts=10, report=True)"
114+
]
115+
},
116+
{
117+
"cell_type": "code",
118+
"execution_count": null,
119+
"metadata": {},
120+
"outputs": [],
121+
"source": [
122+
"mean_benchmark = np.mean([run_experiment(CPU4PontinLayer) for i in tqdm(range(0, 10))])\n",
123+
"print(f\"Benchmark mean: {mean_benchmark}\")"
124+
]
125+
},
126+
{
127+
"cell_type": "code",
128+
"execution_count": null,
129+
"metadata": {},
130+
"outputs": [],
131+
"source": [
132+
"Ns = range(10, 10000, 10)\n",
133+
"results1 = [run_experiment(create_quantized_layer(N), ts=1) for N in tqdm(Ns)]\n",
134+
"results2 = [run_experiment(create_quantized_layer(N), ts=2) for N in tqdm(Ns)]\n",
135+
"results10 = [run_experiment(create_quantized_layer(N), ts=10) for N in tqdm(Ns)]"
136+
]
137+
},
138+
{
139+
"cell_type": "code",
140+
"execution_count": null,
141+
"metadata": {},
142+
"outputs": [],
143+
"source": [
144+
"px.scatter(x=Ns, y=results1, labels={\"x\": \"resolution\", \"y\": \"smallest distance from nest\"})"
145+
]
146+
}
147+
],
148+
"metadata": {
149+
"kernelspec": {
150+
"display_name": "pim",
151+
"language": "python",
152+
"name": "pim"
153+
},
154+
"language_info": {
155+
"codemirror_mode": {
156+
"name": "ipython",
157+
"version": 3
158+
},
159+
"file_extension": ".py",
160+
"mimetype": "text/x-python",
161+
"name": "python",
162+
"nbconvert_exporter": "python",
163+
"pygments_lexer": "ipython3",
164+
"version": "3.10.4"
165+
}
166+
},
167+
"nbformat": 4,
168+
"nbformat_minor": 4
169+
}

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@ pandas>=1.4.2
88
networkx>=2.8.2
99
numpydoc>=1.3.1
1010
asyncio
11-
websockets
11+
websockets
12+
plotly
13+
tqdm
14+
ipywidgets

0 commit comments

Comments
 (0)