11// Copyright (c) 2023, DeepLink.
22
3- #include < cstddef>
43#include < functional>
54#include < memory>
65#include < stack>
76#include < thread>
87#include < utility>
98#include < vector>
109
10+ #include " csrc_dipu/utils/env.hpp"
11+
1112#include " DIPUCachingAllocator.h"
1213#include " DIPUSpinMutex.h"
1314
1415namespace dipu {
1516
16- inline size_t round_up_to_alignment ( size_t nbytes, size_t alignment_size) {
17- return ((nbytes - 1 ) | (alignment_size - 1 )) + 1 ;
18- }
17+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
18+ const size_t kMaxExtendSize = get_env_or_default( " DIPU_MAX_EXTEND_SIZE " , 1024 )
19+ << 20U ;
1920
2021class BFCachingAllocatorImpl {
2122 public:
@@ -30,23 +31,10 @@ class BFCachingAllocatorImpl {
3031 // Number of second level bins (linearly)
3132 static constexpr int kNumSubBins = 4 ;
3233 static constexpr int kLogNumSubBins = 2 ;
33-
3434 // Allocation parameters
35- static constexpr size_t kMinBlockSize =
36- 512 ; // all sizes are rounded to at least 512 bytes
37- static constexpr size_t kSmallSize =
38- 1048576 ; // largest "small" allocation is 1 MiB
39- static constexpr size_t kSmallBuffer =
40- 2097152 ; // "small" allocations are packed in 2 MiB blocks
41- static constexpr size_t kLargeBuffer =
42- 20971520 ; // "large" allocations may be packed in 20 MiB blocks
43- static constexpr size_t kMinLargeAlloc =
44- 10485760 ; // allocations between 1 and 10 MiB may use kLargeBuffer
45- static constexpr size_t kRoundLarge =
46- 2097152 ; // round up large allocations to 2 MiB
47- static constexpr size_t kMaxSplitableBlockSize =
48- 200 << 20 ; // To further reduce fragmentation, blocks >= 200MB are not
49- // allowed to be split
35+ static constexpr size_t kMinAllocationSize = 512 ;
36+ static constexpr size_t kMaxInternalFragmentation = 8U << 20U ; // 8MB
37+ static constexpr size_t kMinExtendSize = 8U << 20U ; // 8MB
5038
5139 size_t cachedBytes = 0 ;
5240 size_t allocatedBytes = 0 ;
@@ -79,6 +67,8 @@ class BFCachingAllocatorImpl {
7967 __uint128_t bits = 0 ;
8068 // Virtual chunks which are the heads of the bins
8169 std::array<int , static_cast <size_t >(kNumBigBins * kNumSubBins )> binHeads_{};
70+ // The extending size next time
71+ size_t currExtendSize_ = kMinExtendSize ;
8272
8373 explicit StreamSet (size_t id) : id(id) {}
8474
@@ -150,11 +140,7 @@ class BFCachingAllocatorImpl {
150140 mutable mutex_t mut_;
151141
152142 static size_t roundBytes (size_t nbytes) {
153- if (nbytes <= kMinBlockSize ) {
154- return kMinBlockSize ;
155- }
156- int clz = __builtin_clzll (nbytes - 1 );
157- return (1LU << (sizeof (int64_t ) - clz));
143+ return ((nbytes - 1 ) | (kMinAllocationSize - 1 )) + 1 ;
158144 }
159145
160146 int newChunk (void * ptr, size_t size, size_t stream) {
@@ -177,7 +163,7 @@ class BFCachingAllocatorImpl {
177163 // Big bin range:
178164 // [2^`bigBinIdx`, 2^(`bigBinIdx`+1)), length: 2^`bigBinIdx`
179165 // Split big bin into `kNumSubBins` sub bins
180- size_t nBlocks = nbytes / kMinBlockSize ;
166+ size_t nBlocks = nbytes / kMinAllocationSize ;
181167 constexpr int kMaxBinIdx = 63 ;
182168 int bigBinIdx = kMaxBinIdx - __builtin_clzll (nBlocks);
183169 // If `nbytes` is so large, we just put it into the last
@@ -253,22 +239,16 @@ class BFCachingAllocatorImpl {
253239 return id;
254240 }
255241
256- void shrink (StreamSetHandle& set, size_t try_release_size = 0 ) {
257- size_t released_size = 0 ;
242+ void shrink (StreamSetHandle& set) {
258243 for (int binHead : set->binHeads_ ) {
259244 int k = chunks_[binHead].nextChunkInList ;
260245 while (k) {
261- auto & chunk_k = chunks_[k];
262- if (chunk_k.isMonoBlock ()) {
263- released_size += chunk_k.size ;
264- releaseOnDevice (chunk_k.ptr , chunk_k.size );
246+ if (chunks_[k].isMonoBlock ()) {
247+ releaseOnDevice (chunks_[k].ptr , chunks_[k].size );
265248 removeChunkFromBin (k);
266249 recycleIds_.push (k);
267- if (try_release_size > 0 && released_size >= try_release_size) {
268- break ;
269- }
270250 }
271- k = chunk_k .nextChunkInList ;
251+ k = chunks_[k] .nextChunkInList ;
272252 }
273253 }
274254 }
@@ -311,39 +291,32 @@ class BFCachingAllocatorImpl {
311291 return id;
312292 }
313293
314- size_t getAllocateSize (size_t nbytes) {
315- if (nbytes <= kSmallSize ) {
316- return kSmallBuffer ;
317- }
318- if (nbytes < kMinLargeAlloc ) {
319- return kLargeBuffer ;
320- }
321- return round_up_to_alignment (nbytes, kRoundLarge );
322- }
323-
324294 int extend (size_t nbytes, StreamSetHandle& set) {
325- size_t allocateSize = getAllocateSize (nbytes);
326-
327- void * ptr = allocateOnDevice (allocateSize);
328- if (!ptr) {
329- shrink (set, allocateSize);
330- ptr = allocateOnDevice (allocateSize);
331- }
332- if (!ptr) {
333- shrink (set);
334- ptr = allocateOnDevice (allocateSize);
335- }
336- if (!ptr) {
337- if (allocateSize > nbytes) {
338- allocateSize = nbytes;
339- ptr = allocateOnDevice (allocateSize);
295+ emptyCacheWithoutLock ();
296+ auto & extSize = set->currExtendSize_ ;
297+ bool increased = false ;
298+ while (extSize < nbytes && extSize < kMaxExtendSize ) {
299+ extSize *= 2 ;
300+ increased = true ;
301+ }
302+
303+ size_t currBytes = std::max (nbytes, extSize);
304+ void * ptr = allocateOnDevice (currBytes);
305+ if (ptr) {
306+ if (!increased && extSize < kMaxExtendSize ) {
307+ extSize *= 2 ;
308+ }
309+ } else {
310+ if (currBytes > nbytes) {
311+ currBytes = nbytes;
312+ ptr = allocateOnDevice (currBytes);
340313 }
341314 }
342315 if (!ptr) {
343316 return 0 ;
344317 }
345318
346- int id = newChunk (ptr, allocateSize , set->id );
319+ int id = newChunk (ptr, currBytes , set->id );
347320 return id;
348321 }
349322
@@ -398,7 +371,8 @@ class BFCachingAllocatorImpl {
398371 }
399372
400373 if (id) {
401- if (chunks_[id].size >= (nbytes << 1 )) {
374+ if (chunks_[id].size >= nbytes * 2 ||
375+ chunks_[id].size >= nbytes + kMaxInternalFragmentation ) {
402376 id = split (id, nbytes);
403377 }
404378 chunks_[id].allocated = true ;
@@ -532,9 +506,6 @@ class BFCachingAllocator : public CacheAllocator {
532506 : DataPtrContextBase(allocator, ptr, size), id_(id), nbytes_(nbytes) {}
533507
534508 ~Context () {
535- if (size () <= 0 ) {
536- return ;
537- }
538509 auto allocator_ = static_cast <const BFCachingAllocator*>(allocator ());
539510 DIPU_DEBUG_ALLOCATOR (8 , " BFCachingAllocator: add to async_mem_pool:"
540511 << ptr () << " , " << size () << " nbytes, id:"
@@ -544,22 +515,18 @@ class BFCachingAllocator : public CacheAllocator {
544515 if (ptr ()) {
545516 allocator_->metrics_producer .deallocate (ptr ());
546517 std::deque<DIPUEvent> events;
547- bool record_block = false ;
548518 for (auto const & stream : streams ()) {
549519 events.emplace_back ();
550520 DIPU_DEBUG_ALLOCATOR (8 , " BFCachingAllocator: record to stream:"
551521 << stream.rawstream ());
552522 events.back ().record (stream);
553- record_block = true ;
554523 }
555524 allocator_->async_mem_pool ()->add (std::make_tuple (ptr (), id_),
556525 events);
557526 allocator_->set_memory_allocated (allocator_->memory_allocated () -
558527 nbytes_);
559- if (!record_block) {
560- allocator_->restore ();
561- }
562528 }
529+ allocator_->restore ();
563530 } else {
564531 DIPU_DEBUG_ALLOCATOR (8 ,
565532 " BFCachingAllocator:~Context: destory tensor "
@@ -570,12 +537,12 @@ class BFCachingAllocator : public CacheAllocator {
570537
571538 friend class Context ;
572539
573- c10::DataPtr allocate (size_t origin_size ) const override {
540+ c10::DataPtr allocate (size_t size ) const override {
574541 restore ();
575542 if (async_mem_pool ()->size () > kMaxAsyncResourcePoolLength ) {
576543 try_empty_resource_pool ();
577544 }
578- size_t size = getMemoryAlignmentStrategy ()->roundBytes (origin_size );
545+ size = getMemoryAlignmentStrategy ()->roundBytes (size );
579546 std::tuple<void *, int , size_t > block = impl->allocateRaw (size);
580547 void * ptr = std::get<0 >(block);
581548 if (ptr == nullptr && size > 0 ) {
@@ -601,8 +568,8 @@ class BFCachingAllocator : public CacheAllocator {
601568 deleteBFContext, device ());
602569 DIPU_DEBUG_ALLOCATOR (
603570 4 , " BFCachingAllocator: malloc "
604- << nbytes << " ,requires " << origin_size
605- << " nbytes, ptr: " << ptr << " ,device:" << device ()
571+ << nbytes << " ,requires " << size << " nbytes, ptr: " << ptr
572+ << " ,device:" << device ()
606573 << " ,async_mempool.size:" << async_mem_pool ()->size ());
607574 c10::reportMemoryUsageToProfiler (
608575 ptr, static_cast <int64_t >(nbytes), memory_allocated (),
0 commit comments