|
| 1 | +from pyhcl import * |
| 2 | +from cocotb import * |
| 3 | +from typing import List |
| 4 | +from math import * |
| 5 | +import os |
| 6 | + |
| 7 | +# Approach setting |
| 8 | +PT = 1 |
| 9 | +COCOTB = 2 |
| 10 | +TREADLE = 3 |
| 11 | + |
| 12 | +CLOCK_PERIOD = 2 |
| 13 | + |
| 14 | +# DUT setting |
| 15 | +# configs |
| 16 | +dut = [] |
| 17 | + |
| 18 | +# class Demo: |
| 19 | +# dataBits = ShellKey().memParams.dataBits |
| 20 | +# data = U.w(dataBits) |
| 21 | +from pyhcl.core._meta_pub import Pub |
| 22 | +from pyhcl.core._repr import CType |
| 23 | + |
| 24 | + |
| 25 | +def ispow2(n): |
| 26 | + return (n & (n-1)) == 0 |
| 27 | + |
| 28 | + |
| 29 | +class Counter: |
| 30 | + def __init__(self, n): |
| 31 | + assert n >= 0 |
| 32 | + |
| 33 | + self.n = n |
| 34 | + self.value = RegInit(U.w(int(ceil(log(n, 2))))(0)) if n > 1 else U(0) |
| 35 | + |
| 36 | + def inc(self): |
| 37 | + if self.n > 1: |
| 38 | + wrap = self.value == S(self.n - 1) |
| 39 | + self.value <<= self.value + U(1) |
| 40 | + |
| 41 | + # is n the power of 2? |
| 42 | + if ispow2(self.n) != 0: |
| 43 | + with when(wrap): |
| 44 | + self.value <<= U(0) |
| 45 | + |
| 46 | + return wrap |
| 47 | + else: |
| 48 | + return Bool(True) |
| 49 | + |
| 50 | + |
| 51 | +def queue(gentype, entries): |
| 52 | + class Queue_IO(Bundle_Helper): |
| 53 | + def __init__(self): |
| 54 | + self.count = Output(U.w(int(ceil(log(entries, 2))))) |
| 55 | + self.enq = flipped(decoupled(gentype)) |
| 56 | + self.deq = decoupled(gentype) |
| 57 | + |
| 58 | + cio = Queue_IO() |
| 59 | + |
| 60 | + class Queue(Module): |
| 61 | + io = mapper(Queue_IO()) |
| 62 | + |
| 63 | + # class Demo: |
| 64 | + # pass |
| 65 | + # |
| 66 | + # io = Demo() |
| 67 | + # io.count = tio._count |
| 68 | + # io.enq = Demo() |
| 69 | + # io.enq.__dict__["valid"] = tio.__getattribute__("enq_valid") |
| 70 | + # io.enq.ready = tio.enq_ready |
| 71 | + # io.enq.bits = tio.enq_bits |
| 72 | + # io.deq = Demo() |
| 73 | + # io.deq.valid = tio.deq_valid |
| 74 | + # io.deq.ready = tio.deq_ready |
| 75 | + # io.deq.bits = tio.deq_bits |
| 76 | + # # io = mapper(Queue_IO()) |
| 77 | + |
| 78 | + # Module Logic |
| 79 | + ram = Mem(entries, gentype) |
| 80 | + enq_ptr = Counter(entries) |
| 81 | + deq_ptr = Counter(entries) |
| 82 | + maybe_full = RegInit(Bool(False)) |
| 83 | + |
| 84 | + ptr_match = enq_ptr.value == deq_ptr.value |
| 85 | + empty = ptr_match & (~maybe_full) |
| 86 | + full = ptr_match & maybe_full |
| 87 | + do_enq = io.enq_valid & io.enq_ready |
| 88 | + do_deq = io.deq_valid & io.deq_ready |
| 89 | + |
| 90 | + with when(do_enq): |
| 91 | + ram[enq_ptr.value] <<= io.enq_bits |
| 92 | + enq_ptr.inc() |
| 93 | + |
| 94 | + with when(do_deq): |
| 95 | + deq_ptr.inc() |
| 96 | + |
| 97 | + with when(do_enq != do_deq): |
| 98 | + maybe_full <<= do_enq |
| 99 | + |
| 100 | + io.deq_valid <<= ~empty |
| 101 | + io.enq_ready <<= ~full |
| 102 | + io.deq_bits <<= ram[deq_ptr.value] |
| 103 | + |
| 104 | + ptr_diff = enq_ptr.value - deq_ptr.value |
| 105 | + if ispow2(entries): |
| 106 | + io.count <<= Mux(maybe_full & ptr_match, U(entries), U(0)) | ptr_diff |
| 107 | + else: |
| 108 | + io.count <<= Mux(ptr_match, |
| 109 | + Mux(maybe_full, U(entries), U(0)), |
| 110 | + Mux(deq_ptr.value > enq_ptr.value, |
| 111 | + U(entries) + ptr_diff, ptr_diff)) |
| 112 | + |
| 113 | + # # debug |
| 114 | + # rio.enq_cvalue <<= enq_ptr.value |
| 115 | + # rio.deq_cvalud <<= deq_ptr.value |
| 116 | + |
| 117 | + return Queue() |
| 118 | + |
| 119 | + |
| 120 | +class BaseType: |
| 121 | + pass |
| 122 | + |
| 123 | + |
| 124 | +# Enhance current io functions |
| 125 | +class Bundle_Helper: |
| 126 | + pass |
| 127 | + |
| 128 | + |
| 129 | +def mapper_helper(bundle, dic=None, prefix=""): |
| 130 | + tdic = {} if dic is None else dic |
| 131 | + |
| 132 | + for k in bundle.__dict__: |
| 133 | + v = bundle.__dict__[k] |
| 134 | + if isinstance(v, Pub): |
| 135 | + if prefix == "": |
| 136 | + tdic[k] = v |
| 137 | + else: |
| 138 | + tdic[prefix+"_"+k] = v |
| 139 | + elif isinstance(v, Bundle_Helper): |
| 140 | + if prefix == "": |
| 141 | + mapper_helper(v, tdic, k) |
| 142 | + else: |
| 143 | + mapper_helper(v, tdic, prefix+"_"+k) |
| 144 | + elif isinstance(v, List): |
| 145 | + for i in range(len(v)): |
| 146 | + if isinstance(v[i], Pub): |
| 147 | + if prefix == "": |
| 148 | + tdic[k+"_"+str(i)] = v[i] |
| 149 | + else: |
| 150 | + tdic[prefix+"_"+k+"_"+str(i)] = v[i] |
| 151 | + elif isinstance(v[i], Bundle_Helper): |
| 152 | + if prefix == "": |
| 153 | + mapper_helper(v[i], tdic, k+"_"+str(i)) |
| 154 | + else: |
| 155 | + mapper_helper(v[i], tdic, prefix+"_"+k+"_"+str(i)) |
| 156 | + |
| 157 | + return tdic |
| 158 | + |
| 159 | + |
| 160 | +def mapper(bundle): |
| 161 | + dct = mapper_helper(bundle) |
| 162 | + io = IO(**dct) |
| 163 | + |
| 164 | + return io |
| 165 | + |
| 166 | + |
| 167 | +def decoupled(basetype): |
| 168 | + coupled = Bundle_Helper() |
| 169 | + coupled.valid = Output(Bool) |
| 170 | + coupled.ready = Input(Bool) |
| 171 | + |
| 172 | + if isinstance(basetype, CType) or isinstance(basetype, type): |
| 173 | + coupled.bits = Output(basetype) |
| 174 | + elif isinstance(basetype, BaseType): |
| 175 | + coupled.bits = Bundle_Helper() |
| 176 | + dic = basetype.__dict__ |
| 177 | + for keys in dic: |
| 178 | + if isinstance(dic[keys], CType) or isinstance(dic[keys], type): |
| 179 | + coupled.bits.__dict__[keys] = Output(dic[keys]) |
| 180 | + |
| 181 | + return coupled |
| 182 | + |
| 183 | + |
| 184 | +def valid(basetype): |
| 185 | + coupled = Bundle_Helper() |
| 186 | + coupled.valid = Output(Bool) |
| 187 | + |
| 188 | + if isinstance(basetype, CType) or isinstance(basetype, type): |
| 189 | + coupled.bits = Output(basetype) |
| 190 | + elif isinstance(basetype, Vec): |
| 191 | + coupled.bits = Output(basetype) |
| 192 | + elif isinstance(basetype, BaseType): |
| 193 | + coupled.bits = Bundle_Helper() |
| 194 | + dic = basetype.__dict__ |
| 195 | + for keys in dic: |
| 196 | + if isinstance(dic[keys], CType) or isinstance(dic[keys], type): |
| 197 | + coupled.bits.__dict__[keys] = Output(dic[keys]) |
| 198 | + |
| 199 | + return coupled |
| 200 | + |
| 201 | + |
| 202 | +def base_flipped(obj): |
| 203 | + return Output(obj.typ) if isinstance(obj.value, Input) else Input(obj.typ) |
| 204 | + |
| 205 | + |
| 206 | +def flipped(bundle): |
| 207 | + dic = bundle.__dict__ |
| 208 | + for keys in dic: |
| 209 | + if isinstance(dic[keys], Pub): |
| 210 | + dic[keys] = base_flipped(dic[keys]) |
| 211 | + elif isinstance(dic[keys], Bundle_Helper): |
| 212 | + flipped(dic[keys]) |
| 213 | + |
| 214 | + return bundle |
| 215 | + |
| 216 | + |
| 217 | +def log2ceil(v): |
| 218 | + return ceil(log(v, 2)) |
| 219 | + |
| 220 | + |
| 221 | +def Mem_maskwrite(m, data, mask, length): |
| 222 | + for i in range(length): |
| 223 | + with when(~(mask[i])): |
| 224 | + data[i] <<= U(0) |
| 225 | + m <<= data |
| 226 | + |
| 227 | + |
| 228 | +@cocotb.coroutine |
| 229 | +def cocotb_poke(dut, signals, value): |
| 230 | + dut.signals = value |
| 231 | + yield Timer(CLOCK_PERIOD, 'ns') |
| 232 | + |
| 233 | + |
| 234 | +@cocotb.coroutine |
| 235 | +def cocotb_peek(dut, signals): |
| 236 | + return dut.signals |
| 237 | + |
| 238 | + |
| 239 | +# General manipulate functions for cocotb, PT, and Treadle |
| 240 | +def poke(signals, value, ap): |
| 241 | + if ap == PT: |
| 242 | + return simulator.poke(signals, value) |
| 243 | + elif ap == COCOTB: |
| 244 | + return cocotb_poke(dut, signals, value) |
| 245 | + else: |
| 246 | + # switch to treadle insert assertions |
| 247 | + os.system("treadle.sh") |
| 248 | + |
| 249 | + |
| 250 | +def peek(signals, ap): |
| 251 | + if ap == PT: |
| 252 | + return simulator.peek(signals) |
| 253 | + elif ap == COCOTB: |
| 254 | + return cocotb_peek(dut, signals) |
| 255 | + else: |
| 256 | + # switch to treadle insert assertions |
| 257 | + os.system("treadle.sh") |
| 258 | + |
| 259 | + |
| 260 | +lass AXIParams: |
| 261 | + def __init__(self, |
| 262 | + coherent: bool = False, |
| 263 | + idBits: int = 1, |
| 264 | + addrBits: int = 32, |
| 265 | + dataBits: int = 64, |
| 266 | + lenBits: int = 8, |
| 267 | + userBits: int = 1 |
| 268 | + ): |
| 269 | + assert addrBits > 0 |
| 270 | + assert dataBits >= 8 and dataBits % 2 == 0 |
| 271 | + |
| 272 | + self.coherent = coherent |
| 273 | + self.idBits = idBits |
| 274 | + self.addrBits = addrBits |
| 275 | + self.dataBits = dataBits |
| 276 | + self.lenBits = lenBits # Max burst length = 256, lenBits = 8 |
| 277 | + self.userBits = userBits |
| 278 | + |
| 279 | + self.strbBits: int = int(dataBits / 8) |
| 280 | + self.sizeBits: int = 3 |
| 281 | + self.burstBits: int = 2 |
| 282 | + self.lockBits: int = 2 |
| 283 | + self.cacheBits: int = 4 |
| 284 | + self.protBits: int = 3 |
| 285 | + self.qosBits: int = 4 |
| 286 | + self.regionBits: int = 4 |
| 287 | + self.respBits: int = 2 |
| 288 | + self.sizeConst: int = int(ceil(log(int(dataBits / 8), 2))) |
| 289 | + self.idConst: int = 0 |
| 290 | + self.userConst: int = 1 if coherent else 0 |
| 291 | + self.burstConst: int = 1 |
| 292 | + self.lockConst: int = 0 |
| 293 | + self.cacheConst: int = 15 if coherent else 3 |
| 294 | + self.protConst: int = 4 if coherent else 0 |
| 295 | + self.qosConst: int = 0 |
| 296 | + self.regionConst: int = 0 |
| 297 | + |
| 298 | +class VMEParams: |
| 299 | + ''' |
| 300 | + VME parameters. |
| 301 | + These parameters are used on VME interfaces and modules. |
| 302 | + ''' |
| 303 | + nReadClients: int = 5 |
| 304 | + nWriteClients: int = 1 |
| 305 | + |
| 306 | + |
| 307 | +class VCRParams: |
| 308 | + nCtrl = 1 |
| 309 | + nECnt = 1 |
| 310 | + nVals = 1 |
| 311 | + nPtrs = 6 |
| 312 | + regBits = 32 |
| 313 | + |
| 314 | + |
| 315 | +# Shell parameters |
| 316 | +class ShellParams: |
| 317 | + hostParams = AXIParams() |
| 318 | + memParams = AXIParams() |
| 319 | + vcrParams = VCRParams() |
| 320 | + vmeParams = VMEParams() |
| 321 | + |
| 322 | + |
| 323 | +class ShellKey(ShellParams): |
| 324 | + pass |
| 325 | + |
| 326 | + |
| 327 | +class CoreParams: |
| 328 | + batch: int = 1 |
| 329 | + blockOut: int = 16 |
| 330 | + blockIn: int = 16 |
| 331 | + inpBits: int = 8 |
| 332 | + wgtBits: int = 8 |
| 333 | + uopBits: int = 32 |
| 334 | + accBits: int = 32 |
| 335 | + outBits: int = 8 |
| 336 | + uopMemDepth: int = 512 |
| 337 | + inpMemDepth: int = 512 |
| 338 | + wgtMemDepth: int = 512 |
| 339 | + accMemDepth: int = 512 |
| 340 | + outMemDepth: int = 512 |
| 341 | + instQueueEntries: int = 32 |
| 342 | + |
| 343 | + |
| 344 | +class CoreKey(CoreParams): |
| 345 | + pass |
0 commit comments