From 4057a8a81b1bd14ea2f91380e71307aff68e4a28 Mon Sep 17 00:00:00 2001 From: thomashli Date: Wed, 18 Jul 2018 10:01:43 -0600 Subject: [PATCH 01/29] swapped ripsFiltration --- src/dionysus/backward.hpp | 2212 +++++++++++++++++ src/dionysus/dionysus/chain.h | 153 ++ src/dionysus/dionysus/chain.hpp | 188 ++ src/dionysus/dionysus/clearing-reduction.h | 45 + src/dionysus/dionysus/clearing-reduction.hpp | 60 + .../dionysus/cohomology-persistence.h | 116 + .../dionysus/cohomology-persistence.hpp | 61 + src/dionysus/dionysus/diagram.h | 105 + src/dionysus/dionysus/distances.h | 93 + src/dionysus/dionysus/distances.hpp | 30 + src/dionysus/dionysus/dlog/progress.h | 57 + src/dionysus/dionysus/fields/q.h | 74 + src/dionysus/dionysus/fields/z2.h | 34 + src/dionysus/dionysus/fields/zp.h | 60 + src/dionysus/dionysus/filtration.h | 123 + .../dionysus/omni-field-persistence.h | 135 + .../dionysus/omni-field-persistence.hpp | 250 ++ src/dionysus/dionysus/ordinary-persistence.h | 64 + src/dionysus/dionysus/pair-recorder.h | 78 + src/dionysus/dionysus/reduced-matrix.h | 166 ++ src/dionysus/dionysus/reduced-matrix.hpp | 78 + src/dionysus/dionysus/reduction.h | 107 + .../dionysus/relative-homology-zigzag.h | 84 + .../dionysus/relative-homology-zigzag.hpp | 122 + src/dionysus/dionysus/rips.h | 147 ++ src/dionysus/dionysus/rips.hpp | 162 ++ src/dionysus/dionysus/row-reduction.h | 54 + src/dionysus/dionysus/row-reduction.hpp | 103 + src/dionysus/dionysus/simplex.h | 272 ++ src/dionysus/dionysus/sparse-row-matrix.h | 184 ++ src/dionysus/dionysus/sparse-row-matrix.hpp | 103 + src/dionysus/dionysus/standard-reduction.h | 44 + src/dionysus/dionysus/standard-reduction.hpp | 47 + src/dionysus/dionysus/trails-chains.h | 17 + src/dionysus/dionysus/zigzag-persistence.h | 141 ++ src/dionysus/dionysus/zigzag-persistence.hpp | 534 ++++ src/tdautils/dionysusUtils.h | 47 +- tests/testthat/test_kde.R | 20 + tests/testthat/test_rips.R | 29 + 39 files changed, 6397 insertions(+), 2 deletions(-) create mode 100755 src/dionysus/backward.hpp create mode 100755 src/dionysus/dionysus/chain.h create mode 100755 src/dionysus/dionysus/chain.hpp create mode 100755 src/dionysus/dionysus/clearing-reduction.h create mode 100755 src/dionysus/dionysus/clearing-reduction.hpp create mode 100755 src/dionysus/dionysus/cohomology-persistence.h create mode 100755 src/dionysus/dionysus/cohomology-persistence.hpp create mode 100755 src/dionysus/dionysus/diagram.h create mode 100755 src/dionysus/dionysus/distances.h create mode 100755 src/dionysus/dionysus/distances.hpp create mode 100755 src/dionysus/dionysus/dlog/progress.h create mode 100755 src/dionysus/dionysus/fields/q.h create mode 100755 src/dionysus/dionysus/fields/z2.h create mode 100755 src/dionysus/dionysus/fields/zp.h create mode 100755 src/dionysus/dionysus/filtration.h create mode 100755 src/dionysus/dionysus/omni-field-persistence.h create mode 100755 src/dionysus/dionysus/omni-field-persistence.hpp create mode 100755 src/dionysus/dionysus/ordinary-persistence.h create mode 100755 src/dionysus/dionysus/pair-recorder.h create mode 100755 src/dionysus/dionysus/reduced-matrix.h create mode 100755 src/dionysus/dionysus/reduced-matrix.hpp create mode 100755 src/dionysus/dionysus/reduction.h create mode 100755 src/dionysus/dionysus/relative-homology-zigzag.h create mode 100755 src/dionysus/dionysus/relative-homology-zigzag.hpp create mode 100755 src/dionysus/dionysus/rips.h create mode 100755 src/dionysus/dionysus/rips.hpp create mode 100755 src/dionysus/dionysus/row-reduction.h create mode 100755 src/dionysus/dionysus/row-reduction.hpp create mode 100755 src/dionysus/dionysus/simplex.h create mode 100755 src/dionysus/dionysus/sparse-row-matrix.h create mode 100755 src/dionysus/dionysus/sparse-row-matrix.hpp create mode 100755 src/dionysus/dionysus/standard-reduction.h create mode 100755 src/dionysus/dionysus/standard-reduction.hpp create mode 100755 src/dionysus/dionysus/trails-chains.h create mode 100755 src/dionysus/dionysus/zigzag-persistence.h create mode 100755 src/dionysus/dionysus/zigzag-persistence.hpp create mode 100644 tests/testthat/test_kde.R create mode 100644 tests/testthat/test_rips.R diff --git a/src/dionysus/backward.hpp b/src/dionysus/backward.hpp new file mode 100755 index 0000000..6b331ba --- /dev/null +++ b/src/dionysus/backward.hpp @@ -0,0 +1,2212 @@ +/* + * backward.hpp + * Copyright 2013 Google Inc. All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef H_6B9572DA_A64B_49E6_B234_051480991C89 +#define H_6B9572DA_A64B_49E6_B234_051480991C89 + +#ifndef __cplusplus +# error "It's not going to compile without a C++ compiler..." +#endif + +#if defined(BACKWARD_CXX11) +#elif defined(BACKWARD_CXX98) +#else +# if __cplusplus >= 201103L +# define BACKWARD_CXX11 +# else +# define BACKWARD_CXX98 +# endif +#endif + +// You can define one of the following (or leave it to the auto-detection): +// +// #define BACKWARD_SYSTEM_LINUX +// - specialization for linux +// +// #define BACKWARD_SYSTEM_UNKNOWN +// - placebo implementation, does nothing. +// +#if defined(BACKWARD_SYSTEM_LINUX) +#elif defined(BACKWARD_SYSTEM_UNKNOWN) +#else +# if defined(__linux) +# define BACKWARD_SYSTEM_LINUX +# else +# define BACKWARD_SYSTEM_UNKNOWN +# endif +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(BACKWARD_SYSTEM_LINUX) + +// On linux, backtrace can back-trace or "walk" the stack using the following +// library: +// +// #define BACKWARD_HAS_UNWIND 1 +// - unwind comes from libgcc, but I saw an equivalent inside clang itself. +// - with unwind, the stacktrace is as accurate as it can possibly be, since +// this is used by the C++ runtine in gcc/clang for stack unwinding on +// exception. +// - normally libgcc is already linked to your program by default. +// +// #define BACKWARD_HAS_BACKTRACE == 1 +// - backtrace seems to be a little bit more portable than libunwind, but on +// linux, it uses unwind anyway, but abstract away a tiny information that is +// sadly really important in order to get perfectly accurate stack traces. +// - backtrace is part of the (e)glib library. +// +// The default is: +// #define BACKWARD_HAS_UNWIND == 1 +// +# if BACKWARD_HAS_UNWIND == 1 +# elif BACKWARD_HAS_BACKTRACE == 1 +# else +# undef BACKWARD_HAS_UNWIND +# define BACKWARD_HAS_UNWIND 1 +# undef BACKWARD_HAS_BACKTRACE +# define BACKWARD_HAS_BACKTRACE 0 +# endif + +// On linux, backward can extract detailed information about a stack trace +// using one of the following library: +// +// #define BACKWARD_HAS_DW 1 +// - libdw gives you the most juicy details out of your stack traces: +// - object filename +// - function name +// - source filename +// - line and column numbers +// - source code snippet (assuming the file is accessible) +// - variables name and values (if not optimized out) +// - You need to link with the lib "dw": +// - apt-get install libdw-dev +// - g++/clang++ -ldw ... +// +// #define BACKWARD_HAS_BFD 1 +// - With libbfd, you get a fair about of details: +// - object filename +// - function name +// - source filename +// - line numbers +// - source code snippet (assuming the file is accessible) +// - You need to link with the lib "bfd": +// - apt-get install binutils-dev +// - g++/clang++ -lbfd ... +// +// #define BACKWARD_HAS_BACKTRACE_SYMBOL 1 +// - backtrace provides minimal details for a stack trace: +// - object filename +// - function name +// - backtrace is part of the (e)glib library. +// +// The default is: +// #define BACKWARD_HAS_BACKTRACE_SYMBOL == 1 +// +# if BACKWARD_HAS_DW == 1 +# elif BACKWARD_HAS_BFD == 1 +# elif BACKWARD_HAS_BACKTRACE_SYMBOL == 1 +# else +# undef BACKWARD_HAS_DW +# define BACKWARD_HAS_DW 0 +# undef BACKWARD_HAS_BFD +# define BACKWARD_HAS_BFD 0 +# undef BACKWARD_HAS_BACKTRACE_SYMBOL +# define BACKWARD_HAS_BACKTRACE_SYMBOL 1 +# endif + + +# if BACKWARD_HAS_UNWIND == 1 + +# include +// while gcc's unwind.h defines something like that: +// extern _Unwind_Ptr _Unwind_GetIP (struct _Unwind_Context *); +// extern _Unwind_Ptr _Unwind_GetIPInfo (struct _Unwind_Context *, int *); +// +// clang's unwind.h defines something like this: +// uintptr_t _Unwind_GetIP(struct _Unwind_Context* __context); +// +// Even if the _Unwind_GetIPInfo can be linked to, it is not declared, worse we +// cannot just redeclare it because clang's unwind.h doesn't define _Unwind_Ptr +// anyway. +// +// Luckily we can play on the fact that the guard macros have a different name: +#ifdef __CLANG_UNWIND_H +// In fact, this function still comes from libgcc (on my different linux boxes, +// clang links against libgcc). +# include +extern "C" uintptr_t _Unwind_GetIPInfo(_Unwind_Context*, int*); +#endif + +# endif + +# include +# include +# include +# include +# include +# include +# include + +# if BACKWARD_HAS_BFD == 1 +# include +# ifndef _GNU_SOURCE +# define _GNU_SOURCE +# include +# undef _GNU_SOURCE +# else +# include +# endif +# endif + +# if BACKWARD_HAS_DW == 1 +# include +# include +# include +# endif + +# if (BACKWARD_HAS_BACKTRACE == 1) || (BACKWARD_HAS_BACKTRACE_SYMBOL == 1) + // then we shall rely on backtrace +# include +# endif + +#endif // defined(BACKWARD_SYSTEM_LINUX) + +#if defined(BACKWARD_CXX11) +# include +# include // for std::swap + namespace backward { + namespace details { + template + struct hashtable { + typedef std::unordered_map type; + }; + using std::move; + } // namespace details + } // namespace backward +#elif defined(BACKWARD_CXX98) +# include + namespace backward { + namespace details { + template + struct hashtable { + typedef std::map type; + }; + template + const T& move(const T& v) { return v; } + template + T& move(T& v) { return v; } + } // namespace details + } // namespace backward +#else +# error "Mmm if its not C++11 nor C++98... go play in the toaster." +#endif + +namespace backward { + +namespace system_tag { + struct linux_tag; // seems that I cannot call that "linux" because the name + // is already defined... so I am adding _tag everywhere. + struct unknown_tag; + +#if defined(BACKWARD_SYSTEM_LINUX) + typedef linux_tag current_tag; +#elif defined(BACKWARD_SYSTEM_UNKNOWN) + typedef unknown_tag current_tag; +#else +# error "May I please get my system defines?" +#endif +} // namespace system_tag + + +namespace stacktrace_tag { +#ifdef BACKWARD_SYSTEM_LINUX + struct unwind; + struct backtrace; + +# if BACKWARD_HAS_UNWIND == 1 + typedef unwind current; +# elif BACKWARD_HAS_BACKTRACE == 1 + typedef backtrace current; +# else +# error "I know it's difficult but you need to make a choice!" +# endif +#endif // BACKWARD_SYSTEM_LINUX +} // namespace stacktrace_tag + + +namespace trace_resolver_tag { +#ifdef BACKWARD_SYSTEM_LINUX + struct libdw; + struct libbfd; + struct backtrace_symbol; + +# if BACKWARD_HAS_DW == 1 + typedef libdw current; +# elif BACKWARD_HAS_BFD == 1 + typedef libbfd current; +# elif BACKWARD_HAS_BACKTRACE_SYMBOL == 1 + typedef backtrace_symbol current; +# else +# error "You shall not pass, until you know what you want." +# endif +#endif // BACKWARD_SYSTEM_LINUX +} // namespace trace_resolver_tag + +namespace details { + +template + struct rm_ptr { typedef T type; }; + +template + struct rm_ptr { typedef T type; }; + +template + struct rm_ptr { typedef const T type; }; + +template +struct deleter { + template + void operator()(U& ptr) const { + (*F)(ptr); + } +}; + +template +struct default_delete { + void operator()(T& ptr) const { + delete ptr; + } +}; + +template > +class handle { + struct dummy; + T _val; + bool _empty; + +#if defined(BACKWARD_CXX11) + handle(const handle&) = delete; + handle& operator=(const handle&) = delete; +#endif + +public: + ~handle() { + if (not _empty) { + Deleter()(_val); + } + } + + explicit handle(): _val(), _empty(true) {} + explicit handle(T val): _val(val), _empty(false) {} + +#if defined(BACKWARD_CXX11) + handle(handle&& from): _empty(true) { + swap(from); + } + handle& operator=(handle&& from) { + swap(from); return *this; + } +#else + explicit handle(const handle& from): _empty(true) { + // some sort of poor man's move semantic. + swap(const_cast(from)); + } + handle& operator=(const handle& from) { + // some sort of poor man's move semantic. + swap(const_cast(from)); return *this; + } +#endif + + void reset(T new_val) { + handle tmp(new_val); + swap(tmp); + } + operator const dummy*() const { + if (_empty) { + return 0; + } + return reinterpret_cast(_val); + } + T get() { + return _val; + } + T release() { + _empty = true; + return _val; + } + void swap(handle& b) { + using std::swap; + swap(b._val, _val); // can throw, we are safe here. + swap(b._empty, _empty); // should not throw: if you cannot swap two + // bools without throwing... It's a lost cause anyway! + } + + T operator->() { return _val; } + const T operator->() const { return _val; } + + typedef typename rm_ptr::type& ref_t; + ref_t operator*() { return *_val; } + const ref_t operator*() const { return *_val; } + ref_t operator[](size_t idx) { return _val[idx]; } + + // Watch out, we've got a badass over here + T* operator&() { + _empty = false; + return &_val; + } +}; + +} // namespace details + +/*************** A TRACE ***************/ + +struct Trace { + void* addr; + size_t idx; + + Trace(): + addr(0), idx(0) {} + + explicit Trace(void* addr, size_t idx): + addr(addr), idx(idx) {} +}; + +// Really simple, generic, and dumb representation of a variable. +// A variable has a name and can represent either: +// - a value (as a string) +// - a list of values (a list of strings) +// - a map of values (a list of variable) +class Variable { +public: + enum Kind { VALUE, LIST, MAP }; + + typedef std::vector list_t; + typedef std::vector map_t; + + std::string name; + Kind kind; + + Variable(Kind k): kind(k) { + switch (k) { + case VALUE: + new (&storage) std::string(); + break; + + case LIST: + new (&storage) list_t(); + break; + + case MAP: + new (&storage) map_t(); + break; + } + } + + std::string& value() { + return reinterpret_cast(storage); + } + list_t& list() { + return reinterpret_cast(storage); + } + map_t& map() { + return reinterpret_cast(storage); + } + + + const std::string& value() const { + return reinterpret_cast(storage); + } + const list_t& list() const { + return reinterpret_cast(storage); + } + const map_t& map() const { + return reinterpret_cast(storage); + } + +private: + // the C++98 style union for non-trivial objects, yes yes I know, its not + // aligned as good as it can be, blabla... Screw this. + union { + char s1[sizeof (std::string)]; + char s2[sizeof (list_t)]; + char s3[sizeof (map_t)]; + } storage; +}; + +struct TraceWithLocals: public Trace { + // Locals variable and values. + std::vector locals; + + TraceWithLocals(): Trace() {} + TraceWithLocals(const Trace& mini_trace): + Trace(mini_trace) {} +}; + +struct ResolvedTrace: public TraceWithLocals { + + struct SourceLoc { + std::string function; + std::string filename; + unsigned line; + unsigned col; + + SourceLoc(): line(0), col(0) {} + + bool operator==(const SourceLoc& b) const { + return function == b.function + and filename == b.filename + and line == b.line + and col == b.col; + } + + bool operator!=(const SourceLoc& b) const { + return not (*this == b); + } + }; + + // In which binary object this trace is located. + std::string object_filename; + + // The function in the object that contain the trace. This is not the same + // as source.function which can be an function inlined in object_function. + std::string object_function; + + // The source location of this trace. It is possible for filename to be + // empty and for line/col to be invalid (value 0) if this information + // couldn't be deduced, for example if there is no debug information in the + // binary object. + SourceLoc source; + + // An optionals list of "inliners". All the successive sources location + // from where the source location of the trace (the attribute right above) + // is inlined. It is especially useful when you compiled with optimization. + typedef std::vector source_locs_t; + source_locs_t inliners; + + ResolvedTrace(const Trace& mini_trace): + TraceWithLocals(mini_trace) {} + ResolvedTrace(const TraceWithLocals& mini_trace_with_locals): + TraceWithLocals(mini_trace_with_locals) {} +}; + +/*************** STACK TRACE ***************/ + +// default implemention. +template +class StackTraceImpl { +public: + size_t size() const { return 0; } + Trace operator[](size_t) { return Trace(); } + size_t load_here(size_t=0) { return 0; } + size_t load_from(void*, size_t=0) { return 0; } + unsigned thread_id() const { return 0; } +}; + +#ifdef BACKWARD_SYSTEM_LINUX + +class StackTraceLinuxImplBase { +public: + StackTraceLinuxImplBase(): _thread_id(0), _skip(0) {} + + unsigned thread_id() const { + return _thread_id; + } + +protected: + void load_thread_info() { + _thread_id = syscall(SYS_gettid); + if (_thread_id == (size_t) getpid()) { + // If the thread is the main one, let's hide that. + // I like to keep little secret sometimes. + _thread_id = 0; + } + } + + void skip_n_firsts(size_t n) { _skip = n; } + size_t skip_n_firsts() const { return _skip; } + +private: + size_t _thread_id; + size_t _skip; +}; + +class StackTraceLinuxImplHolder: public StackTraceLinuxImplBase { +public: + size_t size() const { + return _stacktrace.size() ? _stacktrace.size() - skip_n_firsts() : 0; + } + Trace operator[](size_t idx) { + if (idx >= size()) { + return Trace(); + } + return Trace(_stacktrace[idx + skip_n_firsts()], idx); + } + void** begin() { + if (size()) { + return &_stacktrace[skip_n_firsts()]; + } + return 0; + } + +protected: + std::vector _stacktrace; +}; + + +#if BACKWARD_HAS_UNWIND == 1 + +namespace details { + +template +class Unwinder { +public: + size_t operator()(F& f, size_t depth) { + _f = &f; + _index = -1; + _depth = depth; + _Unwind_Backtrace(&this->backtrace_trampoline, this); + return _index; + } + +private: + F* _f; + ssize_t _index; + size_t _depth; + + static _Unwind_Reason_Code backtrace_trampoline( + _Unwind_Context* ctx, void *self) { + return ((Unwinder*)self)->backtrace(ctx); + } + + _Unwind_Reason_Code backtrace(_Unwind_Context* ctx) { + if (_index >= 0 and static_cast(_index) >= _depth) + return _URC_END_OF_STACK; + + int ip_before_instruction = 0; + uintptr_t ip = _Unwind_GetIPInfo(ctx, &ip_before_instruction); + + if (not ip_before_instruction) { + ip -= 1; + } + + if (_index >= 0) { // ignore first frame. + (*_f)(_index, (void*)ip); + } + _index += 1; + return _URC_NO_REASON; + } +}; + +template +size_t unwind(F f, size_t depth) { + Unwinder unwinder; + return unwinder(f, depth); +} + +} // namespace details + + +template <> +class StackTraceImpl: public StackTraceLinuxImplHolder { +public: + __attribute__ ((noinline)) // TODO use some macro + size_t load_here(size_t depth=32) { + load_thread_info(); + if (depth == 0) { + return 0; + } + _stacktrace.resize(depth); + size_t trace_cnt = details::unwind(callback(*this), depth); + _stacktrace.resize(trace_cnt); + skip_n_firsts(0); + return size(); + } + size_t load_from(void* addr, size_t depth=32) { + load_here(depth + 8); + + for (size_t i = 0; i < _stacktrace.size(); ++i) { + if (_stacktrace[i] == addr) { + skip_n_firsts(i); + break; + } + } + + _stacktrace.resize(std::min(_stacktrace.size(), + skip_n_firsts() + depth)); + return size(); + } + +private: + struct callback { + StackTraceImpl& self; + callback(StackTraceImpl& self): self(self) {} + + void operator()(size_t idx, void* addr) { + self._stacktrace[idx] = addr; + } + }; +}; + + +#else // BACKWARD_HAS_UNWIND == 0 + +template <> +class StackTraceImpl: public StackTraceLinuxImplHolder { +public: + __attribute__ ((noinline)) // TODO use some macro + size_t load_here(size_t depth=32) { + load_thread_info(); + if (depth == 0) { + return 0; + } + _stacktrace.resize(depth + 1); + size_t trace_cnt = backtrace(&_stacktrace[0], _stacktrace.size()); + _stacktrace.resize(trace_cnt); + skip_n_firsts(1); + return size(); + } + + size_t load_from(void* addr, size_t depth=32) { + load_here(depth + 8); + + for (size_t i = 0; i < _stacktrace.size(); ++i) { + if (_stacktrace[i] == addr) { + skip_n_firsts(i); + _stacktrace[i] = (void*)( (uintptr_t)_stacktrace[i] + 1); + break; + } + } + + _stacktrace.resize(std::min(_stacktrace.size(), + skip_n_firsts() + depth)); + return size(); + } +}; + +#endif // BACKWARD_HAS_UNWIND +#endif // BACKWARD_SYSTEM_LINUX + +class StackTrace: + public StackTraceImpl {}; + +/*********** STACKTRACE WITH LOCALS ***********/ + +// default implemention. +template +class StackTraceWithLocalsImpl: + public StackTrace {}; + +#ifdef BACKWARD_SYSTEM_LINUX +#if BACKWARD_HAS_UNWIND +#if BACKWARD_HAS_DW + +template <> +class StackTraceWithLocalsImpl: + public StackTraceLinuxImplBase { +public: + __attribute__ ((noinline)) // TODO use some macro + size_t load_here(size_t depth=32) { + load_thread_info(); + if (depth == 0) { + return 0; + } + _stacktrace.resize(depth); + size_t trace_cnt = details::unwind(callback(*this), depth); + _stacktrace.resize(trace_cnt); + skip_n_firsts(0); + return size(); + } + size_t load_from(void* addr, size_t depth=32) { + load_here(depth + 8); + + for (size_t i = 0; i < _stacktrace.size(); ++i) { + if (_stacktrace[i].addr == addr) { + skip_n_firsts(i); + break; + } + } + _stacktrace.resize(std::min(_stacktrace.size(), + skip_n_firsts() + depth)); + return size(); + } + size_t size() const { + return _stacktrace.size() ? _stacktrace.size() - skip_n_firsts() : 0; + } + const TraceWithLocals& operator[](size_t idx) { + if (idx >= size()) { + return _nil_trace; + } + return _stacktrace[idx + skip_n_firsts()]; + } + +private: + std::vector _stacktrace; + TraceWithLocals _nil_trace; + + void resolve_trace(TraceWithLocals& trace) { + Variable v(Variable::VALUE); + v.name = "var"; + v.value() = "42"; + trace.locals.push_back(v); + } + + struct callback { + StackTraceWithLocalsImpl& self; + callback(StackTraceWithLocalsImpl& self): self(self) {} + + void operator()(size_t idx, void* addr) { + self._stacktrace[idx].addr = addr; + self.resolve_trace(self._stacktrace[idx]); + } + }; +}; + +#endif // BACKWARD_HAS_DW +#endif // BACKWARD_HAS_UNWIND +#endif // BACKWARD_SYSTEM_LINUX + +class StackTraceWithLocals: + public StackTraceWithLocalsImpl {}; + +/*************** TRACE RESOLVER ***************/ + +template +class TraceResolverImpl; + +#ifdef BACKWARD_SYSTEM_UNKNOWN + +template <> +class TraceResolverImpl { +public: + template + void load_stacktrace(ST&) {} + ResolvedTrace resolve(ResolvedTrace t) { + return t; + } +}; + +#endif + +#ifdef BACKWARD_SYSTEM_LINUX + +class TraceResolverLinuxImplBase { +protected: + std::string demangle(const char* funcname) { + using namespace details; + _demangle_buffer.reset( + abi::__cxa_demangle(funcname, _demangle_buffer.release(), + &_demangle_buffer_length, 0) + ); + if (_demangle_buffer) { + return _demangle_buffer.get(); + } + return funcname; + } + +private: + details::handle _demangle_buffer; + size_t _demangle_buffer_length; +}; + +template +class TraceResolverLinuxImpl; + +#if BACKWARD_HAS_BACKTRACE_SYMBOL == 1 + +template <> +class TraceResolverLinuxImpl: + public TraceResolverLinuxImplBase { +public: + template + void load_stacktrace(ST& st) { + using namespace details; + if (st.size() == 0) { + return; + } + _symbols.reset( + backtrace_symbols(st.begin(), st.size()) + ); + } + + ResolvedTrace resolve(ResolvedTrace trace) { + char* filename = _symbols[trace.idx]; + char* funcname = filename; + while (*funcname && *funcname != '(') { + funcname += 1; + } + trace.object_filename.assign(filename, funcname++); + char* funcname_end = funcname; + while (*funcname_end && *funcname_end != ')' && *funcname_end != '+') { + funcname_end += 1; + } + *funcname_end = '\0'; + trace.object_function = this->demangle(funcname); + trace.source.function = trace.object_function; // we cannot do better. + return trace; + } + +private: + details::handle _symbols; +}; + +#endif // BACKWARD_HAS_BACKTRACE_SYMBOL == 1 + +#if BACKWARD_HAS_BFD == 1 + +template <> +class TraceResolverLinuxImpl: + public TraceResolverLinuxImplBase { +public: + TraceResolverLinuxImpl(): _bfd_loaded(false) {} + + template + void load_stacktrace(ST&) {} + + ResolvedTrace resolve(ResolvedTrace trace) { + Dl_info symbol_info; + + // trace.addr is a virtual address in memory pointing to some code. + // Let's try to find from which loaded object it comes from. + // The loaded object can be yourself btw. + if (not dladdr(trace.addr, &symbol_info)) { + return trace; // dat broken trace... + } + + // Now we get in symbol_info: + // .dli_fname: + // pathname of the shared object that contains the address. + // .dli_fbase: + // where the object is loaded in memory. + // .dli_sname: + // the name of the nearest symbol to trace.addr, we expect a + // function name. + // .dli_saddr: + // the exact address corresponding to .dli_sname. + + if (symbol_info.dli_sname) { + trace.object_function = demangle(symbol_info.dli_sname); + } + + if (not symbol_info.dli_fname) { + return trace; + } + + trace.object_filename = symbol_info.dli_fname; + bfd_fileobject& fobj = load_object_with_bfd(symbol_info.dli_fname); + if (not fobj.handle) { + return trace; // sad, we couldn't load the object :( + } + + + find_sym_result* details_selected; // to be filled. + + // trace.addr is the next instruction to be executed after returning + // from the nested stack frame. In C++ this usually relate to the next + // statement right after the function call that leaded to a new stack + // frame. This is not usually what you want to see when printing out a + // stacktrace... + find_sym_result details_call_site = find_symbol_details(fobj, + trace.addr, symbol_info.dli_fbase); + details_selected = &details_call_site; + +#if BACKWARD_HAS_UNWIND == 0 + // ...this is why we also try to resolve the symbol that is right + // before the return address. If we are lucky enough, we will get the + // line of the function that was called. But if the code is optimized, + // we might get something absolutely not related since the compiler + // can reschedule the return address with inline functions and + // tail-call optimisation (among other things that I don't even know + // or cannot even dream about with my tiny limited brain). + find_sym_result details_adjusted_call_site = find_symbol_details(fobj, + (void*) (uintptr_t(trace.addr) - 1), + symbol_info.dli_fbase); + + // In debug mode, we should always get the right thing(TM). + if (details_call_site.found and details_adjusted_call_site.found) { + // Ok, we assume that details_adjusted_call_site is a better estimation. + details_selected = &details_adjusted_call_site; + trace.addr = (void*) (uintptr_t(trace.addr) - 1); + } + + if (details_selected == &details_call_site and details_call_site.found) { + // we have to re-resolve the symbol in order to reset some + // internal state in BFD... so we can call backtrace_inliners + // thereafter... + details_call_site = find_symbol_details(fobj, trace.addr, + symbol_info.dli_fbase); + } +#endif // BACKWARD_HAS_UNWIND + + if (details_selected->found) { + if (details_selected->filename) { + trace.source.filename = details_selected->filename; + } + trace.source.line = details_selected->line; + + if (details_selected->funcname) { + // this time we get the name of the function where the code is + // located, instead of the function were the address is + // located. In short, if the code was inlined, we get the + // function correspoding to the code. Else we already got in + // trace.function. + trace.source.function = demangle(details_selected->funcname); + + if (not symbol_info.dli_sname) { + // for the case dladdr failed to find the symbol name of + // the function, we might as well try to put something + // here. + trace.object_function = trace.source.function; + } + } + + // Maybe the source of the trace got inlined inside the function + // (trace.source.function). Let's see if we can get all the inlined + // calls along the way up to the initial call site. + trace.inliners = backtrace_inliners(fobj, *details_selected); + +#if 0 + if (trace.inliners.size() == 0) { + // Maybe the trace was not inlined... or maybe it was and we + // are lacking the debug information. Let's try to make the + // world better and see if we can get the line number of the + // function (trace.source.function) now. + // + // We will get the location of where the function start (to be + // exact: the first instruction that really start the + // function), not where the name of the function is defined. + // This can be quite far away from the name of the function + // btw. + // + // If the source of the function is the same as the source of + // the trace, we cannot say if the trace was really inlined or + // not. However, if the filename of the source is different + // between the function and the trace... we can declare it as + // an inliner. This is not 100% accurate, but better than + // nothing. + + if (symbol_info.dli_saddr) { + find_sym_result details = find_symbol_details(fobj, + symbol_info.dli_saddr, + symbol_info.dli_fbase); + + if (details.found) { + ResolvedTrace::SourceLoc diy_inliner; + diy_inliner.line = details.line; + if (details.filename) { + diy_inliner.filename = details.filename; + } + if (details.funcname) { + diy_inliner.function = demangle(details.funcname); + } else { + diy_inliner.function = trace.source.function; + } + if (diy_inliner != trace.source) { + trace.inliners.push_back(diy_inliner); + } + } + } + } +#endif + } + + return trace; + } + +private: + bool _bfd_loaded; + + typedef details::handle + > bfd_handle_t; + + typedef details::handle bfd_symtab_t; + + + struct bfd_fileobject { + bfd_handle_t handle; + bfd_vma base_addr; + bfd_symtab_t symtab; + bfd_symtab_t dynamic_symtab; + }; + + typedef details::hashtable::type + fobj_bfd_map_t; + fobj_bfd_map_t _fobj_bfd_map; + + bfd_fileobject& load_object_with_bfd(const std::string& filename_object) { + using namespace details; + + if (not _bfd_loaded) { + using namespace details; + bfd_init(); + _bfd_loaded = true; + } + + fobj_bfd_map_t::iterator it = + _fobj_bfd_map.find(filename_object); + if (it != _fobj_bfd_map.end()) { + return it->second; + } + + // this new object is empty for now. + bfd_fileobject& r = _fobj_bfd_map[filename_object]; + + // we do the work temporary in this one; + bfd_handle_t bfd_handle; + + int fd = open(filename_object.c_str(), O_RDONLY); + bfd_handle.reset( + bfd_fdopenr(filename_object.c_str(), "default", fd) + ); + if (not bfd_handle) { + close(fd); + return r; + } + + if (not bfd_check_format(bfd_handle.get(), bfd_object)) { + return r; // not an object? You lose. + } + + if ((bfd_get_file_flags(bfd_handle.get()) & HAS_SYMS) == 0) { + return r; // that's what happen when you forget to compile in debug. + } + + ssize_t symtab_storage_size = + bfd_get_symtab_upper_bound(bfd_handle.get()); + + ssize_t dyn_symtab_storage_size = + bfd_get_dynamic_symtab_upper_bound(bfd_handle.get()); + + if (symtab_storage_size <= 0 and dyn_symtab_storage_size <= 0) { + return r; // weird, is the file is corrupted? + } + + bfd_symtab_t symtab, dynamic_symtab; + ssize_t symcount = 0, dyn_symcount = 0; + + if (symtab_storage_size > 0) { + symtab.reset( + (bfd_symbol**) malloc(symtab_storage_size) + ); + symcount = bfd_canonicalize_symtab( + bfd_handle.get(), symtab.get() + ); + } + + if (dyn_symtab_storage_size > 0) { + dynamic_symtab.reset( + (bfd_symbol**) malloc(dyn_symtab_storage_size) + ); + dyn_symcount = bfd_canonicalize_dynamic_symtab( + bfd_handle.get(), dynamic_symtab.get() + ); + } + + + if (symcount <= 0 and dyn_symcount <= 0) { + return r; // damned, that's a stripped file that you got there! + } + + r.handle = move(bfd_handle); + r.symtab = move(symtab); + r.dynamic_symtab = move(dynamic_symtab); + return r; + } + + struct find_sym_result { + bool found; + const char* filename; + const char* funcname; + unsigned int line; + }; + + struct find_sym_context { + TraceResolverLinuxImpl* self; + bfd_fileobject* fobj; + void* addr; + void* base_addr; + find_sym_result result; + }; + + find_sym_result find_symbol_details(bfd_fileobject& fobj, void* addr, + void* base_addr) { + find_sym_context context; + context.self = this; + context.fobj = &fobj; + context.addr = addr; + context.base_addr = base_addr; + context.result.found = false; + bfd_map_over_sections(fobj.handle.get(), &find_in_section_trampoline, + (void*)&context); + return context.result; + } + + static void find_in_section_trampoline(bfd*, asection* section, + void* data) { + find_sym_context* context = static_cast(data); + context->self->find_in_section( + reinterpret_cast(context->addr), + reinterpret_cast(context->base_addr), + *context->fobj, + section, context->result + ); + } + + void find_in_section(bfd_vma addr, bfd_vma base_addr, + bfd_fileobject& fobj, asection* section, find_sym_result& result) + { + if (result.found) return; + + if ((bfd_get_section_flags(fobj.handle.get(), section) + & SEC_ALLOC) == 0) + return; // a debug section is never loaded automatically. + + bfd_vma sec_addr = bfd_get_section_vma(fobj.handle.get(), section); + bfd_size_type size = bfd_get_section_size(section); + + // are we in the boundaries of the section? + if (addr < sec_addr or addr >= sec_addr + size) { + addr -= base_addr; // oups, a relocated object, lets try again... + if (addr < sec_addr or addr >= sec_addr + size) { + return; + } + } + + if (not result.found and fobj.symtab) { + result.found = bfd_find_nearest_line(fobj.handle.get(), section, + fobj.symtab.get(), addr - sec_addr, &result.filename, + &result.funcname, &result.line); + } + + if (not result.found and fobj.dynamic_symtab) { + result.found = bfd_find_nearest_line(fobj.handle.get(), section, + fobj.dynamic_symtab.get(), addr - sec_addr, + &result.filename, &result.funcname, &result.line); + } + + } + + ResolvedTrace::source_locs_t backtrace_inliners(bfd_fileobject& fobj, + find_sym_result previous_result) { + // This function can be called ONLY after a SUCCESSFUL call to + // find_symbol_details. The state is global to the bfd_handle. + ResolvedTrace::source_locs_t results; + while (previous_result.found) { + find_sym_result result; + result.found = bfd_find_inliner_info(fobj.handle.get(), + &result.filename, &result.funcname, &result.line); + + if (result.found) /* and not ( + cstrings_eq(previous_result.filename, result.filename) + and cstrings_eq(previous_result.funcname, result.funcname) + and result.line == previous_result.line + )) */ { + ResolvedTrace::SourceLoc src_loc; + src_loc.line = result.line; + if (result.filename) { + src_loc.filename = result.filename; + } + if (result.funcname) { + src_loc.function = demangle(result.funcname); + } + results.push_back(src_loc); + } + previous_result = result; + } + return results; + } + + bool cstrings_eq(const char* a, const char* b) { + if (not a or not b) { + return false; + } + return strcmp(a, b) == 0; + } + +}; +#endif // BACKWARD_HAS_BFD == 1 + +#if BACKWARD_HAS_DW == 1 + +template <> +class TraceResolverLinuxImpl: + public TraceResolverLinuxImplBase { +public: + TraceResolverLinuxImpl(): _dwfl_handle_initialized(false) {} + + template + void load_stacktrace(ST&) {} + + ResolvedTrace resolve(ResolvedTrace trace) { + using namespace details; + + Dwarf_Addr trace_addr = (Dwarf_Addr) trace.addr; + + if (not _dwfl_handle_initialized) { + // initialize dwfl... + _dwfl_cb.reset(new Dwfl_Callbacks); + _dwfl_cb->find_elf = &dwfl_linux_proc_find_elf; + _dwfl_cb->find_debuginfo = &dwfl_standard_find_debuginfo; + _dwfl_cb->debuginfo_path = 0; + + _dwfl_handle.reset(dwfl_begin(_dwfl_cb.get())); + _dwfl_handle_initialized = true; + + if (not _dwfl_handle) { + return trace; + } + + // ...from the current process. + dwfl_report_begin(_dwfl_handle.get()); + int r = dwfl_linux_proc_report (_dwfl_handle.get(), getpid()); + dwfl_report_end(_dwfl_handle.get(), NULL, NULL); + if (r < 0) { + return trace; + } + } + + if (not _dwfl_handle) { + return trace; + } + + // find the module (binary object) that contains the trace's address. + // This is not using any debug information, but the addresses ranges of + // all the currently loaded binary object. + Dwfl_Module* mod = dwfl_addrmodule(_dwfl_handle.get(), trace_addr); + if (mod) { + // now that we found it, lets get the name of it, this will be the + // full path to the running binary or one of the loaded library. + const char* module_name = dwfl_module_info (mod, + 0, 0, 0, 0, 0, 0, 0); + if (module_name) { + trace.object_filename = module_name; + } + // We also look after the name of the symbol, equal or before this + // address. This is found by walking the symtab. We should get the + // symbol corresponding to the function (mangled) containing the + // address. If the code corresponding to the address was inlined, + // this is the name of the out-most inliner function. + const char* sym_name = dwfl_module_addrname(mod, trace_addr); + if (sym_name) { + trace.object_function = demangle(sym_name); + } + } + + // now let's get serious, and find out the source location (file and + // line number) of the address. + + // This function will look in .debug_aranges for the address and map it + // to the location of the compilation unit DIE in .debug_info and + // return it. + Dwarf_Addr mod_bias = 0; + Dwarf_Die* cudie = dwfl_module_addrdie(mod, trace_addr, &mod_bias); + +#if 1 + if (not cudie) { + // Sadly clang does not generate the section .debug_aranges, thus + // dwfl_module_addrdie will fail early. Clang doesn't either set + // the lowpc/highpc/range info for every compilation unit. + // + // So in order to save the world: + // for every compilation unit, we will iterate over every single + // DIEs. Normally functions should have a lowpc/highpc/range, which + // we will use to infer the compilation unit. + + // note that this is probably badly inefficient. + while ((cudie = dwfl_module_nextcu(mod, cudie, &mod_bias))) { + Dwarf_Die die_mem; + Dwarf_Die* fundie = find_fundie_by_pc(cudie, + trace_addr - mod_bias, &die_mem); + if (fundie) { + break; + } + } + } +#endif + +//#define BACKWARD_I_DO_NOT_RECOMMEND_TO_ENABLE_THIS_HORRIBLE_PIECE_OF_CODE +#ifdef BACKWARD_I_DO_NOT_RECOMMEND_TO_ENABLE_THIS_HORRIBLE_PIECE_OF_CODE + if (not cudie) { + // If it's still not enough, lets dive deeper in the shit, and try + // to save the world again: for every compilation unit, we will + // load the corresponding .debug_line section, and see if we can + // find our address in it. + + Dwarf_Addr cfi_bias; + Dwarf_CFI* cfi_cache = dwfl_module_eh_cfi(mod, &cfi_bias); + + Dwarf_Addr bias; + while ((cudie = dwfl_module_nextcu(mod, cudie, &bias))) { + if (dwarf_getsrc_die(cudie, trace_addr - bias)) { + + // ...but if we get a match, it might be a false positive + // because our (address - bias) might as well be valid in a + // different compilation unit. So we throw our last card on + // the table and lookup for the address into the .eh_frame + // section. + + handle frame; + dwarf_cfi_addrframe(cfi_cache, trace_addr - cfi_bias, &frame); + if (frame) { + break; + } + } + } + } +#endif + + if (not cudie) { + return trace; // this time we lost the game :/ + } + + // Now that we have a compilation unit DIE, this function will be able + // to load the corresponding section in .debug_line (if not already + // loaded) and hopefully find the source location mapped to our + // address. + Dwarf_Line* srcloc = dwarf_getsrc_die(cudie, trace_addr - mod_bias); + + if (srcloc) { + const char* srcfile = dwarf_linesrc(srcloc, 0, 0); + if (srcfile) { + trace.source.filename = srcfile; + } + int line = 0, col = 0; + dwarf_lineno(srcloc, &line); + dwarf_linecol(srcloc, &col); + trace.source.line = line; + trace.source.col = col; + } + + deep_first_search_by_pc(cudie, trace_addr - mod_bias, + inliners_search_cb(trace)); + if (trace.source.function.size() == 0) { + // fallback. + trace.source.function = trace.object_function; + } + + return trace; + } + +private: + typedef details::handle > + dwfl_handle_t; + details::handle > + _dwfl_cb; + dwfl_handle_t _dwfl_handle; + bool _dwfl_handle_initialized; + + // defined here because in C++98, template function cannot take locally + // defined types... grrr. + struct inliners_search_cb { + void operator()(Dwarf_Die* die) { + switch (dwarf_tag(die)) { + const char* name; + case DW_TAG_subprogram: + if ((name = dwarf_diename(die))) { + trace.source.function = name; + } + break; + + case DW_TAG_inlined_subroutine: + ResolvedTrace::SourceLoc sloc; + Dwarf_Attribute attr_mem; + + if ((name = dwarf_diename(die))) { + trace.source.function = name; + } + if ((name = die_call_file(die))) { + sloc.filename = name; + } + + Dwarf_Word line = 0, col = 0; + dwarf_formudata(dwarf_attr(die, DW_AT_call_line, + &attr_mem), &line); + dwarf_formudata(dwarf_attr(die, DW_AT_call_column, + &attr_mem), &col); + sloc.line = line; + sloc.col = col; + + trace.inliners.push_back(sloc); + break; + }; + } + ResolvedTrace& trace; + inliners_search_cb(ResolvedTrace& t): trace(t) {} + }; + + + static bool die_has_pc(Dwarf_Die* die, Dwarf_Addr pc) { + Dwarf_Addr low, high; + + // continuous range + if (dwarf_hasattr(die, DW_AT_low_pc) and + dwarf_hasattr(die, DW_AT_high_pc)) { + if (dwarf_lowpc(die, &low) != 0) { + return false; + } + if (dwarf_highpc(die, &high) != 0) { + Dwarf_Attribute attr_mem; + Dwarf_Attribute* attr = dwarf_attr(die, DW_AT_high_pc, &attr_mem); + Dwarf_Word value; + if (dwarf_formudata(attr, &value) != 0) { + return false; + } + high = low + value; + } + return pc >= low and pc < high; + } + + // non-continuous range. + Dwarf_Addr base; + ptrdiff_t offset = 0; + while ((offset = dwarf_ranges(die, offset, &base, &low, &high)) > 0) { + if (pc >= low and pc < high) { + return true; + } + } + return false; + } + + static Dwarf_Die* find_fundie_by_pc(Dwarf_Die* parent_die, Dwarf_Addr pc, + Dwarf_Die* result) { + if (dwarf_child(parent_die, result) != 0) { + return 0; + } + + Dwarf_Die* die = result; + do { + switch (dwarf_tag(die)) { + case DW_TAG_subprogram: + case DW_TAG_inlined_subroutine: + if (die_has_pc(die, pc)) { + return result; + } + default: + bool declaration = false; + Dwarf_Attribute attr_mem; + dwarf_formflag(dwarf_attr(die, DW_AT_declaration, + &attr_mem), &declaration); + if (not declaration) { + // let's be curious and look deeper in the tree, + // function are not necessarily at the first level, but + // might be nested inside a namespace, structure etc. + Dwarf_Die die_mem; + Dwarf_Die* indie = find_fundie_by_pc(die, pc, &die_mem); + if (indie) { + *result = die_mem; + return result; + } + } + }; + } while (dwarf_siblingof(die, result) == 0); + return 0; + } + + template + static bool deep_first_search_by_pc(Dwarf_Die* parent_die, + Dwarf_Addr pc, CB cb) { + Dwarf_Die die_mem; + if (dwarf_child(parent_die, &die_mem) != 0) { + return false; + } + + bool branch_has_pc = false; + Dwarf_Die* die = &die_mem; + do { + bool declaration = false; + Dwarf_Attribute attr_mem; + dwarf_formflag(dwarf_attr(die, DW_AT_declaration, &attr_mem), &declaration); + if (not declaration) { + // let's be curious and look deeper in the tree, function are + // not necessarily at the first level, but might be nested + // inside a namespace, structure, a function, an inlined + // function etc. + branch_has_pc = deep_first_search_by_pc(die, pc, cb); + } + if (not branch_has_pc) { + branch_has_pc = die_has_pc(die, pc); + } + if (branch_has_pc) { + cb(die); + } + } while (dwarf_siblingof(die, &die_mem) == 0); + return branch_has_pc; + } + + static const char* die_call_file(Dwarf_Die *die) { + Dwarf_Attribute attr_mem; + Dwarf_Sword file_idx = 0; + + dwarf_formsdata(dwarf_attr(die, DW_AT_call_file, &attr_mem), + &file_idx); + + if (file_idx == 0) { + return 0; + } + + Dwarf_Die die_mem; + Dwarf_Die* cudie = dwarf_diecu(die, &die_mem, 0, 0); + if (not cudie) { + return 0; + } + + Dwarf_Files* files = 0; + size_t nfiles; + dwarf_getsrcfiles(cudie, &files, &nfiles); + if (not files) { + return 0; + } + + return dwarf_filesrc(files, file_idx, 0, 0); + } + +}; +#endif // BACKWARD_HAS_DW == 1 + +template<> +class TraceResolverImpl: + public TraceResolverLinuxImpl {}; + +#endif // BACKWARD_SYSTEM_LINUX + +class TraceResolver: + public TraceResolverImpl {}; + +/*************** CODE SNIPPET ***************/ + +class SourceFile { +public: + typedef std::vector > lines_t; + + SourceFile() {} + SourceFile(const std::string& path): _file(new std::ifstream(path.c_str())) {} + bool is_open() const { return _file->is_open(); } + + lines_t& get_lines(unsigned line_start, unsigned line_count, lines_t& lines) { + using namespace std; + // This function make uses of the dumbest algo ever: + // 1) seek(0) + // 2) read lines one by one and discard until line_start + // 3) read line one by one until line_start + line_count + // + // If you are getting snippets many time from the same file, it is + // somewhat a waste of CPU, feel free to benchmark and propose a + // better solution ;) + + _file->clear(); + _file->seekg(0); + string line; + unsigned line_idx; + + for (line_idx = 1; line_idx < line_start; ++line_idx) { + getline(*_file, line); + if (not *_file) { + return lines; + } + } + + // think of it like a lambda in C++98 ;) + // but look, I will reuse it two times! + // What a good boy am I. + struct isspace { + bool operator()(char c) { + return std::isspace(c); + } + }; + + bool started = false; + for (; line_idx < line_start + line_count; ++line_idx) { + getline(*_file, line); + if (not *_file) { + return lines; + } + if (not started) { + if (std::find_if(line.begin(), line.end(), + not_isspace()) == line.end()) + continue; + started = true; + } + lines.push_back(make_pair(line_idx, line)); + } + + lines.erase( + std::find_if(lines.rbegin(), lines.rend(), + not_isempty()).base(), lines.end() + ); + return lines; + } + + lines_t get_lines(unsigned line_start, unsigned line_count) { + lines_t lines; + return get_lines(line_start, line_count, lines); + } + + // there is no find_if_not in C++98, lets do something crappy to + // workaround. + struct not_isspace { + bool operator()(char c) { + return not std::isspace(c); + } + }; + // and define this one here because C++98 is not happy with local defined + // struct passed to template functions, fuuuu. + struct not_isempty { + bool operator()(const lines_t::value_type& p) { + return not (std::find_if(p.second.begin(), p.second.end(), + not_isspace()) == p.second.end()); + } + }; + + void swap(SourceFile& b) { + _file.swap(b._file); + } + +#if defined(BACKWARD_CXX11) + SourceFile(SourceFile&& from): _file(0) { + swap(from); + } + SourceFile& operator=(SourceFile&& from) { + swap(from); return *this; + } +#else + explicit SourceFile(const SourceFile& from) { + // some sort of poor man's move semantic. + swap(const_cast(from)); + } + SourceFile& operator=(const SourceFile& from) { + // some sort of poor man's move semantic. + swap(const_cast(from)); return *this; + } +#endif + +private: + details::handle + > _file; + +#if defined(BACKWARD_CXX11) + SourceFile(const SourceFile&) = delete; + SourceFile& operator=(const SourceFile&) = delete; +#endif +}; + +class SnippetFactory { +public: + typedef SourceFile::lines_t lines_t; + + lines_t get_snippet(const std::string& filename, + unsigned line_start, unsigned context_size) { + + SourceFile& src_file = get_src_file(filename); + unsigned start = line_start - context_size / 2; + return src_file.get_lines(start, context_size); + } + + lines_t get_combined_snippet( + const std::string& filename_a, unsigned line_a, + const std::string& filename_b, unsigned line_b, + unsigned context_size) { + SourceFile& src_file_a = get_src_file(filename_a); + SourceFile& src_file_b = get_src_file(filename_b); + + lines_t lines = src_file_a.get_lines(line_a - context_size / 4, + context_size / 2); + src_file_b.get_lines(line_b - context_size / 4, context_size / 2, + lines); + return lines; + } + + lines_t get_coalesced_snippet(const std::string& filename, + unsigned line_a, unsigned line_b, unsigned context_size) { + SourceFile& src_file = get_src_file(filename); + + using std::min; using std::max; + unsigned a = min(line_a, line_b); + unsigned b = max(line_a, line_b); + + if ((b - a) < (context_size / 3)) { + return src_file.get_lines((a + b - context_size + 1) / 2, + context_size); + } + + lines_t lines = src_file.get_lines(a - context_size / 4, + context_size / 2); + src_file.get_lines(b - context_size / 4, context_size / 2, lines); + return lines; + } + + +private: + typedef details::hashtable::type src_files_t; + src_files_t _src_files; + + SourceFile& get_src_file(const std::string& filename) { + src_files_t::iterator it = _src_files.find(filename); + if (it != _src_files.end()) { + return it->second; + } + SourceFile& new_src_file = _src_files[filename]; + new_src_file = SourceFile(filename); + return new_src_file; + } +}; + +/*************** PRINTER ***************/ + +#ifdef BACKWARD_SYSTEM_LINUX + +namespace Color { + enum type { + yellow = 33, + purple = 35, + reset = 39 + }; +} // namespace Color + +class Colorize { +public: + Colorize(std::FILE* os): + _os(os), _reset(false), _istty(false) {} + + void init() { + _istty = isatty(fileno(_os)); + } + + void set_color(Color::type ccode) { + if (not _istty) return; + + // I assume that the terminal can handle basic colors. Seriously I + // don't want to deal with all the termcap shit. + fprintf(_os, "\033[%im", static_cast(ccode)); + _reset = (ccode != Color::reset); + } + + ~Colorize() { + if (_reset) { + set_color(Color::reset); + } + } + +private: + std::FILE* _os; + bool _reset; + bool _istty; +}; + +#else // ndef BACKWARD_SYSTEM_LINUX + + +namespace Color { + enum type { + yellow = 0, + purple = 0, + reset = 0 + }; +} // namespace Color + +class Colorize { +public: + Colorize(std::FILE*) {} + void init() {} + void set_color(Color::type) {} +}; + +#endif // BACKWARD_SYSTEM_LINUX + +class Printer { +public: + bool snippet; + bool color; + bool address; + bool object; + + Printer(): + snippet(true), + color(true), + address(false), + object(false) + {} + + template + FILE* print(StackTrace& st, FILE* os = stderr) { + using namespace std; + + Colorize colorize(os); + if (color) { + colorize.init(); + } + + fprintf(os, "Stack trace (most recent call last)"); + if (st.thread_id()) { + fprintf(os, " in thread %u:\n", st.thread_id()); + } else { + fprintf(os, ":\n"); + } + + _resolver.load_stacktrace(st); + for (unsigned trace_idx = st.size(); trace_idx > 0; --trace_idx) { + fprintf(os, "#%-2u", trace_idx); + bool already_indented = true; + const ResolvedTrace trace = _resolver.resolve(st[trace_idx-1]); + + if (not trace.source.filename.size() or object) { + fprintf(os, " Object \"%s\", at %p, in %s\n", + trace.object_filename.c_str(), trace.addr, + trace.object_function.c_str()); + already_indented = false; + } + + if (trace.source.filename.size()) { + for (size_t inliner_idx = trace.inliners.size(); + inliner_idx > 0; --inliner_idx) { + if (not already_indented) { + fprintf(os, " "); + } + const ResolvedTrace::SourceLoc& inliner_loc + = trace.inliners[inliner_idx-1]; + print_source_loc(os, " | ", inliner_loc); + if (snippet) { + print_snippet(os, " | ", inliner_loc, + colorize, Color::purple, 5); + } + already_indented = false; + } + + if (not already_indented) { + fprintf(os, " "); + } + print_source_loc(os, " ", trace.source, trace.addr); + if (snippet) { + print_snippet(os, " ", trace.source, + colorize, Color::yellow, 7); + } + + if (trace.locals.size()) { + print_locals(os, " ", trace.locals); + } + } + } + return os; + } +private: + TraceResolver _resolver; + SnippetFactory _snippets; + + void print_snippet(FILE* os, const char* indent, + const ResolvedTrace::SourceLoc& source_loc, + Colorize& colorize, Color::type color_code, + int context_size) + { + using namespace std; + typedef SnippetFactory::lines_t lines_t; + + lines_t lines = _snippets.get_snippet(source_loc.filename, + source_loc.line, context_size); + + for (lines_t::const_iterator it = lines.begin(); + it != lines.end(); ++it) { + if (it-> first == source_loc.line) { + colorize.set_color(color_code); + fprintf(os, "%s>", indent); + } else { + fprintf(os, "%s ", indent); + } + fprintf(os, "%4u: %s\n", it->first, it->second.c_str()); + if (it-> first == source_loc.line) { + colorize.set_color(Color::reset); + } + } + } + + void print_source_loc(FILE* os, const char* indent, + const ResolvedTrace::SourceLoc& source_loc, + void* addr=0) { + fprintf(os, "%sSource \"%s\", line %i, in %s", + indent, source_loc.filename.c_str(), (int)source_loc.line, + source_loc.function.c_str()); + + if (address and addr != 0) { + fprintf(os, " [%p]\n", addr); + } else { + fprintf(os, "\n"); + } + } + + void print_var(FILE* os, const char* base_indent, int indent, + const Variable& var) { + fprintf(os, "%s%s: ", base_indent, var.name.c_str()); + switch (var.kind) { + case Variable::VALUE: + fprintf(os, "%s\n", var.value().c_str()); + break; + case Variable::LIST: + fprintf(os, "["); + for (size_t i = 0; i < var.list().size(); ++i) { + if (i > 0) { + fprintf(os, ", %s", var.list()[i].c_str()); + } + fprintf(os, "%s", var.list()[i].c_str()); + } + fprintf(os, "]\n"); + break; + case Variable::MAP: + fprintf(os, "{\n"); + for (size_t i = 0; i < var.map().size(); ++i) { + if (i > 0) { + fprintf(os, ",\n%s", base_indent); + } + print_var(os, base_indent, indent + 2, var.map()[i]); + } + fprintf(os, "]\n"); + break; + }; + } + + void print_locals(FILE* os, const char* indent, + const std::vector& locals) { + fprintf(os, "%sLocal variables:\n", indent); + for (size_t i = 0; i < locals.size(); ++i) { + if (i > 0) { + fprintf(os, ",\n%s", indent); + } + print_var(os, indent, 0, locals[i]); + } + } +}; + +/*************** SIGNALS HANDLING ***************/ + +#ifdef BACKWARD_SYSTEM_LINUX + + +class SignalHandling { +public: + static std::vector make_default_signals() { + const int signals[] = { + // default action: Core + SIGILL, + SIGABRT, + SIGFPE, + SIGSEGV, + SIGBUS, + // I am not sure the following signals should be enabled by + // default: + // default action: Term + SIGHUP, + SIGINT, + SIGPIPE, + SIGALRM, + SIGTERM, + SIGUSR1, + SIGUSR2, + SIGPOLL, + SIGPROF, + SIGVTALRM, + SIGIO, + SIGPWR, + // default action: Core + SIGQUIT, + SIGSYS, + SIGTRAP, + SIGXCPU, + SIGXFSZ + }; + return std::vector(signals, signals + sizeof signals); + } + + SignalHandling(const std::vector& signals = make_default_signals()) : _loaded(false) { + bool success = true; + + const size_t stack_size = 1024 * 1024 * 8; + _stack_content.reset((char*)malloc(stack_size)); + if (_stack_content) { + stack_t ss; + ss.ss_sp = _stack_content.get(); + ss.ss_size = stack_size; + ss.ss_flags = 0; + if (sigaltstack(&ss, 0) < 0) { + success = false; + } + } else { + success = false; + } + + for (size_t i = 0; i < signals.size(); ++i) { + struct sigaction action; + action.sa_flags = SA_SIGINFO | SA_ONSTACK; + sigemptyset(&action.sa_mask); + action.sa_sigaction = &sig_handler; + + int r = sigaction(signals[i], &action, 0); + if (r < 0) success = false; + } + _loaded = success; + } + + bool loaded() const { return _loaded; } + +private: + details::handle _stack_content; + bool _loaded; + + static void sig_handler(int, siginfo_t* info, void* _ctx) { + ucontext_t *uctx = (ucontext_t*) _ctx; + + StackTrace st; + void* error_addr = 0; +#ifdef REG_RIP // x86_64 + error_addr = reinterpret_cast(uctx->uc_mcontext.gregs[REG_RIP]); +#elif defined(REG_EIP) // x86_32 + error_addr = reinterpret_cast(uctx->uc_mcontext.gregs[REG_EIP]); +#else +# warning ":/ sorry, ain't know no nothing none not of your architecture!" +#endif + if (error_addr) { + st.load_from(error_addr, 32); + } else { + st.load_here(32); + } + + Printer printer; + printer.address = true; + printer.print(st, stderr); + + psiginfo(info, 0); + // terminate the process immediately. + _exit(EXIT_FAILURE); + } +}; + +#endif // BACKWARD_SYSTEM_LINUX + +#ifdef BACKWARD_SYSTEM_UNKNOWN + +class SignalHandling { +public: + SignalHandling(const std::vector& = std::vector()) {} + bool init() { return false; } +}; + +#endif // BACKWARD_SYSTEM_UNKNOWN + +#if 0 +void crit_err_hdlr(int sig_num, siginfo_t * info, void * ucontext) +{ + void * array[50]; + void * caller_address; + char ** messages; + int size, i; + sig_ucontext_t * uc; + + uc = (sig_ucontext_t *)ucontext; + + /* Get the address at the time the signal was raised from the EIP (x86) */ + caller_address = (void *) uc->uc_mcontext.eip; + + fprintf(stderr, "signal %d (%s), address is %p from %p\n", + sig_num, strsignal(sig_num), info->si_addr, + (void *)caller_address); + + size = backtrace(array, 50); + + /* overwrite sigaction with caller's address */ + array[1] = caller_address; + + messages = backtrace_symbols(array, size); + + +void sig_handler(int sig, siginfo_t* info, void* _ctx) { +ucontext_t *context = (ucontext_t*) _ctx; + +psiginfo(info, "Shit hit the fan"); +exit(EXIT_FAILURE); +} + +using namespace std; + +void badass() { +cout << "baddass!" << endl; +((char*)&badass)[0] = 42; +} + +int main() { +struct sigaction action; +action.sa_flags = SA_SIGINFO; +sigemptyset(&action.sa_mask); +action.sa_sigaction = &sig_handler; +int r = sigaction(SIGSEGV, &action, 0); +if (r < 0) { err(errno, 0); } +r = sigaction(SIGILL, &action, 0); +if (r < 0) { err(errno, 0); } + +badass(); +return 0; +} + + +#endif + +// i want to get a stacktrace on: +// - abort +// - signals (segfault.. abort...) +// - exception +// - dont messup with gdb! +// - thread ID +// - helper for capturing stack trace inside exception +// propose a little magic wrapper to throw an exception adding a stacktrace, +// and propose a specific tool to get a stacktrace from an exception (if its +// available). +// - optional override __cxa_throw, then the specific magic tool could get +// the stacktrace. Might be possible to use a thread-local variable to do +// some shit. RTLD_DEEPBIND might do the tricks to override it on the fly. + +// maybe I can even get the last variables and theirs values? +// that might be possible. + +// print with code snippet +// print traceback demangled +// detect color stuff +// register all signals +// +// Seperate stacktrace (load and co function) +// than object extracting informations about a stack trace. + +// also public a simple function to print a stacktrace. + +// backtrace::StackTrace st; +// st.snapshot(); +// print(st); +// cout << st; + +} // namespace backward + +#endif /* H_GUARD */ diff --git a/src/dionysus/dionysus/chain.h b/src/dionysus/dionysus/chain.h new file mode 100755 index 0000000..00c9836 --- /dev/null +++ b/src/dionysus/dionysus/chain.h @@ -0,0 +1,153 @@ +#ifndef DIONYSUS_CHAIN_H +#define DIONYSUS_CHAIN_H + +#include +#include +#include + +#include "fields/z2.h" + +namespace dionysus +{ + +template +struct FieldElement +{ + typedef typename Field::Element Element; + FieldElement(Element e_): + e(e_) {} + Element element() const { return e; } + void set_element(Element e_) { e = e_; } + Element e; +}; + +template<> +struct FieldElement +{ + typedef Z2Field::Element Element; + FieldElement(Element) {} + Element element() const { return Z2Field::id(); } + void set_element(Element) {} +}; + +template +struct ChainEntry: public FieldElement, public Extra... +{ + typedef Field_ Field; + typedef Index_ Index; + + typedef FieldElement Parent; + typedef typename Parent::Element Element; + + ChainEntry(): Parent(Element()), i(Index()) {} // need for serialization + + ChainEntry(ChainEntry&& other) = default; + ChainEntry(const ChainEntry& other) = default; + ChainEntry& operator=(ChainEntry&& other) = default; + + ChainEntry(Element e_, const Index& i_): + Parent(e_), i(i_) {} + + ChainEntry(Element e_, Index&& i_): + Parent(e_), i(std::move(i_)) {} + + const Index& index() const { return i; } + Index& index() { return i; } + + // debug + bool operator==(const ChainEntry& other) const { return i == other.i; } + + Index i; +}; + +template +struct Chain +{ + struct Visitor + { + template + void first(Iter it) const {} + + template + void second(Iter it) const {} + + template + void equal_keep(Iter it) const {} + + template + void equal_drop(Iter it) const {} + }; + + // x += a*y + template + static void addto(C1& x, typename Field::Element a, const C2& y, const Field& field, const Cmp& cmp, const Visitor_& = Visitor_()); +}; + +template +struct Chain> +{ + struct Visitor + { + template + void first(Iter it) const {} + + template + void second(Iter it) const {} + + template + void equal_keep(Iter it) const {} + + template + void equal_drop(Iter it) const {} + }; + + // x += a*y + template + static void addto(std::list& x, typename Field::Element a, const C2& y, + const Field& field, const Cmp& cmp, const Visitor_& visitor = Visitor_()); +}; + + +template +struct Chain> +{ + struct Visitor + { + template + void first(Iter it) const {} + + template + void second(Iter it) const {} + + template + void equal_keep(Iter it) const {} + + template + void equal_drop(Iter it) const {} + }; + + // x += a*y + template + static void addto(std::set& x, typename Field::Element a, const C2& y, + const Field& field, const Cmp& cmp, const Visitor_& = Visitor_()); + + template + static void addto(std::set& x, typename Field::Element a, T&& y, + const Field& field, const Cmp& cmp, const Visitor_& = Visitor_()); +}; + +} + +//namespace std +//{ +// template +// void swap(::dionysus::ChainEntry& x, ::dionysus::ChainEntry& y) +// { +// std::swap(x.e, y.e); +// std::swap(x.i, y.i); +// } +//} + +#include "chain.hpp" + +#endif diff --git a/src/dionysus/dionysus/chain.hpp b/src/dionysus/dionysus/chain.hpp new file mode 100755 index 0000000..4da9f44 --- /dev/null +++ b/src/dionysus/dionysus/chain.hpp @@ -0,0 +1,188 @@ +template +template +void +dionysus::Chain>:: +addto(std::list& x, typename Field::Element a, const C2& y, const Field& field, const Cmp& cmp, const Visitor_& visitor) +{ + typedef typename Field::Element Element; + + auto cur_x = std::begin(x), + end_x = std::end(x); + auto cur_y = std::begin(y), + end_y = std::end(y); + + while (cur_x != end_x && cur_y != end_y) + { + if (cmp(cur_x->index(), cur_y->index())) + { + visitor.first(cur_x++); + } else if (cmp(cur_y->index(), cur_x->index())) + { + // multiply and add + Element ay = field.mul(a, cur_y->element()); + auto nw_x = x.insert(cur_x, *cur_y); + nw_x->set_element(ay); + ++cur_y; + visitor.second(nw_x); + } else + { + Element ay = field.mul(a, cur_y->element()); + Element r = field.add(cur_x->element(), ay); + if (field.is_zero(r)) + { + visitor.equal_drop(cur_x); + x.erase(cur_x++); + } + else + { + cur_x->set_element(r); + visitor.equal_keep(cur_x); + ++cur_x; + } + ++cur_y; + } + } + + for (auto it = cur_y; it != end_y; ++it) + { + Element ay = field.mul(a, it->element()); + x.push_back(*it); + x.back().set_element(ay); + visitor.second(--x.end()); + } +} + +template +template +void +dionysus::Chain>:: +addto(std::set& x, typename Field::Element a, const C2& y, const Field& field, const Cmp&, const Visitor_& visitor) +{ + typedef typename Field::Element Element; + + auto cur_y = std::begin(y), + end_y = std::end(y); + + while (cur_y != end_y) + { + auto cur_x = x.find(*cur_y); + if (cur_x == x.end()) + { + auto nw = x.insert(*cur_y).first; + Element ay = field.mul(a, nw->element()); + const_cast(*nw).set_element(ay); + visitor.second(nw); + } else + { + Element ay = field.mul(a, cur_y->element()); + Element r = field.add(cur_x->element(), ay); + if (field.is_zero(r)) + { + visitor.equal_drop(cur_x); + x.erase(cur_x); + } + else + { + const_cast(*cur_x).set_element(r); + visitor.equal_keep(cur_x); + } + } + ++cur_y; + } +} + +template +template +void +dionysus::Chain>:: +addto(std::set& x, typename Field::Element a, T&& y, const Field& field, const Cmp&, const Visitor_& visitor) +{ + typedef typename Field::Element Element; + + auto cur_x = x.find(y); + if (cur_x == x.end()) + { + auto nw = x.insert(std::move(y)).first; + Element ay = field.mul(a, nw->element()); + const_cast(*nw).set_element(ay); + visitor.second(nw); + } else + { + Element ay = field.mul(a, y.element()); + Element r = field.add(cur_x->element(), ay); + if (field.is_zero(r)) + { + visitor.equal_drop(cur_x); + x.erase(cur_x); + } + else + { + const_cast(*cur_x).set_element(r); + visitor.equal_keep(cur_x); + } + } +} + +template +template +void +dionysus::Chain:: +addto(C1& x, typename Field::Element a, const C2& y, const Field& field, const Cmp& cmp, const Visitor_& visitor) +{ + typedef typename Field::Element Element; + + C1 res; + + auto cur_x = std::begin(x), + end_x = std::end(x); + auto cur_y = std::begin(y), + end_y = std::end(y); + + while (cur_x != end_x && cur_y != end_y) + { + if (cmp(*cur_x, *cur_y)) + { + res.emplace_back(std::move(*cur_x)); + visitor.first(--res.end()); + ++cur_x; + } else if (cmp(*cur_y, *cur_x)) + { + // multiply and add + Element ay = field.mul(a, cur_y->element()); + res.emplace_back(ay, cur_y->index()); + visitor.second(--res.end()); + ++cur_y; + } else + { + Element ay = field.mul(a, cur_y->element()); + Element r = field.add(cur_x->element(), ay); + if (field.is_zero(r)) + visitor.equal_drop(cur_x); + else + { + res.emplace_back(std::move(*cur_x)); + res.back().set_element(r); + visitor.equal_keep(--res.end()); + } + ++cur_x; + ++cur_y; + } + } + + while (cur_y != end_y) + { + Element ay = field.mul(a, cur_y->element()); + res.emplace_back(ay, cur_y->index()); + visitor.second(--res.end()); + ++cur_y; + } + + while (cur_x != end_x) + { + res.emplace_back(std::move(*cur_x)); + visitor.first(--res.end()); + ++cur_x; + } + + x.swap(res); +} diff --git a/src/dionysus/dionysus/clearing-reduction.h b/src/dionysus/dionysus/clearing-reduction.h new file mode 100755 index 0000000..8651e9a --- /dev/null +++ b/src/dionysus/dionysus/clearing-reduction.h @@ -0,0 +1,45 @@ +#ifndef DIONYSUS_CLEARING_REDUCTION_H +#define DIONYSUS_CLEARING_REDUCTION_H + +namespace dionysus +{ + +// Mid-level interface +template +class ClearingReduction +{ + public: + using Persistence = Persistence_; + using Field = typename Persistence::Field; + using Index = typename Persistence::Index; + + public: + ClearingReduction(Persistence& persistence): + persistence_(persistence) {} + + template + void operator()(const Filtration& f, const Relative& relative, const ReportPair& report_pair, const Progress& progress); + + template + void operator()(const Filtration& f, const ReportPair& report_pair); + + template + void operator()(const Filtration& f) { return (*this)(f, &no_report_pair); } + + static void no_report_pair(int, Index, Index) {} + static void no_progress() {} + + const Persistence& + persistence() const { return persistence_; } + Persistence& persistence() { return persistence_; } + + private: + Persistence& persistence_; +}; + +} + +#include "clearing-reduction.hpp" + +#endif + diff --git a/src/dionysus/dionysus/clearing-reduction.hpp b/src/dionysus/dionysus/clearing-reduction.hpp new file mode 100755 index 0000000..ceac118 --- /dev/null +++ b/src/dionysus/dionysus/clearing-reduction.hpp @@ -0,0 +1,60 @@ +#include +#include + +#include +namespace ba = boost::adaptors; + +template +template +void +dionysus::ClearingReduction

:: +operator()(const Filtration& filtration, const ReportPair& report_pair) +{ + using Cell = typename Filtration::Cell; + (*this)(filtration, [](const Cell&) { return false; }, report_pair, &no_progress); +} + +template +template +void +dionysus::ClearingReduction

:: +operator()(const Filtration& filtration, const Relative& relative, const ReportPair& report_pair, const Progress& progress) +{ + persistence_.resize(filtration.size()); + + // sort indices by decreasing dimension + std::vector indices(filtration.size()); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), indices.end(), + [&filtration](size_t x, size_t y) + { return filtration[x].dimension() > filtration[y].dimension(); }); + + typedef typename Filtration::Cell Cell; + typedef ChainEntry CellChainEntry; + typedef ChainEntry ChainEntry; + + for(size_t i : indices) + { + progress(); + const auto& c = filtration[i]; + + if (relative(c)) + { + persistence_.set_skip(i); + continue; + } + + if (persistence_.pair(i) != persistence_.unpaired()) + continue; + + persistence_.set(i, c.boundary(persistence_.field()) | + ba::filtered([relative](const CellChainEntry& e) { return !relative(e.index()); }) | + ba::transformed([this,&filtration](const CellChainEntry& e) + { return ChainEntry(e.element(), filtration.index(e.index())); })); + + Index pair = persistence_.reduce(i); + if (pair != persistence_.unpaired()) + report_pair(c.dimension(), pair, i); + } +} + diff --git a/src/dionysus/dionysus/cohomology-persistence.h b/src/dionysus/dionysus/cohomology-persistence.h new file mode 100755 index 0000000..8d2019e --- /dev/null +++ b/src/dionysus/dionysus/cohomology-persistence.h @@ -0,0 +1,116 @@ +#ifndef DIONYSUS_COHOMOLOGY_PERSISTENCE_H +#define DIONYSUS_COHOMOLOGY_PERSISTENCE_H + +#include +#include + +#include +namespace bi = boost::intrusive; + +#include "reduction.h" +#include "chain.h" + +namespace dionysus +{ + +template> +class CohomologyPersistence +{ + public: + typedef Field_ Field; + typedef Index_ Index; + typedef Comparison_ Comparison; + + typedef typename Field::Element FieldElement; + + typedef bi::list_base_hook> auto_unlink_hook; + struct Entry; + struct ColumnHead; + + typedef std::vector Column; + typedef bi::list> Row; + typedef std::list Columns; + typedef typename Columns::iterator ColumnsIterator; + typedef Column Chain; + + using IndexColumn = std::tuple; + + CohomologyPersistence(const Field& field, + const Comparison& cmp = Comparison()): + field_(field), cmp_(cmp) {} + + CohomologyPersistence(Field&& field, + const Comparison& cmp = Comparison()): + field_(std::move(field)), + cmp_(cmp) {} + + CohomologyPersistence(CohomologyPersistence&& other): + field_(std::move(other.field_)), + cmp_(std::move(other.cmp_)), + columns_(std::move(other.columns_)), + rows_(std::move(other.rows_)) {} + + template + Index add(const ChainRange& chain); + + template + IndexColumn add(const ChainRange& chain, bool keep_cocycle); + + // TODO: no skip support for now + bool skip(Index) const { return false; } + void add_skip() {} + void set_skip(Index, bool flag = true) {} + + const Field& field() const { return field_; } + const Columns& columns() const { return columns_; } + void reserve(size_t s) { rows_.reserve(s); } + + struct AddtoVisitor; + + static const Index unpaired() { return Reduction::unpaired; } + + private: + Field field_; + Comparison cmp_; + Columns columns_; + std::vector rows_; +}; + + +template +struct CohomologyPersistence::ColumnHead +{ + ColumnHead(Index i): index_(i) {} + + Index index() const { return index_; } + + Index index_; + Column chain; +}; + +template +struct CohomologyPersistence::Entry: + public ChainEntry +{ + typedef ChainEntry Parent; + + Entry(FieldElement e, const Index& i): // slightly dangerous + Parent(e,i) {} + + Entry(FieldElement e, const Index& i, ColumnsIterator it): + Parent(e,i), column(it) {} + + Entry(const Entry& other) = default; + Entry(Entry&& other) = default; + + void unlink() { auto_unlink_hook::unlink(); } + bool is_linked() const { return auto_unlink_hook::is_linked(); } + + ColumnsIterator column; // TODO: I really don't like this overhead +}; + +} + +#include "cohomology-persistence.hpp" + +#endif diff --git a/src/dionysus/dionysus/cohomology-persistence.hpp b/src/dionysus/dionysus/cohomology-persistence.hpp new file mode 100755 index 0000000..b2334f9 --- /dev/null +++ b/src/dionysus/dionysus/cohomology-persistence.hpp @@ -0,0 +1,61 @@ +template +template +typename dionysus::CohomologyPersistence::Index +dionysus::CohomologyPersistence:: +add(const ChainRange& chain) +{ + return std::get<0>(add(chain, false)); // return just the index +} + + +template +template +typename dionysus::CohomologyPersistence::IndexColumn +dionysus::CohomologyPersistence:: +add(const ChainRange& chain, bool keep_cocycle) +{ + auto entry_cmp = [this](const Entry& e1, const Entry& e2) { return this->cmp_(e1.index(), e2.index()); }; + std::set row_sum(entry_cmp); + for (auto it = std::begin(chain); it != std::end(chain); ++it) + for (auto& re : rows_[it->index()]) + dionysus::Chain::addto(row_sum, it->element(), Entry(re.element(), re.column->index(), re.column), field_, cmp_); + + if (row_sum.empty()) // Birth + { + columns_.emplace_back(rows_.size()); + auto before_end = columns_.end(); + --before_end; + columns_.back().chain.push_back(Entry(field_.id(), rows_.size(), before_end)); + rows_.emplace_back(); + rows_.back().push_back(columns_.back().chain.front()); + return std::make_tuple(unpaired(), Column()); + } else // Death + { + // Select front element in terms of comparison (rows are unsorted) + auto it = std::max_element(std::begin(row_sum), std::end(row_sum), entry_cmp); + + Entry first = std::move(*it); + row_sum.erase(it); + + for (auto& ce : row_sum) + { + FieldElement ay = field_.neg(field_.div(ce.element(), first.element())); + dionysus::Chain::addto(ce.column->chain, ay, first.column->chain, field_, + [this](const Entry& e1, const Entry& e2) + { return this->cmp_(e1.index(), e2.index()); }); + + for (auto& x : ce.column->chain) + { + x.column = ce.column; + rows_[x.index()].push_back(x); + } + } + Index pair = first.column->index(); + Column cocycle; + if (keep_cocycle) + cocycle = std::move(first.column->chain); + columns_.erase(first.column); + rows_.emplace_back(); // useless row; only present to make indices match + return std::make_tuple(pair, cocycle); + } +} diff --git a/src/dionysus/dionysus/diagram.h b/src/dionysus/dionysus/diagram.h new file mode 100755 index 0000000..160fc9f --- /dev/null +++ b/src/dionysus/dionysus/diagram.h @@ -0,0 +1,105 @@ +#ifndef DIONYSUS_DIAGRAM_H +#define DIONYSUS_DIAGRAM_H + +#include +#include + +namespace dionysus +{ + +template +class Diagram +{ + public: + using Value = Value_; + using Data = Data_; + struct Point: public std::pair + { + using Parent = std::pair; + + Point(Value b, Value d, Data dd): + Parent(b,d), data(dd) {} + + Value birth() const { return Parent::first; } + Value death() const { return Parent::second; } + + // FIXME: temporary hack + Value operator[](size_t i) const { if (i == 0) return birth(); return death(); } + + Data data; + }; + + using Points = std::vector; + using iterator = typename Points::iterator; + using const_iterator = typename Points::const_iterator; + using value_type = Point; + + public: + const_iterator begin() const { return points.begin(); } + const_iterator end() const { return points.end(); } + iterator begin() { return points.begin(); } + iterator end() { return points.end(); } + + const Point& operator[](size_t i) const { return points[i]; } + + size_t size() const { return points.size(); } + void push_back(const Point& p) { points.push_back(p); } + template + void emplace_back(Args&&... args) { points.emplace_back(std::forward(args)...); } + + private: + std::vector points; +}; + +namespace detail +{ + template + struct Diagrams + { + using Value = decltype(std::declval()(std::declval())); + using Data = decltype(std::declval()(std::declval())); + using type = std::vector>; + }; +} + +template +typename detail::Diagrams::type +init_diagrams(const ReducedMatrix& m, const Filtration& f, const GetValue& get_value, const GetData& get_data) +{ + using Result = typename detail::Diagrams::type; + + Result diagrams; + for (typename ReducedMatrix::Index i = 0; i < m.size(); ++i) + { + if (m.skip(i)) + continue; + + auto& s = f[i]; + auto d = s.dimension(); + + while (d + 1 > diagrams.size()) + diagrams.emplace_back(); + + auto pair = m.pair(i); + if (pair == m.unpaired()) + { + auto birth = get_value(s); + using Value = decltype(birth); + Value death = std::numeric_limits::infinity(); + diagrams[d].emplace_back(birth, death, get_data(i)); + } else if (pair > i) // positive + { + auto birth = get_value(s); + auto death = get_value(f[pair]); + + if (birth != death) // skip diagonal + diagrams[d].emplace_back(birth, death, get_data(i)); + } // else negative: do nothing + } + + return diagrams; +} + +} + +#endif diff --git a/src/dionysus/dionysus/distances.h b/src/dionysus/dionysus/distances.h new file mode 100755 index 0000000..29cac60 --- /dev/null +++ b/src/dionysus/dionysus/distances.h @@ -0,0 +1,93 @@ +#ifndef DIONYSUS_DISTANCES_H +#define DIONYSUS_DISTANCES_H + +#include +#include + +namespace dionysus +{ + +/** + * Class: ExplicitDistances + * Stores the pairwise distances of Distances_ instance passed at construction. + * It's a protypical Distances template argument for the Rips complex. + */ +template +class ExplicitDistances +{ + public: + typedef Distances_ Distances; + typedef size_t IndexType; + typedef typename Distances::DistanceType DistanceType; + + ExplicitDistances(IndexType size): + size_(size), + distances_(size*(size + 1)/2 + size) {} + ExplicitDistances(const Distances& distances); + + DistanceType operator()(IndexType a, IndexType b) const; + DistanceType& operator()(IndexType a, IndexType b); + + size_t size() const { return size_; } + IndexType begin() const { return 0; } + IndexType end() const { return size(); } + + private: + std::vector distances_; + size_t size_; +}; + + +/** + * Class: PairwiseDistances + * Given a Container_ of points and a Distance_, it computes distances between elements + * in the container (given as instances of Index_ defaulted to unsigned) using the Distance_ functor. + * + * Container_ is assumed to be an std::vector. That simplifies a number of things. + */ +template +class PairwiseDistances +{ + public: + typedef Container_ Container; + typedef Distance_ Distance; + typedef Index_ IndexType; + typedef typename Distance::result_type DistanceType; + + + PairwiseDistances(const Container& container, + const Distance& distance = Distance()): + container_(container), distance_(distance) {} + + DistanceType operator()(IndexType a, IndexType b) const { return distance_(container_[a], container_[b]); } + + size_t size() const { return container_.size(); } + IndexType begin() const { return 0; } + IndexType end() const { return size(); } + + private: + const Container& container_; + Distance distance_; +}; + +template +struct L2Distance +{ + typedef Point_ Point; + typedef decltype(Point()[0] + 0) result_type; + + result_type operator()(const Point& p1, const Point& p2) const + { + result_type sum = 0; + for (size_t i = 0; i < p1.size(); ++i) + sum += (p1[i] - p2[i])*(p1[i] - p2[i]); + + return sqrt(sum); + } +}; + +} + +#include "distances.hpp" + +#endif // DIONYSUS_DISTANCES_H diff --git a/src/dionysus/dionysus/distances.hpp b/src/dionysus/dionysus/distances.hpp new file mode 100755 index 0000000..9b1f20a --- /dev/null +++ b/src/dionysus/dionysus/distances.hpp @@ -0,0 +1,30 @@ +template +dionysus::ExplicitDistances:: +ExplicitDistances(const Distances& distances): + size_(distances.size()), distances_((distances.size() * (distances.size() + 1))/2) +{ + IndexType i = 0; + for (typename Distances::IndexType a = distances.begin(); a != distances.end(); ++a) + for (typename Distances::IndexType b = a; b != distances.end(); ++b) + { + distances_[i++] = distances(a,b); + } +} + +template +typename dionysus::ExplicitDistances::DistanceType +dionysus::ExplicitDistances:: +operator()(IndexType a, IndexType b) const +{ + if (a > b) std::swap(a,b); + return distances_[a*size_ - ((a*(a-1))/2) + (b-a)]; +} + +template +typename dionysus::ExplicitDistances::DistanceType& +dionysus::ExplicitDistances:: +operator()(IndexType a, IndexType b) +{ + if (a > b) std::swap(a,b); + return distances_[a*size_ - ((a*(a-1))/2) + (b-a)]; +} diff --git a/src/dionysus/dionysus/dlog/progress.h b/src/dionysus/dionysus/dlog/progress.h new file mode 100755 index 0000000..12bf86a --- /dev/null +++ b/src/dionysus/dionysus/dlog/progress.h @@ -0,0 +1,57 @@ +#ifndef DLOG_PROGRESS_H +#define DLOG_PROGRESS_H + +#include +#include +#include +#include + +namespace dlog +{ + +struct progress +{ + progress(size_t total): + current_(0), total_(total) { show_progress(); } + + progress& operator++() { current_++; if (current_ * 100 / total_ > (current_ - 1) * 100 / total_) show_progress(); check_done(); return *this; } + progress& operator=(size_t cur) { current_ = cur; show_progress(); check_done(); return *this; } + progress& operator()(const std::string& s) { message_ = s; show_progress(); check_done(); return *this; } + template + progress& operator()(const T& x) { std::ostringstream oss; oss << x; return (*this)(oss.str()); } + + inline void show_progress() const; + void check_done() const { if (current_ >= total_) std::cout << "\n" << std::flush; } + + private: + size_t current_, total_; + std::string message_; +}; + +} + +void +dlog::progress:: +show_progress() const +{ + int barWidth = 70; + + std::cout << "["; + int pos = barWidth * current_ / total_; + for (int i = 0; i < barWidth; ++i) + { + if (i < pos) + std::cout << "="; + else if (i == pos) + std::cout << ">"; + else + std::cout << " "; + } + std::cout << "] " << std::setw(3) << current_ * 100 / total_ << "%"; + if (!message_.empty()) + std::cout << " (" << message_ << ")"; + std::cout << "\r"; + std::cout.flush(); +} + +#endif diff --git a/src/dionysus/dionysus/fields/q.h b/src/dionysus/dionysus/fields/q.h new file mode 100755 index 0000000..6cbf7ee --- /dev/null +++ b/src/dionysus/dionysus/fields/q.h @@ -0,0 +1,74 @@ +#ifndef DIONYSUS_Q_H +#define DIONYSUS_Q_H + +#include + +// TODO: eventually need to be able to adaptively switch to arbitrary precision arithmetic + +namespace dionysus +{ + +template +class Q +{ + public: + // typedef stuff. Renames Element_ to BaseElement + using BaseElement = Element_; + // Elements of the field Q should have a numerator and denominator. Equality should be determined by the relationship {ac ~ bd} given elements {a,b} and {c,d} + struct Element + { + // An element of this field should have a numerator and a denominator. This field is the field of rational numbers. + BaseElement numerator, denominator; + // Redefine the equals operator for elements. numerators for each should be equal and denominators should be equal? + bool operator==(Element o) const { return numerator == o.numerator && denominator == o.denominator; } + // Ask Dave + bool operator!=(Element o) const { return !((*this) == o); } + // ask Dave + friend + std::ostream& operator<<(std::ostream& out, Element e) { out << e.numerator << '/' << e.denominator; return out; } + }; + // identity is 1/1, zero is 0/1, given a long we make denominator 0 + Element id() const { return { 1,1 }; } + Element zero() const { return { 0,1 }; } + Element init(BaseElement a) const { return { a,1 }; } + + Element neg(Element a) const { return { -a.numerator, a.denominator }; } + Element add(Element a, Element b) const { Element x { a.numerator*b.denominator + b.numerator*a.denominator, a.denominator*b.denominator }; normalize(x); return x; } + + Element inv(Element a) const { return { a.denominator, a.numerator }; } + Element mul(Element a, Element b) const { Element x { a.numerator*b.numerator, a.denominator*b.denominator }; normalize(x); return x; } + Element div(Element a, Element b) const { return mul(a, inv(b)); } + + bool is_zero(Element a) const { return a.numerator == 0; } + + BaseElement numerator(const Element& x) const { return x.numerator; } + BaseElement denominator(const Element& x) const { return x.denominator; } + + static void normalize(Element& x) + { + BaseElement q = gcd(abs(x.numerator), abs(x.denominator)); + x.numerator /= q; + x.denominator /= q; + if (x.denominator < 0) + { + x.numerator = -x.numerator; + x.denominator = -x.denominator; + } + } + + static BaseElement abs(BaseElement x) { if (x < 0) return -x; return x; } + static BaseElement gcd(BaseElement a, BaseElement b) { + if (b < a) + return gcd(b,a); + while (a != 0) { + b %= a; std::swap(a,b); + } + return b; + } + + static bool is_prime(BaseElement x) { return false; } // Ok, since is_prime is only used as a shortcut +}; + +} + +#endif diff --git a/src/dionysus/dionysus/fields/z2.h b/src/dionysus/dionysus/fields/z2.h new file mode 100755 index 0000000..d73dbdf --- /dev/null +++ b/src/dionysus/dionysus/fields/z2.h @@ -0,0 +1,34 @@ +#ifndef DIONYSUS_Z2_H +#define DIONYSUS_Z2_H + +namespace dionysus +{ + +class Z2Field +{ + public: + typedef short Element; + + Z2Field() {} // this is a constructor + // this is a function that returns a short 1 + static Element id() { return 1; } + // this is a function that returns 0 + static Element zero() { return 0; } + // this init function returns + static Element init(int a) { return (a % 2 + 2) % 2; } + // turn a 0 to a 1 or vice versa + Element neg(Element a) const { return 2 - a; } + // add elements in binary + Element add(Element a, Element b) const { return (a+b) % 2; } + Element inv(Element a) const { return a; } + Element mul(Element a, Element b) const { return a*b; } + // This is strange + Element div(Element a, Element b) const { return a; } + + bool is_zero(Element a) const { return a == 0; } +}; + +} + +#endif + diff --git a/src/dionysus/dionysus/fields/zp.h b/src/dionysus/dionysus/fields/zp.h new file mode 100755 index 0000000..9fb761a --- /dev/null +++ b/src/dionysus/dionysus/fields/zp.h @@ -0,0 +1,60 @@ +#ifndef DIONYSUS_ZP_H +#define DIONYSUS_ZP_H + +#include + +namespace dionysus +{ + +template +class ZpField +{ + public: + typedef Element_ Element; + + ZpField(Element p); + ZpField(const ZpField& other) = default; + ZpField(ZpField&& other) = default; + + Element id() const { return 1; } + Element zero() const { return 0; } + // adding a prime number so that when modded the result is positive + Element init(int a) const { return (a % p_ + p_) % p_; } + + Element neg(Element a) const { return p_ - a; } + Element add(Element a, Element b) const { return (a+b) % p_; } + + Element inv(Element a) const { + while (a < 0) a += p_; + return inverses_[a]; + } + + Element mul(Element a, Element b) const { return (a*b) % p_; } + Element div(Element a, Element b) const { return mul(a, inv(b)); } + + bool is_zero(Element a) const { return (a % p_) == 0; } + + Element prime() const { return p_; } + + private: + Element p_; + std::vector inverses_; +}; + +template +ZpField::ZpField(Element p): + p_(p), // constructor for setting p_ + inverses_(p_) // constructor that sets length of vector + { + for (Element i = 1; i < p_; ++i) + for (Element j = 1; j < p_; ++j) + if (mul(i,j) == 1) + { + inverses_[i] = j; + break; + } + } + +} + +#endif diff --git a/src/dionysus/dionysus/filtration.h b/src/dionysus/dionysus/filtration.h new file mode 100755 index 0000000..98d8501 --- /dev/null +++ b/src/dionysus/dionysus/filtration.h @@ -0,0 +1,123 @@ +#ifndef DIONYSUS_FILTRATION_H +#define DIONYSUS_FILTRATION_H + +#include + +#include +#include +#include +#include + +namespace b = boost; +namespace bmi = boost::multi_index; + +namespace dionysus +{ + +// Filtration stores a filtered cell complex as boost::multi_index_container<...>. +// It allows for bidirectional translation between a cell and its index. +template>, + bool checked_index = false> +class Filtration +{ + public: + struct order {}; + + typedef Cell_ Cell; + typedef CellLookupIndex_ CellLookupIndex; + + typedef b::multi_index_container> + >> Container; + typedef typename Container::value_type value_type; + + typedef typename Container::template nth_index<0>::type Complex; + typedef typename Container::template nth_index<1>::type Order; + typedef typename Order::const_iterator OrderConstIterator; + typedef typename Order::iterator OrderIterator; + + + public: + Filtration() = default; + Filtration(Filtration&& other) = default; + Filtration& operator=(Filtration&& other) = default; + + Filtration(const std::initializer_list& cells): + Filtration(std::begin(cells), std::end(cells)) {} + + template + Filtration(Iterator bg, Iterator end): + cells_(bg, end) {} + + template + Filtration(const CellRange& cells): + Filtration(std::begin(cells), std::end(cells)) {} + + // Lookup + const Cell& operator[](size_t i) const { return cells_.template get()[i]; } + OrderConstIterator iterator(const Cell& s) const { return bmi::project(cells_, cells_.find(s)); } + size_t index(const Cell& s) const; + bool contains(const Cell& s) const { return cells_.find(s) != cells_.end(); } + + void push_back(const Cell& s) { cells_.template get().push_back(s); } + void push_back(Cell&& s) { cells_.template get().push_back(s); } + + void replace(size_t i, const Cell& s) { cells_.template get().replace(begin() + i, s); } + + // return index of the cell, adding it, if necessary + size_t add(const Cell& s) { size_t i = (iterator(s) - begin()); if (i == size()) emplace_back(s); return i; } + size_t add(Cell&& s) { size_t i = (iterator(s) - begin()); if (i == size()) emplace_back(std::move(s)); return i; } + + template + void emplace_back(Args&&... args) { cells_.template get().emplace_back(std::forward(args)...); } + + template> + void sort(const Cmp& cmp = Cmp()) { cells_.template get().sort(cmp); } + + void rearrange(const std::vector& indices); + + OrderConstIterator begin() const { return cells_.template get().begin(); } + OrderConstIterator end() const { return cells_.template get().end(); } + OrderIterator begin() { return cells_.template get().begin(); } + OrderIterator end() { return cells_.template get().end(); } + size_t size() const { return cells_.size(); } + void clear() { return Container().swap(cells_); } + + Cell& back() { return const_cast(cells_.template get().back()); } + const Cell& back() const { return cells_.template get().back(); } + + private: + Container cells_; +}; + +} + +template +size_t +dionysus::Filtration:: +index(const Cell& s) const +{ + auto it = iterator(s); + if (checked_index && it == end()) + { + std::ostringstream oss; + oss << "Trying to access non-existent cell: " << s; + throw std::runtime_error(oss.str()); + } + return it - begin(); +} + +template +void +dionysus::Filtration:: +rearrange(const std::vector& indices) +{ + std::vector> references; references.reserve(indices.size()); + for (size_t i : indices) + references.push_back(std::cref((*this)[i])); + cells_.template get().rearrange(references.begin()); +} + + +#endif diff --git a/src/dionysus/dionysus/omni-field-persistence.h b/src/dionysus/dionysus/omni-field-persistence.h new file mode 100755 index 0000000..afec237 --- /dev/null +++ b/src/dionysus/dionysus/omni-field-persistence.h @@ -0,0 +1,135 @@ +#ifndef DIONYSUS_OMNI_FIELD_REDUCTION_H +#define DIONYSUS_OMNI_FIELD_REDUCTION_H + +#include +#include + +#include "reduction.h" // for unpaired +#include "fields/q.h" +#include "fields/zp.h" +#include "chain.h" + +namespace dionysus +{ + +template, class Q_ = ::dionysus::Q<>, class Zp_ = ::dionysus::ZpField> +class OmniFieldPersistence +{ + public: + using Index = Index_; + using Q = Q_; + using Field = Q; + using Comparison = Comparison_; + + using BaseElement = typename Q::BaseElement; + using Zp = Zp_; + using Zps = std::unordered_map; + + using QElement = typename Q::Element; + using QEntry = ChainEntry; + using QChain = std::vector; + + using ZpElement = typename Zp::Element; + using ZpEntry = ChainEntry; + using ZpChain = std::vector; + + using QChains = std::vector; + using ZpChains = std::unordered_map>; + + using QLows = std::unordered_map; + using ZpLows = std::unordered_map>; + + using QPairs = std::vector; + using ZpPairs = std::unordered_map>; + + using Factors = std::vector; + + const Field& field() const { return q_; } + + void sort(QChain& c) { std::sort(c.begin(), c.end(), + [this](const QEntry& e1, const QEntry& e2) + { return this->cmp_(e1.index(), e2.index()); }); } + + template + void add(const ChainRange& chain) { return add(QChain(std::begin(chain), std::end(chain))); } + void add(QChain&& chain); + + void reserve(size_t s) { q_chains_.reserve(s); q_pairs_.reserve(s); } + size_t size() const { return q_pairs_.size(); } + + void reduce(ZpChain& zp_chain, BaseElement p); + ZpChain convert(const QChain& c, const Zp& field) const; + bool special(Index i, BaseElement p) const { auto it = zp_chains_.find(i); if (it == zp_chains_.end()) return false; if (it->second.find(p) == it->second.end()) return false; return true; } + + const Zp& zp(BaseElement p) const { auto it = zps_.find(p); if (it != zps_.end()) return it->second; return zps_.emplace(p, Zp(p)).first->second; } + + static Factors factor(BaseElement x); + + const QChains& q_chains() const { return q_chains_; } + const ZpChains& zp_chains() const { return zp_chains_; } + + // This is a bit of a hack; it takes advantage of the fact that zp(p) + // generates field on-demand and memoizes them. So there is an entry in + // zps_ only if something special happened over the prime. + Factors primes() const { Factors result; result.reserve(zps_.size()); for (auto& x : zps_) result.push_back(x.first); return result; } + + // TODO: no skip support for now + bool skip(Index) const { return false; } + void add_skip() {} + void set_skip(Index, bool flag = true) {} + + Index pair(Index i, BaseElement p) const; + void set_pair(Index i, Index j); + void set_pair(Index i, Index j, BaseElement p); + static const Index unpaired() { return Reduction::unpaired; } + + private: + QChains q_chains_; + ZpChains zp_chains_; + + QLows q_lows_; + ZpLows zp_lows_; + + QPairs q_pairs_; + ZpPairs zp_pairs_; + + Q q_; + mutable Zps zps_; + + Comparison cmp_; +}; + +// Make OmniFieldPersistence act like a ReducedMatrix (e.g., for the purpose of constructing a persistence diagram) +template +struct PrimeAdapter +{ + using Persistence = OmniFieldPersistence; + using Prime = typename Persistence::BaseElement; + using Index = typename Persistence::Index; + + PrimeAdapter(const Persistence& persistence, Prime p): + persistence_(persistence), p_(p) {} + + bool skip(Index i) const { return persistence_.skip(i); } + + size_t size() const { return persistence_.size(); } + Index pair(Index i) const { return persistence_.pair(i, p_); } + static const Index unpaired() { return Persistence::unpaired(); } + + const Persistence& persistence_; + Prime p_; +}; + +template +PrimeAdapter +prime_adapter(const OmniFieldPersistence& persistence, + typename PrimeAdapter::Prime p) +{ + return PrimeAdapter(persistence, p); +} + +} // dionysus + +#include "omni-field-persistence.hpp" + +#endif diff --git a/src/dionysus/dionysus/omni-field-persistence.hpp b/src/dionysus/dionysus/omni-field-persistence.hpp new file mode 100755 index 0000000..68d5fbe --- /dev/null +++ b/src/dionysus/dionysus/omni-field-persistence.hpp @@ -0,0 +1,250 @@ +template +void +dionysus::OmniFieldPersistence:: +add(QChain&& chain) +{ + sort(chain); + + q_chains_.emplace_back(std::move(chain)); + q_pairs_.emplace_back(unpaired()); + Index i = q_chains_.size() - 1; + + QChain& c = q_chains_.back(); + + auto reduce = [this,&c,i](BaseElement p) + { + auto zp_chain = convert(c, zp(p)); + + this->reduce(zp_chain, p); + + if (!zp_chain.empty()) + { + auto l = zp_chain.back().index(); + zp_lows_[l].emplace(p,i); + set_pair(l,i,p); + } + + zp_chains_[i].emplace(p, std::move(zp_chain)); // empty chain is still a valid indicator that we don't need to bother with this field + }; + + // reduce + auto entry_cmp = [this](const QEntry& e1, const QEntry& e2) { return this->cmp_(e1.index(), e2.index()); }; + while (!c.empty()) + { + auto& low = c.back(); + + auto e = low.element(); + auto l = low.index(); + assert(!q_.is_zero(e)); + if (e != q_.id()) + { + auto factors = factor(q_.numerator(e)); + for (auto p : factors) + { + if (!special(i, p)) // there is already a dedicated column over p + reduce(p); + } + } + + auto it_zp = zp_lows_.find(l); + if (it_zp != zp_lows_.end()) + for (auto& x : it_zp->second) + { + auto p = x.first; + if (!special(i,p)) + reduce(p); + } + + auto it_q = q_lows_.find(l); + if (it_q != q_lows_.end()) + { + Index j = it_q->second; + + // add the primes from j to i + auto it_zp = zp_chains_.find(j); + if (it_zp != zp_chains_.end()) + for (auto& x : it_zp->second) + { + auto p = x.first; + if (!special(i,p)) + reduce(p); + } + + // reduce over Q + auto j_chain = q_chains_[j]; + auto j_e = j_chain.back().element(); + + auto m = q_.neg(q_.div(e,j_e)); + Chain::addto(c, m, j_chain, q_, entry_cmp); + assert(c.empty() || !q_.is_zero(c.back().element())); + } else + { + q_lows_.emplace(l,i); + set_pair(l,i); + break; + } + } +} + +template +void +dionysus::OmniFieldPersistence:: +reduce(ZpChain& zp_chain, BaseElement p) +{ + auto& field = zp(p); + + auto entry_cmp = [this](const ZpEntry& e1, const ZpEntry& e2) { return this->cmp_(e1.index(), e2.index()); }; + + while (!zp_chain.empty()) + { + auto& low = zp_chain.back(); + auto j = low.index(); + + auto it = zp_lows_.find(j); + if (it != zp_lows_.end()) + { + auto it2 = it->second.find(p); + if (it2 != it->second.end()) + { + const ZpChain& co = zp_chains_[it2->second][p]; + + auto m = field.neg(field.div(low.element(), co.back().element())); + assert(m < p); + Chain::addto(zp_chain, m, co, field, entry_cmp); + continue; + } + } + + auto qit = q_lows_.find(j); + if (qit == q_lows_.end() || special(qit->second, p)) // no valid pivot over Q + return; + + // TODO: this could be optimized (add and convert on the fly) + auto& q_chain = q_chains_[qit->second]; + assert(q_chain.empty() || !q_.is_zero(q_chain.back().element())); + + auto co = convert(q_chain, field); + auto m = field.neg(field.div(low.element(), co.back().element())); + Chain::addto(zp_chain, m, co, field, entry_cmp); + + assert(!zp_chain.empty() || zp_chain.back().index() != j); + } +} + +template +typename dionysus::OmniFieldPersistence::ZpChain +dionysus::OmniFieldPersistence:: +convert(const QChain& c, const Zp& field) const +{ + ZpChain result; + result.reserve(c.size()); + auto p = field.prime(); + for (auto& x : c) + { + auto num = q_.numerator(x.element()) % p; + if (num != 0) + { + while (num < 0) num += p; + auto denom = q_.denominator(x.element()) % p; + while (denom < 0) denom += p; + assert(denom % p != 0); + result.emplace_back(field.div(num, denom), x.index()); + } + } + return result; +} + + +template +typename dionysus::OmniFieldPersistence::Factors +dionysus::OmniFieldPersistence:: +factor(BaseElement x) +{ + if (x < 0) + x = -x; + Factors result; + + if (Q::is_prime(x)) + { + result.push_back(x); + return result; + } + + BaseElement p { 2 }; + while (p*p <= x) + { + if (x % p == 0) + { + result.push_back(p); + do { x /= p; } while (x % p == 0); + if (Q::is_prime(x)) + { + result.push_back(x); + break; + } + } + ++p; + } + if (x > 1) + result.push_back(x); + + return result; +} + +template +typename dionysus::OmniFieldPersistence::Index +dionysus::OmniFieldPersistence:: +pair(Index i, BaseElement p) const +{ + if (p == 1) + return q_pairs_[i]; + else + { + auto it = zp_pairs_.find(p); + if (it == zp_pairs_.end()) + return q_pairs_[i]; + else + { + auto pit = it->second.find(i); + if (pit == it->second.end()) + return q_pairs_[i]; + else + return pit->second; + } + } +} + +template +void +dionysus::OmniFieldPersistence:: +set_pair(Index i, Index j, BaseElement p) +{ + auto& pairs = zp_pairs_[p]; + pairs[i] = j; + pairs[j] = i; +} + +template +void +dionysus::OmniFieldPersistence:: +set_pair(Index i, Index j) +{ + q_pairs_[i] = j; + q_pairs_[j] = i; + + auto it = zp_chains_.find(j); + if (it == zp_chains_.end()) + return; + + auto& chains = it->second; + for (auto& x : chains) + { + auto p = x.first; + auto& chain = x.second; + if (chain.empty()) + { + zp_pairs_[p][j] = unpaired(); + zp_pairs_[p][i] = unpaired(); + } + } +} diff --git a/src/dionysus/dionysus/ordinary-persistence.h b/src/dionysus/dionysus/ordinary-persistence.h new file mode 100755 index 0000000..5f26bd2 --- /dev/null +++ b/src/dionysus/dionysus/ordinary-persistence.h @@ -0,0 +1,64 @@ +#ifndef DIONYSUS_ORDINARY_PERSISTENCE_H +#define DIONYSUS_ORDINARY_PERSISTENCE_H + +#include "reduced-matrix.h" + +namespace dionysus +{ + +/* Move this into a ReducedMatrix class */ + +// Ordinary D -> R reduction +template, + template class... Visitors> +using OrdinaryPersistence = ReducedMatrix; + +// No negative optimization +template> +struct NoNegative +{ + template + struct Visitor: public EmptyVisitor + { + template + void chain_initialized(Self* matrix, Chain& c) + { + for (auto cur = std::begin(c); cur != std::end(c); ++cur) + { + Index i = cur->index(); + Index p = matrix->pair(i); + if (!(p == Self::unpaired() || (*matrix)[i].empty())) + c.erase(cur--); + } + } + }; + + template + using V2 = EmptyVisitor; +}; + +template, + template class... Visitors> +using OrdinaryPersistenceNoNegative = ReducedMatrix::template Visitor, + Visitors...>; + +// TODO: add clearing optimization (possibly bake it into the code itself) + +template, + template class... Visitors> +using FastPersistence = ReducedMatrix::template Visitor, + //Clearing::template Visitor, // FIXME + Visitors...>; + + +} + +#endif diff --git a/src/dionysus/dionysus/pair-recorder.h b/src/dionysus/dionysus/pair-recorder.h new file mode 100755 index 0000000..81c066b --- /dev/null +++ b/src/dionysus/dionysus/pair-recorder.h @@ -0,0 +1,78 @@ +#ifndef DIONYSUS_PAIR_RECORDER_H +#define DIONYSUS_PAIR_RECORDER_H + +namespace dionysus +{ + +template +struct PairRecorder: public Persistence_ +{ + typedef Persistence_ Persistence; + typedef typename Persistence::Index Index; + + + using Persistence::Persistence; + + template + Index add(const ChainRange& chain) + { + Index p = Persistence::add(chain); + pairs_.push_back(p); + if (p != unpaired()) + pairs_[p] = pairs_.size() - 1; + + return p; + } + + Index pair(Index i) const { return pairs_[i]; } + + void resize(size_t s) { Persistence::resize(s); pairs_.resize(s, unpaired()); } + size_t size() const { return pairs_.size(); } + static const Index unpaired() { return Reduction::unpaired; } + + std::vector pairs_; +}; + +template +struct PairChainRecorder: public PairRecorder +{ + using Persistence = Persistence_; + using Parent = PairRecorder; + using Index = typename Persistence_::Index; + using Chain = typename Persistence_::Chain; + + using Parent::Parent; + + template + Index add(const ChainRange& chain) + { + auto p_chain = Persistence::add(chain, keep_cocycles); + Index p = std::get<0>(p_chain); + + pairs_.push_back(p); + chains_.emplace_back(); + + if (p != unpaired()) + { + pairs_[p] = pairs_.size() - 1; + chains_[p] = std::move(std::get<1>(p_chain)); + } + + return p; + } + + using Parent::unpaired; + + Index pair(Index i) const { return pairs_[i]; } + const Chain& chain(Index i) const { return chains_[i]; } // chain that dies at i + void resize(size_t s) { Parent::resize(s); chains_.resize(s); } + + std::vector chains_; + using Parent::pairs_; + + bool keep_cocycles = true; +}; + +} + +#endif diff --git a/src/dionysus/dionysus/reduced-matrix.h b/src/dionysus/dionysus/reduced-matrix.h new file mode 100755 index 0000000..39ced5d --- /dev/null +++ b/src/dionysus/dionysus/reduced-matrix.h @@ -0,0 +1,166 @@ +#ifndef DIONYSUS_REDUCED_MATRIX_H +#define DIONYSUS_REDUCED_MATRIX_H + +#include +#include + +#include "chain.h" +#include "reduction.h" + +namespace dionysus +{ + +template, template class... Visitors> +class ReducedMatrix +{ + public: + typedef ReducedMatrix Self; + + typedef Field_ Field; + typedef Index_ Index; + typedef Comparison_ Comparison; + + typedef std::tuple...> VisitorsTuple; + template + using Visitor = std::tuple_element; + + typedef typename Field::Element FieldElement; + typedef ChainEntry Entry; + typedef std::vector Chain; + + typedef std::vector Chains; + typedef std::vector Indices; + typedef std::vector SkipFlags; + + public: + ReducedMatrix(const Field& field): + field_(field) {} + + ReducedMatrix(const Field& field, + const Comparison& cmp, + const Visitors&... visitors): + field_(field), + cmp_(cmp), + visitors_(visitors...) {} + + ReducedMatrix(Field&& field, + Comparison&& cmp, + Visitors&&... visitors): + field_(std::move(field)), + cmp_(std::move(cmp)), + visitors_(visitors...) {} + + template class... OtherVisitors> + ReducedMatrix(ReducedMatrix&& other): + field_(other.field_), + cmp_(other.cmp_), + reduced_(std::move(other.reduced_)), + pairs_(std::move(other.pairs_)) {} + + template + Index add(const ChainRange& chain) { return add(Chain(std::begin(chain), std::end(chain))); } + Index add(Chain&& chain); + + template + void set(Index i, const ChainRange& chain) { return set(i, Chain(std::begin(chain), std::end(chain))); } + void set(Index i, Chain&& chain); + + Index reduce(Index i); + Index reduce(Chain& c) { return reduce(c, reduced_, pairs_); } + template + Index reduce(Chain& c, const ChainsLookup& chains, const LowLookup& low); + + Index reduce_upto(Index i); // TODO + + size_t size() const { return pairs_.size(); } + void clear() { Chains().swap(reduced_); Indices().swap(pairs_); } + + void sort(Chain& c) { std::sort(c.begin(), c.end(), [this](const Entry& e1, const Entry& e2) { return this->cmp_(e1.index(), e2.index()); }); } + + const Chain& operator[](Index i) const { return reduced_[i]; } + Index pair(Index i) const { return pairs_[i]; } + void set_pair(Index i, Index j) { pairs_[i] = j; pairs_[j] = i; } + + Chain& column(Index i) { return reduced_[i]; } + + bool skip(Index i) const { return skip_[i]; } + void add_skip(); + void set_skip(Index i, bool flag = true) { skip_[i] = flag; } + + const Field& field() const { return field_; } + const Comparison& cmp() const { return cmp_; } + void reserve(size_t s) { reduced_.reserve(s); pairs_.reserve(s); } + void resize(size_t s); + + const Chains& columns() const { return reduced_; } + + template + Visitor& visitor() { return std::get(visitors_); } + + static const Index unpaired() { return Reduction::unpaired; } + + private: + template class... Vs> + friend class ReducedMatrix; // let's all be friends + + public: + // Visitors::chain_initialized(c) + template + typename std::enable_if::type + visitors_chain_initialized(Chain& c) {} + + template + typename std::enable_if::type + visitors_chain_initialized(Chain& c) { std::get(visitors_).chain_initialized(this, c); visitors_chain_initialized(c); } + + // Visitors::addto(m, cl) + template + typename std::enable_if::type + visitors_addto(FieldElement m, Index cl) {} + + template + typename std::enable_if::type + visitors_addto(FieldElement m, Index cl) { std::get(visitors_).addto(this, m, cl); visitors_addto(m, cl); } + + // Visitors::reduction_finished(m, cl) + template + typename std::enable_if::type + visitors_reduction_finished() {} + + template + typename std::enable_if::type + visitors_reduction_finished() { std::get(visitors_).reduction_finished(this); visitors_reduction_finished(); } + + private: + Field field_; + Comparison cmp_; + Chains reduced_; // matrix R + Indices pairs_; + SkipFlags skip_; // indicates whether the column should be skipped (e.g., for relative homology) + VisitorsTuple visitors_; +}; + +/* Visitors */ + +// The prototypical visitor. Others may (and probably should) inherit from it. +template +struct EmptyVisitor +{ + EmptyVisitor() = default; + + template + EmptyVisitor(const EmptyVisitor&) {} + + + template + void chain_initialized(Self*, Chain& c) {} + + void addto(Self*, typename Field::Element m, Index cl) {} + void reduction_finished(Self*) {} +}; + +} + +#include "reduced-matrix.hpp" + +#endif diff --git a/src/dionysus/dionysus/reduced-matrix.hpp b/src/dionysus/dionysus/reduced-matrix.hpp new file mode 100755 index 0000000..3e4aca8 --- /dev/null +++ b/src/dionysus/dionysus/reduced-matrix.hpp @@ -0,0 +1,78 @@ +template class... V> +void +dionysus::ReducedMatrix:: +resize(size_t s) +{ + reduced_.resize(s); + pairs_.resize(s, unpaired()); + skip_.resize(s, false); +} + +template class... V> +typename dionysus::ReducedMatrix::Index +dionysus::ReducedMatrix:: +add(Chain&& chain) +{ + // TODO: skip the computation entirely if we already know this is positive (in case of the clearing optimization) + Index i = pairs_.size(); + pairs_.emplace_back(unpaired()); + reduced_.emplace_back(); + skip_.push_back(false); + + set(i, std::move(chain)); + + return reduce(i); +} + +template class... V> +void +dionysus::ReducedMatrix:: +add_skip() +{ + pairs_.emplace_back(unpaired()); + reduced_.emplace_back(); + skip_.push_back(true); +} + +template class... V> +void +dionysus::ReducedMatrix:: +set(Index i, Chain&& c) +{ + sort(c); + visitors_chain_initialized(c); + reduced_[i] = std::move(c); +} + +template class... V> +typename dionysus::ReducedMatrix::Index +dionysus::ReducedMatrix:: +reduce(Index i) +{ + Chain& c = column(i); + Index pair = reduce(c); + + if (pair != unpaired()) + pairs_[pair] = i; + + pairs_[i] = pair; + visitors_reduction_finished<>(); + + return pair; +} + +template class... V> +template +typename dionysus::ReducedMatrix::Index +dionysus::ReducedMatrix:: +reduce( Chain& c, + const ChainsLookup& chains, + const LowLookup& lows) +{ + auto entry_cmp = [this](const Entry& e1, const Entry& e2) { return this->cmp_(e1.index(), e2.index()); }; + return Reduction::reduce(c, chains, lows, field_, + [this](FieldElement m, Index cl) + { this->visitors_addto<>(m, cl); }, + entry_cmp); +} diff --git a/src/dionysus/dionysus/reduction.h b/src/dionysus/dionysus/reduction.h new file mode 100755 index 0000000..152f1bb --- /dev/null +++ b/src/dionysus/dionysus/reduction.h @@ -0,0 +1,107 @@ +#ifndef DIONYSUS_REDUCTION_H +#define DIONYSUS_REDUCTION_H + +#include +#include +#include +#include "chain.h" + +namespace dionysus +{ + +namespace detail +{ + +template +struct Unpaired +{ static constexpr Index value() { return std::numeric_limits::max(); } }; + +} + +template +struct Reduction +{ + typedef Index_ Index; + + template + using AddtoVisitor = std::function; + + template + struct CallToSub; + + static const Index unpaired; + + template> + static + Index reduce(Chain1& c, + const ChainsLookup& chains, + const LowLookup& lows, + const Field& field, + const AddtoVisitor& visitor = [](typename Field::Element, Index) {}, + const Comparison& cmp = Comparison()) + { + typedef typename Field::Element FieldElement; + + while (!c.empty()) + { + //auto& low = c.back(); + auto& low = *(std::prev(c.end())); + Index l = low.index(); + Index cl = lows(l); + if (cl == unpaired) + return l; + else + { + // Reduce further + auto& co = chains(cl); + auto& co_low = co.back(); + FieldElement m = field.neg(field.div(low.element(), co_low.element())); + // c += m*co + Chain::addto(c, m, co, field, cmp); + visitor(m, cl); + } + } + return unpaired; + } + + template> + static + Index reduce(Chain1& c, + const std::vector& chains, + const std::vector& lows, + const Field& field, + const AddtoVisitor& visitor = [](typename Field::Element, Index) {}, + const Comparison& cmp = Comparison()) + { + return reduce(c, + CallToSub(chains), + CallToSub(lows), + field, visitor, cmp); + } + + // This is a work-around a bug in GCC (should really be a lambda function) + template + struct CallToSub + { + CallToSub(const std::vector& items_): + items(items_) {} + const Item& operator()(Index i) const { return items[i]; } + const std::vector& items; + }; +}; + + +template +const Index +Reduction::unpaired = detail::Unpaired::value(); + +} + +#endif diff --git a/src/dionysus/dionysus/relative-homology-zigzag.h b/src/dionysus/dionysus/relative-homology-zigzag.h new file mode 100755 index 0000000..167a327 --- /dev/null +++ b/src/dionysus/dionysus/relative-homology-zigzag.h @@ -0,0 +1,84 @@ +#ifndef RELATIVE_HOMOLOGY_ZIGZAG_H +#define RELATIVE_HOMOLOGY_ZIGZAG_H + +#include +#include + +#include "zigzag-persistence.h" + +namespace dionysus +{ + +namespace ba = boost::adaptors; + +template> +class RelativeHomologyZigzag +{ + public: + typedef Field_ Field; + typedef Index_ Index; + typedef Comparison_ Comparison; + + typedef ZigzagPersistence ZZP; + typedef typename ZZP::IndexChain IndexChain; + typedef typename ZZP::FieldElement FieldElement; + typedef typename IndexChain::value_type ChainEntry; + + + typedef Comparison Cmp; + + RelativeHomologyZigzag(const Field& field, + const Comparison& cmp = Comparison()): + zzp_(field, cmp) + { + zzp_.add( IndexChain() ); // vertex w + ++zzp_op_; + ++zzp_cell_; + } + + template + void add_both(const ChainRange& chain); + + void remove_both(Index cell); + + // index of the absolute cell; chain = its boundary + template + Index add(Index cell, const ChainRange& chain); // add to the relative part + + Index remove(Index cell); // remove from the relative part + + const Field& field() const { return zzp_.field(); } + const Cmp& cmp() const { return zzp_.cmp(); } + + size_t alive_size() const { return zzp_.alive_size() - 1; } // -1 for the cone vertex + + static + const Index unpaired() { return ZZP::unpaired(); } + + private: + template + IndexChain relative_chain(Index cell, const ChainRange& chain) const; + + template + IndexChain absolute_chain(const ChainRange& chain) const; + + Index abs_index(Index idx) const { return absolute_.left.find(idx)->second; } + Index rel_index(Index idx) const { return relative_.left.find(idx)->second; } + Index decode_pair(Index pair); + + private: + ZZP zzp_; // underlying (cone) implementation + boost::bimap absolute_; // bimap between our cells and zzp absolute cells + boost::bimap relative_; // bimap between our cells and zzp relative cells + std::unordered_map op_map_; // map from zzp_op to our op + Index op_ = 0, + zzp_op_ = 0, + cell_ = 0, + zzp_cell_ = 0; +}; + +} + +#include "relative-homology-zigzag.hpp" + +#endif diff --git a/src/dionysus/dionysus/relative-homology-zigzag.hpp b/src/dionysus/dionysus/relative-homology-zigzag.hpp new file mode 100755 index 0000000..4998071 --- /dev/null +++ b/src/dionysus/dionysus/relative-homology-zigzag.hpp @@ -0,0 +1,122 @@ +template +template +void +dionysus::RelativeHomologyZigzag:: +add_both(const ChainRange& chain) +{ + zzp_.add(absolute_chain(chain)); + op_map_.insert( { zzp_op_++, op_ } ); + absolute_.left.insert( { cell_, zzp_cell_++ } ); + + zzp_.add(relative_chain(cell_, chain)); + op_map_.insert( { zzp_op_++, op_ } ); + relative_.left.insert( { cell_, zzp_cell_++ } ); + + cell_++; + op_++; +} + +template +void +dionysus::RelativeHomologyZigzag:: +remove_both(Index cell) +{ + Index abs_cell = absolute_.left.find(cell)->second; + Index rel_cell = relative_.left.find(cell)->second; + + zzp_.remove(rel_cell); + zzp_.remove(abs_cell); + + absolute_.left.erase(cell); + relative_.left.erase(cell); + + op_map_.insert( { zzp_op_++, op_ } ); + op_map_.insert( { zzp_op_++, op_ } ); + + op_++; +} + +template +template +typename dionysus::RelativeHomologyZigzag::Index +dionysus::RelativeHomologyZigzag:: +add(Index cell, const ChainRange& chain) +{ + Index pair = zzp_.add(relative_chain(cell, chain)); + op_map_.insert( { zzp_op_++, op_++ } ); + relative_.left.insert( { cell, zzp_cell_++ } ); + + return decode_pair(pair); +} + + +template +typename dionysus::RelativeHomologyZigzag::Index +dionysus::RelativeHomologyZigzag:: +decode_pair(Index pair) +{ + if (pair == unpaired()) + return pair; + + Index decoded = op_map_.find(pair)->second; + op_map_.erase(pair); + return decoded; +} + +template +template +typename dionysus::RelativeHomologyZigzag::IndexChain +dionysus::RelativeHomologyZigzag:: +absolute_chain(const ChainRange& chain) const +{ + IndexChain res; + for (const auto& e : chain) + res.push_back(ChainEntry(e.element(), abs_index(e.index()))); + return res; +} + +template +template +typename dionysus::RelativeHomologyZigzag::IndexChain +dionysus::RelativeHomologyZigzag:: +relative_chain(Index cell, const ChainRange& chain) const +{ + // NB: to compute the signs correctly, + // this assumes that the cone vertex w is the last vertex in some total order + + typedef typename IndexChain::value_type ChainEntry; + + IndexChain res; + if (!chain.empty()) + { + for (const auto& e : chain) + res.push_back(ChainEntry(e.element(), rel_index(e.index()))); + + FieldElement a = field().id(); + if (chain.size() % 2 == 0) // TODO: double-check + a = field().neg(a); + res.push_back(ChainEntry(a, abs_index(cell))); // add the base space cell + } else + { + res.reserve(2); + res.push_back(ChainEntry(field().id(), abs_index(cell))); + res.push_back(ChainEntry(field().neg(field().id()), 0)); + } + return res; +} + + +template +typename dionysus::RelativeHomologyZigzag::Index +dionysus::RelativeHomologyZigzag:: +remove(Index cell) +{ + Index rel_cell = rel_index(cell); + Index pair = zzp_.remove(rel_cell); + pair = decode_pair(pair); + + op_map_.insert( { zzp_op_++, op_++ } ); + relative_.left.erase(cell); + + return pair; +} diff --git a/src/dionysus/dionysus/rips.h b/src/dionysus/dionysus/rips.h new file mode 100755 index 0000000..717ff65 --- /dev/null +++ b/src/dionysus/dionysus/rips.h @@ -0,0 +1,147 @@ +#ifndef DIONYSUS_RIPS_H +#define DIONYSUS_RIPS_H + +#include +#include + +#include + +#include "simplex.h" + +namespace dionysus +{ + +/** + * Rips class + * + * Class providing basic operations to work with Rips complexes. It implements Bron-Kerbosch algorithm, + * and provides simple wrappers for various functions. + * + * Distances_ is expected to define types IndexType and DistanceType as well as + * provide operator()(...) which given two IndexTypes should return + * the distance between them. There should be methods begin() and end() + * for iterating over IndexTypes as well as a method size(). + */ +template > +class Rips +{ + public: + typedef Distances_ Distances; + typedef typename Distances::IndexType IndexType; + typedef typename Distances::DistanceType DistanceType; + + typedef Simplex_ Simplex; + typedef typename Simplex::Vertex Vertex; // should be the same as IndexType + typedef std::vector VertexContainer; + + typedef short unsigned Dimension; + + class Evaluator; + class Comparison; + + public: + Rips(const Distances& distances): + distances_(distances) {} + + // Calls functor f on each simplex in the k-skeleton of the Rips complex + template + void generate(Dimension k, DistanceType max, const Functor& f, + Iterator candidates_begin, Iterator candidates_end) const; + + // Calls functor f on all the simplices of the Rips complex that contain the given vertex v + template + void vertex_cofaces(IndexType v, Dimension k, DistanceType max, const Functor& f, + Iterator candidates_begin, Iterator candidates_end) const; + + // Calls functor f on all the simplices of the Rips complex that contain the given edge [u,v] + template + void edge_cofaces(IndexType u, IndexType v, Dimension k, DistanceType max, const Functor& f, + Iterator candidates_begin, Iterator candidates_end) const; + + // Calls functor f on all the simplices of the Rips complex that contain the given Simplex s + // (unlike the previous methods it does not call the functor on the Simplex s itself) + template + void cofaces(const Simplex& s, Dimension k, DistanceType max, const Functor& f, + Iterator candidates_begin, Iterator candidates_end) const; + + + /* No Iterator argument means Iterator = IndexType and the range is [distances().begin(), distances().end()) */ + template + void generate(Dimension k, DistanceType max, const Functor& f) const + { generate(k, max, f, boost::make_counting_iterator(distances().begin()), boost::make_counting_iterator(distances().end())); } + + template + void vertex_cofaces(IndexType v, Dimension k, DistanceType max, const Functor& f) const + { vertex_cofaces(v, k, max, f, boost::make_counting_iterator(distances().begin()), boost::make_counting_iterator(distances().end())); } + + template + void edge_cofaces(IndexType u, IndexType v, Dimension k, DistanceType max, const Functor& f) const + { edge_cofaces(u, v, k, max, f, boost::make_counting_iterator(distances().begin()), boost::make_counting_iterator(distances().end())); } + + template + void cofaces(const Simplex& s, Dimension k, DistanceType max, const Functor& f) const + { cofaces(s, k, max, f, boost::make_counting_iterator(distances().begin()), boost::make_counting_iterator(distances().end())); } + + + const Distances& distances() const { return distances_; } + DistanceType max_distance() const; + + DistanceType distance(const Simplex& s1, const Simplex& s2) const; + + + template + static void bron_kerbosch(VertexContainer& current, + const VertexContainer& candidates, + typename VertexContainer::const_iterator excluded, + Dimension max_dim, + const NeighborTest& neighbor, + const Functor& functor, + bool check_initial = true); + + protected: + const Distances& distances_; +}; + +template +class Rips::Evaluator: public std::unary_function +{ + public: + typedef Simplex_ Simplex; + + Evaluator(const Distances& distances): + distances_(distances) {} + + DistanceType operator()(const Simplex& s) const; + + protected: + const Distances& distances_; +}; + +template +class Rips::Comparison: public std::binary_function +{ + public: + typedef Simplex_ Simplex; + + Comparison(const Distances& distances): + eval_(distances) {} + + bool operator()(const Simplex& s1, const Simplex& s2) const + { + DistanceType e1 = eval_(s1), + e2 = eval_(s2); + if (e1 == e2) + return s1.dimension() < s2.dimension(); + + return e1 < e2; + } + + protected: + Evaluator eval_; +}; + +} + +#include "rips.hpp" + +#endif // DIONYSUS_RIPS_H diff --git a/src/dionysus/dionysus/rips.hpp b/src/dionysus/dionysus/rips.hpp new file mode 100755 index 0000000..2fdda34 --- /dev/null +++ b/src/dionysus/dionysus/rips.hpp @@ -0,0 +1,162 @@ +#include +#include +#include +#include + +#include +#include + +template +template +void +dionysus::Rips:: +generate(Dimension k, DistanceType max, const Functor& f, Iterator bg, Iterator end) const +{ + auto neighbor = [this, max](Vertex u, Vertex v) { return this->distances()(u,v) <= max; }; + + // current = empty + // candidates = everything + VertexContainer current; + VertexContainer candidates(bg, end); + bron_kerbosch(current, candidates, std::prev(candidates.begin()), k, neighbor, f); +} + +template +template +void +dionysus::Rips:: +vertex_cofaces(IndexType v, Dimension k, DistanceType max, const Functor& f, Iterator bg, Iterator end) const +{ + auto neighbor = [this, max](Vertex u, Vertex v) { return this->distances()(u,v) <= max; }; + + // current = [v] + // candidates = everything - [v] + VertexContainer current; current.push_back(v); + VertexContainer candidates; + for (Iterator cur = bg; cur != end; ++cur) + if (*cur != v && neighbor(v, *cur)) + candidates.push_back(*cur); + + bron_kerbosch(current, candidates, std::prev(candidates.begin()), k, neighbor, f); +} + +template +template +void +dionysus::Rips:: +edge_cofaces(IndexType u, IndexType v, Dimension k, DistanceType max, const Functor& f, Iterator bg, Iterator end) const +{ + auto neighbor = [this, max](Vertex u, Vertex v) { return this->distances()(u,v) <= max; }; + + // current = [u,v] + // candidates = everything - [u,v] + VertexContainer current; current.push_back(u); current.push_back(v); + + VertexContainer candidates; + for (Iterator cur = bg; cur != end; ++cur) + if (*cur != u && *cur != v && neighbor(v,*cur) && neighbor(u,*cur)) + candidates.push_back(*cur); + + bron_kerbosch(current, candidates, std::prev(candidates.begin()), k, neighbor, f); +} + +template +template +void +dionysus::Rips:: +cofaces(const Simplex& s, Dimension k, DistanceType max, const Functor& f, Iterator bg, Iterator end) const +{ + namespace ba = boost::adaptors; + + auto neighbor = [this, max](Vertex u, Vertex v) { return this->distances()(u,v) <= max; }; + + // current = s + VertexContainer current(s.begin(), s.end()); + + // candidates = everything - s that is a neighbor of every vertex in the simplex + VertexContainer candidates; + boost::set_difference(std::make_pair(bg, end) | + ba::filtered([this,&s,&neighbor](Vertex cur) + { for (auto& v : s) + if (!neighbor(v, cur)) + return false; + }), + s, + std::back_inserter(candidates)); + + bron_kerbosch(current, candidates, std::prev(candidates.begin()), k, neighbor, f, false); +} + + +template +template +void +dionysus::Rips:: +bron_kerbosch(VertexContainer& current, + const VertexContainer& candidates, + typename VertexContainer::const_iterator excluded, + Dimension max_dim, + const NeighborTest& neighbor, + const Functor& functor, + bool check_initial) +{ + if (check_initial && !current.empty()) + functor(Simplex(current)); + + if (current.size() == static_cast(max_dim) + 1) + return; + + for (auto cur = std::next(excluded); cur != candidates.end(); ++cur) + { + current.push_back(*cur); + + VertexContainer new_candidates; + for (auto ccur = candidates.begin(); ccur != cur; ++ccur) + if (neighbor(*ccur, *cur)) + new_candidates.push_back(*ccur); + size_t ex = new_candidates.size(); + for (auto ccur = std::next(cur); ccur != candidates.end(); ++ccur) + if (neighbor(*ccur, *cur)) + new_candidates.push_back(*ccur); + excluded = new_candidates.begin() + (ex - 1); + + bron_kerbosch(current, new_candidates, excluded, max_dim, neighbor, functor); + current.pop_back(); + } +} + +template +typename dionysus::Rips::DistanceType +dionysus::Rips:: +distance(const Simplex& s1, const Simplex& s2) const +{ + DistanceType mx = 0; + for (auto a : s1) + for (auto b : s2) + mx = std::max(mx, distances_(a,b)); + return mx; +} + +template +typename dionysus::Rips::DistanceType +dionysus::Rips:: +max_distance() const +{ + DistanceType mx = 0; + for (IndexType a = distances_.begin(); a != distances_.end(); ++a) + for (IndexType b = std::next(a); b != distances_.end(); ++b) + mx = std::max(mx, distances_(a,b)); + return mx; +} + +template +typename dionysus::Rips::DistanceType +dionysus::Rips::Evaluator:: +operator()(const Simplex& s) const +{ + DistanceType mx = 0; + for (auto a = s.begin(); a != s.end(); ++a) + for (auto b = std::next(a); b != s.end(); ++b) + mx = std::max(mx, distances_(*a,*b)); + return mx; +} diff --git a/src/dionysus/dionysus/row-reduction.h b/src/dionysus/dionysus/row-reduction.h new file mode 100755 index 0000000..e2481ce --- /dev/null +++ b/src/dionysus/dionysus/row-reduction.h @@ -0,0 +1,54 @@ +#ifndef DIONYSUS_ROW_REDUCTION_H +#define DIONYSUS_ROW_REDUCTION_H + +#include "reduced-matrix.h" + +namespace dionysus +{ + +// Mid-level interface +template, template class... Visitors> +class RowReduction +{ + public: + typedef Field_ Field; + typedef Index_ Index; + typedef Comparison_ Comparison; + + typedef ReducedMatrix Persistence; + + public: + RowReduction(const Field& field): + persistence_(field) {} + + RowReduction(const Field& field, + const Comparison& cmp, + const Visitors&... visitors): + persistence_(field, cmp, visitors...) {} + + template + void operator()(const Filtration& f, const Relative& relative, const ReportPair& report_pair, const Progress& progress); + + template + void operator()(const Filtration& f, const ReportPair& report_pair); + + template + void operator()(const Filtration& f) { return (*this)(f, &no_report_pair); } + + static void no_report_pair(int, Index, Index) {} + static void no_progress() {} + + const Persistence& + persistence() const { return persistence_; } + Persistence& persistence() { return persistence_; } + + private: + Persistence persistence_; +}; + +} + +#include "row-reduction.hpp" + +#endif + diff --git a/src/dionysus/dionysus/row-reduction.hpp b/src/dionysus/dionysus/row-reduction.hpp new file mode 100755 index 0000000..edb1652 --- /dev/null +++ b/src/dionysus/dionysus/row-reduction.hpp @@ -0,0 +1,103 @@ +#include +namespace ba = boost::adaptors; + +template class... V> +template +void +dionysus::RowReduction:: +operator()(const Filtration& filtration, const ReportPair& report_pair) +{ + using Cell = typename Filtration::Cell; + (*this)(filtration, [](const Cell&) { return false; }, report_pair, &no_progress); +} + +template class... V> +template +void +dionysus::RowReduction:: +operator()(const Filtration& filtration, const Relative& relative, const ReportPair& report_pair, const Progress& progress) +{ + persistence_.resize(filtration.size()); + + typedef typename Persistence::Index Index; + typedef typename Persistence::FieldElement Element; + typedef typename Persistence::Chain Chain; + typedef typename Filtration::Cell Cell; + typedef ChainEntry CellChainEntry; + typedef ChainEntry ChainEntry; + + std::vector rows(persistence_.size()); + + auto& field = persistence_.field(); + + // fill the matrix + Index i = 0; + for(auto& c : filtration) + { + progress(); + + if (relative(c)) + { + persistence_.set_skip(i); + ++i; + continue; + } + + persistence_.set(i, c.boundary(field) | + ba::filtered([relative](const CellChainEntry& e) { return !relative(e.index()); }) | + ba::transformed([this,&filtration](const CellChainEntry& e) + { return ChainEntry(e.element(), filtration.index(e.index())); })); + if (!persistence_[i].empty()) + { + auto& x = persistence_[i].back(); + rows[x.index()].emplace_back(x.element(),i); + } + ++i; + } + + auto entry_cmp = [this](const ChainEntry& e1, const ChainEntry& e2) { return this->persistence_.cmp()(e1.index(), e2.index()); }; + + // reduce the matrix from the bottom up + for (auto it = rows.rbegin(); it != rows.rend(); ++it) + { + auto& row = *it; + Index r = rows.rend() - it - 1; + + if (row.empty()) + continue; + + // add the first column to every other column + Index c = row.front().index(); + Element e = row.front().element(); + Chain& first = persistence_.column(c); + for (size_t i = 1; i < row.size(); ++i) + { + Index cur_idx = row[i].index(); + Element cur_elem = row[i].element(); + Chain& cur = persistence_.column(cur_idx); + if (cur.empty()) // zeroed out by the clearing optimization + continue; + + Element m = field.neg(field.div(cur_elem, e)); + // cur += m*first + ::dionysus::Chain::addto(cur, m, first, field, entry_cmp); + + // update row + if (!cur.empty()) + { + ChainEntry ce = cur.back(); + auto& new_row = rows[ce.index()]; + new_row.emplace_back(ce.element(), cur_idx); + if (entry_cmp(new_row.back(), new_row.front())) + std::swap(new_row.back(), new_row.front()); + } + } + + persistence_.set_pair(r,c); + report_pair(filtration[r].dimension(), r, c); + + // zero out the corresponding column (the clearing optimization) + persistence_.column(r).clear(); + } +} + diff --git a/src/dionysus/dionysus/simplex.h b/src/dionysus/dionysus/simplex.h new file mode 100755 index 0000000..c974979 --- /dev/null +++ b/src/dionysus/dionysus/simplex.h @@ -0,0 +1,272 @@ +#ifndef DIONYSUS_SIMPLEX_H +#define DIONYSUS_SIMPLEX_H + +#include + +//#include +#include +#include +#include + +#include "chain.h" + +namespace dionysus +{ + +struct Empty {}; + +template +class Simplex +{ + public: + typedef Vertex_ Vertex; + typedef T Data; + typedef std::unique_ptr Vertices; + + template + struct BoundaryChainIterator; + struct BoundaryIterator; + + template + using BoundaryChainRange = boost::iterator_range>; + using BoundaryRange = boost::iterator_range; + + public: + Simplex(const Data& d = Data()): + dim_(-1), data_(d) {} + + Simplex(const std::initializer_list& vertices, + Data&& d = Data()): + Simplex(vertices.size() - 1, vertices.begin(), vertices.end(), std::move(d)) + {} + + Simplex(const std::initializer_list& vertices, + const Data& d): + Simplex(vertices.size() - 1, vertices.begin(), vertices.end(), d) {} + + Simplex(short unsigned dim, Vertices&& vertices, Data&& data = Data()): + dim_(dim), vertices_(std::move(vertices)), data_(std::move(data)) { std::sort(begin(), end()); } + + template + Simplex(const VertexRange& vertices, + Data&& d = Data()): + Simplex(vertices.size() - 1, vertices.begin(), vertices.end(), std::move(d)) + {} + + template + Simplex(const VertexRange& vertices, + const Data& d): + Simplex(vertices.size() - 1, vertices.begin(), vertices.end(), d) {} + + Simplex(const Simplex& other): + Simplex(other.dim_, other.begin(), other.end(), other.data_) {} + Simplex& operator=(const Simplex& other) { dim_ = other.dim_; vertices_ = Vertices(new Vertex[dim_+1]); std::copy(other.begin(), other.end(), begin()); data_ = other.data_; return *this; } + + Simplex(Simplex&& other) noexcept: + dim_(other.dim_), + vertices_(std::move(other.vertices_)), + data_(std::move(other.data_)) {} + Simplex& operator=(Simplex&& other) = default; + + template + Simplex(short unsigned dim, + Iterator b, Iterator e, + Data&& d = Data()): + dim_(dim), + vertices_(new Vertex[dim_+1]), + data_(std::move(d)) { std::copy(b, e, begin()); std::sort(begin(), end()); } + + template + Simplex(short unsigned dim, + Iterator b, Iterator e, + const Data& d): + dim_(dim), + vertices_(new Vertex[dim_+1]), + data_(d) { std::copy(b, e, begin()); std::sort(begin(), end()); } + + short unsigned dimension() const { return dim_; } + + BoundaryRange boundary() const { return BoundaryRange(boundary_begin(), boundary_end()); } + BoundaryIterator boundary_begin() const; + BoundaryIterator boundary_end() const; + + template + BoundaryChainRange + boundary(const Field& field) const { return BoundaryChainRange(boundary_begin(field), boundary_end(field)); } + + template + BoundaryChainIterator + boundary_begin(const Field& field) const; + template + BoundaryChainIterator + boundary_end(const Field& field) const; + + const Vertex* begin() const { return vertices_.get(); } + const Vertex* end() const { return begin() + dim_ + 1; } + size_t size() const { return dim_ + 1; } + + std::pair + range() const { return std::make_pair(begin(), end()); } + + Simplex join(const Vertex& v) const { Vertices vertices(new Vertex[dim_+2]); std::copy(begin(), end(), vertices.get()); vertices[dim_+1] = v; return Simplex(dim_ + 1, std::move(vertices), Data(data_)); } + + bool operator==(const Simplex& other) const { return dim_ == other.dim_ && std::equal(begin(), end(), other.begin()); } + bool operator!=(const Simplex& other) const { return !operator==(other); } + bool operator<(const Simplex& other) const { return dim_ < other.dim_ || (dim_ == other.dim_ && std::lexicographical_compare(begin(), end(), other.begin(), other.end())); } + bool operator>(const Simplex& other) const { return other < (*this); } + + Vertex operator[](short unsigned i) const { return vertices_[i]; } + const Data& data() const { return data_; } + Data& data() { return data_; } + + friend std::ostream& operator<<(std::ostream& out, const Simplex& s) { + out << '<' << *(s.begin()); + for (auto it = s.begin() + 1; it != s.end(); ++it) { + out << ',' << *it; + } + out << '>'; + return out; + } + + private: + Vertex* begin() { return vertices_.get(); } + Vertex* end() { return begin() + dim_ + 1; } + + private: + short unsigned dim_; + //boost::compressed_pair vertices_data_; + Vertices vertices_; + Data data_; // TODO: optimize +}; + +template +size_t hash_value(const Simplex& s) { return boost::hash_range(s.begin(), s.end()); } + + +template +struct Simplex::BoundaryIterator: + public boost::iterator_adaptor, // Value + boost::use_default, + Simplex> // Reference +{ + public: + typedef const V* Iterator; + typedef Simplex Value; + + typedef boost::iterator_adaptor Parent; + + BoundaryIterator() {} + explicit BoundaryIterator(short unsigned dim, Iterator iter, Iterator bg, Iterator end): + Parent(iter), dim_(dim), bg_(bg), end_(end) {} + + Iterator begin() const { return bg_; } + + private: + friend class boost::iterator_core_access; + Value dereference() const + { + typedef std::not_equal_to NotEqualVertex; + + return Simplex(dim_ - 1, + boost::make_filter_iterator(std::bind2nd(NotEqualVertex(), *(this->base())), bg_, end_), + boost::make_filter_iterator(std::bind2nd(NotEqualVertex(), *(this->base())), end_, end_)); + } + + short unsigned dim_; + Iterator bg_; + Iterator end_; +}; + +template +template +struct Simplex::BoundaryChainIterator: + public boost::iterator_adaptor, // Derived + BoundaryIterator, + ChainEntry>, // Value + boost::use_default, + ChainEntry>> // Reference +{ + public: + typedef F Field; + typedef BoundaryIterator Iterator; + typedef ChainEntry> Value; + + typedef boost::iterator_adaptor Parent; + + BoundaryChainIterator() {} + explicit BoundaryChainIterator(const Field& field, Iterator iter): + Parent(iter), field_(&field) {} + + private: + friend class boost::iterator_core_access; + Value dereference() const + { + return Value(((this->base().base() - this->base().begin()) % 2 == 0)? field_->id() : field_->neg(field_->id()), + *(this->base())); + } + + const Field* field_ = nullptr; +}; + + +/* Simplex */ +template +typename Simplex::BoundaryIterator +Simplex:: +boundary_begin() const +{ + if (dimension() == 0) return boundary_end(); + return BoundaryIterator(dimension(), begin(), begin(), end()); +} + +template +typename Simplex::BoundaryIterator +Simplex:: +boundary_end() const +{ + return BoundaryIterator(dimension(), end(), begin(), end()); +} + +template +template +typename Simplex::template BoundaryChainIterator +Simplex:: +boundary_begin(const F& field) const +{ + if (dimension() == 0) return boundary_end(field); + return BoundaryChainIterator(field, boundary_begin()); +} + +template +template +typename Simplex::template BoundaryChainIterator +Simplex:: +boundary_end(const F& field) const +{ + return BoundaryChainIterator(field, boundary_end()); +} + +} // dionysus + +namespace std +{ + +template +struct hash> +{ + size_t operator()(const dionysus::Simplex& s) const { return hash_value(s); } +}; + +} // std + +#endif diff --git a/src/dionysus/dionysus/sparse-row-matrix.h b/src/dionysus/dionysus/sparse-row-matrix.h new file mode 100755 index 0000000..fb1e929 --- /dev/null +++ b/src/dionysus/dionysus/sparse-row-matrix.h @@ -0,0 +1,184 @@ +#ifndef DIONYSUS_SPARSE_ROW_MATRIX_H +#define DIONYSUS_SPARSE_ROW_MATRIX_H + +#include +#include +#include +#include // for debugging output + +#include + +#include "chain.h" +#include "reduction.h" + +namespace dionysus +{ + +namespace bi = boost::intrusive; + +namespace detail +{ + typedef bi::list_base_hook> auto_unlink_hook; + + template + struct SparseRowMatrixEntry: + public ChainEntry, auto_unlink_hook> + { + typedef I Index; + typedef typename F::Element FieldElement; + typedef std::tuple IndexPair; // (id, pair) + typedef ChainEntry Parent; + typedef SparseRowMatrixEntry Entry; + + SparseRowMatrixEntry(FieldElement e, const IndexPair& ip): + Parent(e,ip) {} + + SparseRowMatrixEntry(FieldElement e, const Index& r, const Index& c): + Parent(e,IndexPair(r,c)) {} + + SparseRowMatrixEntry(const Entry& other) = default; + SparseRowMatrixEntry(Entry&& other) = default; + Entry& operator=(Entry&& other) = default; + + void unlink() { auto_unlink_hook::unlink(); } + bool is_linked() const { return auto_unlink_hook::is_linked(); } + }; +} + +template, + template class Column_ = std::vector> +class SparseRowMatrix +{ + public: + typedef Field_ Field; + typedef Index_ Index; + typedef Comparison_ Comparison; + + typedef typename Field::Element FieldElement; + + typedef detail::SparseRowMatrixEntry Entry; + typedef Column_ Column; + typedef typename Entry::IndexPair IndexPair; + typedef bi::list> Row; + + typedef std::vector> IndexChain; + + typedef std::unordered_map Columns; + typedef std::unordered_map Rows; + typedef std::unordered_map LowMap; + + public: + SparseRowMatrix(const Field& field, + const Comparison& cmp = Comparison()): + field_(field), cmp_(cmp) {} + + SparseRowMatrix(SparseRowMatrix&& other) = default; + + + template + Column reduce(const ChainRange& chain, IndexChain& trail); + + Index set(Index i, Column&& chain); // returns previous column with this low + void fix(Index c, Column& column); + void fix(Index c) { fix(c, col(c)); } + + const Row& prepend_row(Index r, FieldElement m, const Row& chain); // could be horribly inefficient if Column is chosen poorly + + void drop_row(Index r) { rows_.erase(r); if (is_low(r)) lows_.erase(r); } + void drop_col(Index c) + { + auto cit = columns_.find(c); + Column& column = cit->second; + if (!column.empty()) + { + Index rlow = std::get<0>(column.back().index()); + auto it = lows_.find(rlow); + if (it != lows_.end() && it->second == c) + lows_.erase(it); + } + columns_.erase(cit); + } + void drop_low(Index r) { lows_.erase(r); } + + // accessors + Row& row(Index r) { return rows_[r]; } + Column& col(Index c) { assert(col_exists(c)); return columns_.find(c)->second; } + const Column& col(Index c) const { assert(col_exists(c)); return columns_.find(c)->second; } + Index low(Index r) const { return lows_.find(r)->second; } + bool is_low(Index r) const { return lows_.find(r) != lows_.end(); } + void update_low(Index c) { lows_[std::get<0>(col(c).back().index())] = c; } + + const Field& field() const { return field_; } + void reserve(size_t) {} // here for compatibility only + const Comparison& cmp() const { return cmp_; } + + // debug + bool col_exists(Index c) const { return columns_.find(c) != columns_.end(); } + const Columns& columns() const { return columns_; } + void check_columns() const + { + for (auto& x : columns_) + { + Index c = x.first; + if (x.second.empty()) + std::cout << "Warning: empty column " << c << std::endl; + Index rl = std::get<0>(x.second.back().index()); + if (!is_low(rl) || low(rl) != c) + { + std::cout << "Columns don't check out: lows don't match" << std::endl; + std::cout << " " << c << ' ' << rl << ' ' << ' ' << low(rl) << std::endl; + std::cout << "---\n"; + for (auto& x : col(c)) + std::cout << " " << x.element() << ' ' << std::get<0>(x.index()) << ' ' << std::get<1>(x.index()) << '\n'; + std::cout << "---\n"; + for (auto& x : col(low(rl))) + std::cout << " " << x.element() << ' ' << std::get<0>(x.index()) << ' ' << std::get<1>(x.index()) << '\n'; + assert(0); + } + + for (auto& x : lows_) + { + if (!col_exists(x.second)) + { + std::cout << "Still keeping low of a removed column" << std::endl; + assert(0); + } + else if (std::get<0>(col(x.second).back().index()) != x.first) + { + std::cout << "Low mismatch: " << x.second << ' ' << std::get<0>(col(x.second).back().index()) << ' ' << x.first << '\n'; + assert(0); + } + } + } + } + + private: + Field field_; + Comparison cmp_; + + Columns columns_; + Rows rows_; + LowMap lows_; // column that has this low +}; + + +namespace detail +{ + +template +struct Unpaired> +{ + static + constexpr std::tuple + value() + { return std::make_tuple(std::numeric_limits::max(), + std::numeric_limits::max()); } +}; + +} + +} + +#include "sparse-row-matrix.hpp" + +#endif diff --git a/src/dionysus/dionysus/sparse-row-matrix.hpp b/src/dionysus/dionysus/sparse-row-matrix.hpp new file mode 100755 index 0000000..10f4808 --- /dev/null +++ b/src/dionysus/dionysus/sparse-row-matrix.hpp @@ -0,0 +1,103 @@ +template class Col> +template +typename dionysus::SparseRowMatrix::Column +dionysus::SparseRowMatrix:: +reduce(const ChainRange& chain_, IndexChain& trail) +{ + auto row_cmp = [this](const Entry& e1, const Entry& e2) + { return this->cmp_(std::get<0>(e1.index()), std::get<0>(e2.index())); }; + +#define __DIONYSUS_USE_VECTOR_CHAINS 1 + +#if !(__DIONYSUS_USE_VECTOR_CHAINS) + std::set chain(row_cmp); + for (auto x : chain_) + chain.insert(Entry(x.element(), IndexPair(x.index(), 0))); +#else + Column chain; + for (auto x : chain_) + chain.emplace_back(x.element(), IndexPair(x.index(), 0)); + std::sort(chain.begin(), chain.end(), row_cmp); +#endif + + typedef Reduction ReductionIP; + + auto chains = [this](const IndexPair& rc) -> const Column& { return this->col(std::get<1>(rc)); }; + auto lows = [this](const IndexPair& rc) -> IndexPair + { + Index r = std::get<0>(rc); + auto it = this->lows_.find(r); + if (it == this->lows_.end()) + return ReductionIP::unpaired; + else + { + Index rr = std::get<0>(col(it->second).back().index()); + if (rr != r) + std::cout << "Mismatch: " << rr << ' ' << r << std::endl; + return IndexPair(r, it->second); + } + }; + + auto addto = [&trail](FieldElement m, const IndexPair& rc) { trail.emplace_back(m, std::get<1>(rc)); }; + + ReductionIP::reduce(chain, + chains, lows, + field_, addto, row_cmp); + +#if !(__DIONYSUS_USE_VECTOR_CHAINS) + return Column(std::begin(chain), std::end(chain)); +#else + return chain; +#endif +} + +template class Col> +typename dionysus::SparseRowMatrix::Index +dionysus::SparseRowMatrix:: +set(Index col, Column&& chain) +{ + Column& column = columns_.emplace(col, std::move(chain)).first->second; + + fix(col, column); + + Index r = std::get<0>(column.back().index()); + Index res; + if (is_low(r)) + res = low(r); + else + res = col; + lows_[r] = col; + + return res; +} + +template class Col> +void +dionysus::SparseRowMatrix:: +fix(Index col, Column& column) +{ + for (auto& x : column) + { + std::get<1>(x.index()) = col; + Index r = std::get<0>(x.index()); + row(r).push_back(x); + } +} + +template class Col> +const typename dionysus::SparseRowMatrix::Row& +dionysus::SparseRowMatrix:: +prepend_row(Index r, FieldElement m, const Row& chain) +{ + Row& new_row = row(r); + + for (auto& x : chain) + { + Index c = std::get<1>(x.index()); + Column& column = col(c); + auto it = column.emplace(column.begin(), field().mul(x.element(), m), r, c); + new_row.push_back(*it); + } + + return new_row; +} diff --git a/src/dionysus/dionysus/standard-reduction.h b/src/dionysus/dionysus/standard-reduction.h new file mode 100755 index 0000000..0477d46 --- /dev/null +++ b/src/dionysus/dionysus/standard-reduction.h @@ -0,0 +1,44 @@ +#ifndef DIONYSUS_STANDARD_REDUCTION_H +#define DIONYSUS_STANDARD_REDUCTION_H + +namespace dionysus +{ + +// Mid-level interface +template +class StandardReduction +{ + public: + typedef Persistence_ Persistence; + typedef typename Persistence::Field Field; + typedef typename Persistence::Index Index; + + public: + StandardReduction(Persistence& persistence): + persistence_(persistence) {} + + template + void operator()(const Filtration& f, const Relative& relative, const ReportPair& report_pair, const Progress& progress); + + template + void operator()(const Filtration& f, const ReportPair& report_pair); + + template + void operator()(const Filtration& f) { return (*this)(f, &no_report_pair); } + + static void no_report_pair(int, Index, Index) {} + static void no_progress() {} + + const Persistence& + persistence() const { return persistence_; } + Persistence& persistence() { return persistence_; } + + private: + Persistence& persistence_; +}; + +} + +#include "standard-reduction.hpp" + +#endif diff --git a/src/dionysus/dionysus/standard-reduction.hpp b/src/dionysus/dionysus/standard-reduction.hpp new file mode 100755 index 0000000..9aa3396 --- /dev/null +++ b/src/dionysus/dionysus/standard-reduction.hpp @@ -0,0 +1,47 @@ +#include +namespace ba = boost::adaptors; + +template +template +void +dionysus::StandardReduction

:: +operator()(const Filtration& filtration, const ReportPair& report_pair) +{ + using Cell = typename Filtration::Cell; + (*this)(filtration, [](const Cell&) { return false; }, report_pair, no_progress); +} + +template +template +void +dionysus::StandardReduction

:: +operator()(const Filtration& filtration, const Relative& relative, const ReportPair& report_pair, const Progress& progress) +{ + persistence_.reserve(filtration.size()); + + typedef typename Filtration::Cell Cell; + typedef ChainEntry CellChainEntry; + typedef ChainEntry ChainEntry; + + unsigned i = 0; + for(auto& c : filtration) + { + progress(); + + if (relative(c)) + { + ++i; + persistence_.add_skip(); + continue; + } + + //std::cout << "Adding: " << c << " : " << boost::distance(c.boundary(persistence_.field())) << std::endl; + Index pair = persistence_.add(c.boundary(persistence_.field()) | + ba::filtered([relative](const CellChainEntry& e) { return !relative(e.index()); }) | + ba::transformed([this,&filtration](const CellChainEntry& e) + { return ChainEntry(e.element(), filtration.index(e.index())); })); + if (pair != persistence_.unpaired()) + report_pair(c.dimension(), pair, i); + ++i; + } +} diff --git a/src/dionysus/dionysus/trails-chains.h b/src/dionysus/dionysus/trails-chains.h new file mode 100755 index 0000000..f18ff89 --- /dev/null +++ b/src/dionysus/dionysus/trails-chains.h @@ -0,0 +1,17 @@ +#ifndef DIONYSUS_TRAILS_CHAINS_H +#define DIONYSUS_TRAILS_CHAINS_H + +#include "ordinary-persistence.h" + +template +struct ChainsVisitor: public EmptyVisitor +{ + template + void chain_initialized(Chain& c) { } + + void addto(typename Field::Element m, Index cl) {} + void reduction_finished() {} +}; + + +#endif diff --git a/src/dionysus/dionysus/zigzag-persistence.h b/src/dionysus/dionysus/zigzag-persistence.h new file mode 100755 index 0000000..6731f2d --- /dev/null +++ b/src/dionysus/dionysus/zigzag-persistence.h @@ -0,0 +1,141 @@ +#ifndef DIONYSUS_ZIGZAG_PERSISTENCE_H +#define DIONYSUS_ZIGZAG_PERSISTENCE_H + +#include +#include + +#include +#include + +#include "sparse-row-matrix.h" + +namespace dionysus +{ + +namespace ba = boost::adaptors; + +template> +class ZigzagPersistence +{ + static_assert(std::is_signed::value, "Index type used in ZigzagPersistence must be a *signed* integer"); + + public: + typedef Field_ Field; + typedef Index_ Index; + typedef Comparison_ Comparison; + + typedef SparseRowMatrix RowMatrix; + typedef SparseRowMatrix DequeRowMatrix; + typedef typename RowMatrix::IndexPair IndexPair; + typedef typename RowMatrix::FieldElement FieldElement; + typedef typename RowMatrix::IndexChain IndexChain; + typedef typename RowMatrix::Column Column; + typedef typename RowMatrix::Row Row; + typedef typename DequeRowMatrix::Column DequeColumn; + typedef typename DequeRowMatrix::Row DequeRow; + + typedef std::unordered_map BirthIndexMap; + + + ZigzagPersistence(const Field& field, + const Comparison& cmp = Comparison()): + Z(field, cmp), C(field, cmp), B(field, cmp), + operations(0), + cell_indices(0), + z_indicies_last(0), + z_indicies_first(-1), + b_indices(0) {} + + template + Index add(const ChainRange& chain) // returns the id of the dying cycle (or unpaired) + { + Index res = add_impl(chain); +#ifdef DIONYSUS_ZIGZAG_DEBUG + check_sorted(); + check_b_cols(); + Z.check_columns(); +#endif + return res; + } + Index remove(Index cell) + { + Index res = remove_impl(cell); +#ifdef DIONYSUS_ZIGZAG_DEBUG + check_sorted(); + check_b_cols(); + Z.check_columns(); +#endif + return res; + } + + struct IsAlive + { + IsAlive(const ZigzagPersistence& zz_): zz(&zz_) {} + bool operator()(const std::pair& x) const { return zz->is_alive(x.first); } + const ZigzagPersistence* zz; + }; + + bool is_alive(Index x) const { return !B.is_low(x); } + + auto alive_ops() const -> decltype(BirthIndexMap() | ba::filtered(IsAlive(*this)) | ba::map_values) + { return birth_index | ba::filtered(IsAlive(*this)) | ba::map_values; } + + auto alive_cycles() const -> decltype(BirthIndexMap() | ba::filtered(IsAlive(*this)) | ba::map_keys) + { return birth_index | ba::filtered(IsAlive(*this)) | ba::map_keys; } + + size_t alive_size() const { return Z.columns().size() - B.columns().size(); } + + void reserve(size_t) {} // here for compatibility only + const Field& field() const { return Z.field(); } + const Comparison& cmp() const { return Z.cmp(); } + + template + static Index row(const Entry& e) { return std::get<0>(e.index()); } + template + static Index col(const Entry& e) { return std::get<1>(e.index()); } + + static + const Index unpaired() { return Reduction::unpaired; } + + const Column& cycle(Index i) const { return Z.col(i); } + + // debug + void check_b_cols() const; + + template + void check_boundaries(const SimplexToIndex& s2i, const IndexToSimplex& i2s) const; + template + void check_cycles(const SimplexToIndex& s2i, const IndexToSimplex& i2s) const; + + Column zb_dot(Index c) const; + + template + Column dc_dot(Index c, const SimplexToIndex& s2i, const IndexToSimplex& i2s) const; + + template + Column boundary(Index i, const SimplexToIndex& s2i, const IndexToSimplex& i2s) const; + + void check_sorted() const; + + private: + template + Index add_impl(const ChainRange& chain); + Index remove_impl(Index cell); + + private: + RowMatrix Z, C; + DequeRowMatrix B; + + BirthIndexMap birth_index; + Index operations; + Index cell_indices; + Index z_indicies_last, z_indicies_first; + Index b_indices; +}; + +} + +#include "zigzag-persistence.hpp" + +#endif diff --git a/src/dionysus/dionysus/zigzag-persistence.hpp b/src/dionysus/dionysus/zigzag-persistence.hpp new file mode 100755 index 0000000..2d52789 --- /dev/null +++ b/src/dionysus/dionysus/zigzag-persistence.hpp @@ -0,0 +1,534 @@ +template +template +typename dionysus::ZigzagPersistence::Index +dionysus::ZigzagPersistence:: +add_impl(const ChainRange& chain_) +{ + //std::cout << "add(" << cell_indices << ")" << std::endl; + Index op = operations++; + + IndexChain cycles; // chain_ -> Z*cycles + Column z_remainder = Z.reduce(chain_, cycles); + assert(z_remainder.empty()); + + IndexChain boundaries; + DequeColumn b_remainder = B.reduce(cycles, boundaries); + + // add up columns of C indexed by boundaries + typedef typename Column::value_type Entry; + auto row_cmp = [this](const Entry& e1, const Entry& e2) + { return this->cmp()(row(e1), row(e2)); }; + Column chain; + for (auto& x : boundaries) + Chain::addto(chain, x.element(), C.col(x.index()), field(), row_cmp); + chain.push_back(Entry(field().neg(field().id()), IndexPair(cell_indices++,0))); + + if (b_remainder.empty()) // birth + { + //std::cout << " birth" << std::endl; + Index z_col = z_indicies_last++; + Z.set(z_col, std::move(chain)); + birth_index[z_col] = op; + return unpaired(); + } + else // death + { + //std::cout << " death" << std::endl; + Index b_col = b_indices++; + Index pair = row(b_remainder.back()); + B.set(b_col, std::move(b_remainder)); + C.set(b_col, std::move(chain)); + return birth_index[pair]; + } +} + +template +typename dionysus::ZigzagPersistence::Index +dionysus::ZigzagPersistence:: +remove_impl(Index cell) +{ + //std::cout << "remove(" << cell << ")" << std::endl; + + Index op = operations++; + + typedef typename Column::value_type Entry; + auto row_cmp = [this](const Entry& e1, const Entry& e2) + { return this->cmp()(row(e1), row(e2)); }; + typedef typename DequeColumn::value_type DequeEntry; + auto b_row_cmp = [this](const DequeEntry& e1, const DequeEntry& e2) + { return this->cmp()(row(e1), row(e2)); }; + + IndexChain z_row; + for (auto& x : Z.row(cell)) + z_row.emplace_back(x.element(), col(x)); + + if (z_row.empty()) // birth + { + //std::cout << " birth" << std::endl; + Row& c_row = C.row(cell); + // c_row.front() may not be the first column in order, but that doesn't really matter, does it? (TODO) + auto& c_front = c_row.front(); + + Index j = col(c_front); + Index l = row(B.col(j).back()); + + //std::cout << j << ' ' << l << std::endl; + + // cycle = ZB[j] = DC[j] + Column cycle; + for (auto& x : B.col(j)) + Chain::addto(cycle, x.element(), Z.col(row(x)), field(), row_cmp); + + //std::cout << "Cycle:" << std::endl; + //for (auto& x : cycle) + // std::cout << x.element() << ' ' << row(x) << std::endl; + + // 1: prepend the cycle + Index znew = z_indicies_first--; + Index oth = Z.set(znew, std::move(cycle)); // oth records our collision (used in step 6) + birth_index[znew] = op; + + //std::cout << "znew oth: " << znew << ' ' << oth << std::endl; + //std::cout << "oth column:" << std::endl; + //for (auto& x : Z.col(oth)) + // std::cout << x.element() << ' ' << row(x) << std::endl; + + // 2: prepend the row to B + FieldElement m = field().neg(field().inv(c_front.element())); // m = -1/c + const DequeRow& b_row = B.prepend_row(znew, m, c_row); + //std::cout << "Prepended row with multiplier: " << m << " (" << b_row.size() << ")" << std::endl; + + // 3: subtract C[j] from every C[k] + const Column& Cj = C.col(j); + + // use the copy of c_row in B, since c_row will be modified in the following loop + for (auto it = std::next(b_row.begin()); it != b_row.end(); ++it) + { + Index c = col(*it); + assert(c != j); + //std::cout << "adding to " << c << " in C" << std::endl; + Chain::addto(C.col(c), it->element(), Cj, field(), row_cmp); // using it->element() since b_row = m*c_row + C.fix(c); // old elements got removed via auto_unlink_hook + // we don't need lows in C, so not updating them + } + //std::cout << "Done with step 3" << std::endl; + + // 4: subtract B[j] from every B[k] that has l + // (we don't need to update C because ZB[j] = 0 after step 2) + DequeColumn& Bj = B.col(j); + FieldElement bm = field().neg(field().inv(Bj.back().element())); // bm = -1/B[l,j] + IndexChain Bl_row; // make a copy of Bl_row, since it will be changing + for (auto& x : B.row(l)) + { + if (col(x) == j) + continue; + Bl_row.emplace_back(x.element(), col(x)); + } + for (auto& x : Bl_row) + { + Index c = x.index(); + assert(c != j); + Chain::addto(B.col(c), field().mul(bm, x.element()), Bj, field(), b_row_cmp); + B.fix(c); // old elements got removed via auto_unlink_hook + // l cannot be the low in c, so no need to update lows + } + //std::cout << "Done with step 4" << std::endl; + + // 5: drop row l and column j from B; drop column l from Z; drop column j from C + B.drop_col(j); + assert(B.row(l).empty()); + B.drop_row(l); + Index Zl_low = row(Z.col(l).back()); + Z.drop_col(l); + birth_index.erase(l); + C.drop_col(j); + assert(Z.row(cell).empty()); + assert(C.row(cell).empty()); + C.drop_row(cell); + Z.drop_row(cell); + //std::cout << "Done with step 5" << std::endl; + if (oth == l) // we just dropped our collision in Z + oth = znew; + else + Z.drop_low(Zl_low); + + // 6: reduce Z + std::unordered_map b_changes; // the columns to add in B to apply row changes + Index cur = znew; + while (oth != cur) + { + Column& cur_col = Z.col(cur); + Column& oth_col = Z.col(oth); + assert(row(cur_col.back()) == row(oth_col.back())); + //std::cout << "--- " << cur << " (" << cur_col.size() << ") " << oth << " (" << oth_col.size() << ")" << std::endl; + FieldElement m1 = cur_col.back().element(); + FieldElement m2 = oth_col.back().element(); + FieldElement m2_div_m1 = field().div(m2, m1); + Chain::addto(oth_col, field().neg(m2_div_m1), cur_col, field(), row_cmp); + Z.fix(oth, oth_col); + + // record the changes we need to make in B; + // because there is only one collision in the matrix during the reduction, + // once we use a row as the source, we never revisit it. This means once the row is updated in B, + // we never touch it again, so below record is fine. + for (auto& x : this->B.row(oth)) + b_changes[col(x)].emplace_back(field().mul(x.element(), m2_div_m1), cur, col(x)); + + cur = oth; + Index low = row(oth_col.back()); + if (Z.is_low(low)) + oth = Z.low(low); + //std::cout << "--- -- new low: " << low << ' ' << cur << ' ' << oth << std::endl; + + if (cmp()(oth, cur)) + std::swap(oth, cur); + else + Z.update_low(cur); + } + + // apply changes in B (the complexity here could get ugly) + for (auto& bx : b_changes) + { + std::sort(bx.second.begin(), bx.second.end(), b_row_cmp); + Chain::addto(B.col(bx.first), field().id(), bx.second, field(), b_row_cmp); + B.fix(bx.first); + // no need to update low (additions from bottom up) + } + //std::cout << "Done with step 6" << std::endl; + + return unpaired(); + } + else // death + { + //std::cout << " death" << std::endl; + + auto index_chain_cmp = [this](const typename IndexChain::value_type& e1, const typename IndexChain::value_type& e2) + { return this->cmp()(e1.index(), e2.index()); }; + + // 1: change basis to clear z_row + std::sort(z_row.begin(), z_row.end(), index_chain_cmp); // this adds a log factor, but it makes life easier + Index j = z_row.front().index(); + FieldElement e = z_row.front().element(); + + if (z_row.size() > 1) + { + // figure out the columns we use for reduction + typedef typename IndexChain::const_iterator RowIterator; + std::vector reducers; + reducers.push_back(z_row.begin()); + for (RowIterator it = std::next(z_row.begin()); it != z_row.end(); ++it) + { + Index c = it->index(); + + assert(Z.col_exists(c)); + assert(Z.col_exists(reducers.back()->index())); + if (cmp()(row(Z.col(c).back()), + row(Z.col(reducers.back()->index()).back()))) + reducers.push_back(it); + } + reducers.push_back(z_row.end()); + //std::cout << "reducers.size(): " << reducers.size() << std::endl; + //std::cout << "z_row.size(): " << z_row.size() << std::endl; + + + std::map b_changes; // the rows to add to B + auto add_in_z = [this,&b_changes,&row_cmp,&index_chain_cmp](Index to, Index from, FieldElement m, FieldElement e) + { + //std::cout << " add_in_z: " << from << ' ' << to << std::endl; + + FieldElement mult = this->field().mul(m, e); + assert(Z.col_exists(to)); + assert(Z.col_exists(from)); + Chain::addto(Z.col(to), mult, Z.col(from), this->field(), row_cmp); + assert(!Z.col(to).empty()); + this->Z.fix(to); // NB: rows will be linked in the back, so the iterators are Ok + this->Z.update_low(to); + + // subtract B.row(to) from B.row(from) + IndexChain Bto_row; + for (auto& x : this->B.row(to)) + Bto_row.emplace_back(x.element(), col(x)); + std::sort(Bto_row.begin(), Bto_row.end(), index_chain_cmp); + +#if 0 + for (auto& x : this->B.row(to)) + std::cout << x.element() << ' ' << row(x) << ' ' << col(x) << std::endl; + + std::cout << "---\n"; + + for (auto& x : this->B.row(from)) + std::cout << x.element() << ' ' << row(x) << ' ' << col(x) << std::endl; +#endif + + Chain::addto(b_changes[from], this->field().neg(mult), Bto_row, this->field(), index_chain_cmp); + + // if there is b_changes[to] add it, too + auto it = b_changes.find(to); + if (it != b_changes.end()) + Chain::addto(b_changes[from], this->field().neg(mult), it->second, this->field(), index_chain_cmp); + }; + Index last_low = row(Z.col(reducers[reducers.size() - 2]->index()).back()); + for (int i = reducers.size() - 2; i >= 0; --i) + { + auto rit = reducers[i]; + FieldElement m = field().neg(field().inv(rit->element())); + + for (auto it = std::next(rit); it != reducers[i+1]; ++it) + add_in_z(it->index(), rit->index(), m, it->element()); + + if (i + 1 != reducers.size() - 1) + { + auto it = reducers[i+1]; + add_in_z(it->index(), rit->index(), m, it->element()); + } + } + if (reducers.size() > 2) + Z.drop_low(last_low); + + // apply changes in b (the complexity here could get ugly) + // Specifically, transpose b_changes and add it in + std::unordered_map b_changes_transposed; + for (auto& b_row : b_changes) + for (auto& bx : b_row.second) + b_changes_transposed[bx.index()].emplace_back(bx.element(), b_row.first, bx.index()); + + for (auto& b_col : b_changes_transposed) + { +#if 0 + std::cout << "Adding:" << std::endl; + for (auto& x : b_col.second) + std::cout << x.element() << ' ' << row(x) << ' ' << col(x) << std::endl; +#endif + Chain::addto(B.col(b_col.first), field().id(), b_col.second, field(), b_row_cmp); + assert(!B.col(b_col.first).empty()); + B.fix(b_col.first); + // no need to update low (additions from bottom up) + } + } // z_row.size() > 1 + + // 2: subtract cycle from every chain in C + const Column& Zj = Z.col(j); + //std::cout << "Zj:" << std::endl; + //for (auto& x : Zj) + // std::cout << x.element() << " * " << row(x) << std::endl; + + IndexChain Ccols; // save the columns in C, we'll be modifying C.row(cell) + for (auto& x : C.row(cell)) + Ccols.emplace_back(x.element(), col(x)); + + for (auto& x : Ccols) + { + Index c = x.index(); + FieldElement m = field().neg(field().div(x.element(), e)); // m = -C[k][cell]/Z[j][cell] + //std::cout << "Adding to C: " << c << std::endl; + Chain::addto(C.col(c), m, Zj, field(), row_cmp); + C.fix(c); + // we don't care about lows in C, so don't update them + } + + // 3: drop + assert(Z.row(cell).size() == 1); + Z.drop_col(j); + assert(Z.row(cell).empty()); + assert(C.row(cell).empty()); + Z.drop_row(cell); + C.drop_row(cell); + assert(B.row(j).empty()); + B.drop_row(j); + + Index birth = birth_index[j]; + birth_index.erase(j); + + return birth; + } +} + + +/* debug routines */ +template +void +dionysus::ZigzagPersistence:: +check_b_cols() const +{ + // check that entries in B refer to existing Z columns + bool stop = false; + for (auto& b : B.columns()) + for (auto& x : b.second) + if (!Z.col_exists(row(x))) + { + std::cout << "B refers to a non-existent column in Z: " << row(x) << std::endl; + stop = true; + } + if (stop) + assert(0); +} + +template +template +void +dionysus::ZigzagPersistence:: +check_cycles(const SimplexToIndex& s2i, const IndexToSimplex& i2s) const +{ + typedef typename Column::value_type Entry; + auto row_cmp = [this](const Entry& e1, const Entry& e2) + { return this->cmp()(row(e1), row(e2)); }; + + for (auto& z : Z.columns()) + { + Column res; + for (auto& x : z.second) + { + Column bdry = boundary(row(x), s2i, i2s); + Chain::addto(res, x.element(), bdry, field(), row_cmp); + } + assert(res.empty()); + } +} + +template +template +void +dionysus::ZigzagPersistence:: +check_boundaries(const SimplexToIndex& s2i, const IndexToSimplex& i2s) const +{ + check_cycles(s2i, i2s); + + for (auto& x : B.columns()) + if (!C.col_exists(x.first)) + { + std::cout << x.first << " in B, but not in C" << std::endl; + assert(0); + } + + for (auto& x : C.columns()) + if (!B.col_exists(x.first)) + { + std::cout << x.first << " in B, but not in C" << std::endl; + assert(0); + } + + for (auto& x : B.columns()) + { + auto zb = zb_dot(x.first); + auto dc = dc_dot(x.first, s2i, i2s); + + auto it_zb = zb.begin(), + it_dc = dc.begin(); + for (; it_zb != zb.end(); ++it_zb, ++it_dc) + { + if (it_zb->element() != it_dc->element() || row(*it_zb) != row(*it_dc)) + { + std::cout << "Boundary mismatch: " << x.first << std::endl; + std::cout << "===" << std::endl; + for (auto& x : zb) + std::cout << " " << x.element() << ' ' << row(x) << std::endl; + for (auto& y : B.col(x.first)) + { + std::cout << " " << y.element() << " * " << row(y) << std::endl; + for (auto& z : Z.col(row(y))) + std::cout << " " << z.element() << ' ' << row(z) << std::endl; + std::cout << " ---" << std::endl; + } + std::cout << "===" << std::endl; + for (auto& x : dc) + std::cout << " " << x.element() << ' ' << row(x) << std::endl; + for (auto& y : C.col(x.first)) + { + std::cout << " " << y.element() << " * " << row(y) << std::endl; + for (auto& z : boundary(row(y), s2i, i2s)) + std::cout << " " << z.element() << ' ' << row(z) << std::endl; + std::cout << " ---" << std::endl; + } + assert(0); + } + } + if (it_zb != zb.end() || it_dc != dc.end()) + { + std::cout << "zb.end() doesn't match dc.end()" << std::endl; + assert(0); + } + } +} + +template +typename dionysus::ZigzagPersistence::Column +dionysus::ZigzagPersistence:: +zb_dot(Index c) const +{ + typedef typename Column::value_type Entry; + auto row_cmp = [this](const Entry& e1, const Entry& e2) + { return this->cmp()(row(e1), row(e2)); }; + Column res; + for (auto& x : B.col(c)) + Chain::addto(res, x.element(), Z.col(row(x)), field(), row_cmp); + + return res; +} + +template +template +typename dionysus::ZigzagPersistence::Column +dionysus::ZigzagPersistence:: +dc_dot(Index c, const SimplexToIndex& s2i, const IndexToSimplex& i2s) const +{ + typedef typename Column::value_type Entry; + auto row_cmp = [this](const Entry& e1, const Entry& e2) + { return this->cmp()(row(e1), row(e2)); }; + Column res; + for (auto& x : C.col(c)) + { + Column bdry = boundary(row(x), s2i, i2s); + Chain::addto(res, x.element(), bdry, field(), row_cmp); + } + return res; +} + +template +template +typename dionysus::ZigzagPersistence::Column +dionysus::ZigzagPersistence:: +boundary(Index i, const SimplexToIndex& s2i, const IndexToSimplex& i2s) const +{ + typedef typename Column::value_type Entry; + auto row_cmp = [this](const Entry& e1, const Entry& e2) + { return this->cmp()(row(e1), row(e2)); }; + Column bdry; + auto s = i2s(i); + for (auto y : s.boundary(field())) + bdry.emplace_back(y.element(), s2i(y.index()), 0); + std::sort(bdry.begin(), bdry.end(), row_cmp); + return bdry; +} + +template +void +dionysus::ZigzagPersistence:: +check_sorted() const +{ + typedef typename Column::value_type Entry; + auto row_cmp = [this](const Entry& e1, const Entry& e2) + { return this->cmp()(row(e1), row(e2)); }; + typedef typename DequeColumn::value_type DequeEntry; + auto b_row_cmp = [this](const DequeEntry& e1, const DequeEntry& e2) + { return this->cmp()(row(e1), row(e2)); }; + + for (auto& x : Z.columns()) + if (!std::is_sorted(x.second.begin(), x.second.end(), row_cmp)) + { + std::cout << "Z column not sorted: " << x.first << std::endl; + assert(0); + } + for (auto& x : C.columns()) + if (!std::is_sorted(x.second.begin(), x.second.end(), row_cmp)) + { + std::cout << "C column not sorted: " << x.first << std::endl; + assert(0); + } + for (auto& x : B.columns()) + if (!std::is_sorted(x.second.begin(), x.second.end(), b_row_cmp)) + { + std::cout << "B column not sorted: " << x.first << std::endl; + assert(0); + } +} + diff --git a/src/tdautils/dionysusUtils.h b/src/tdautils/dionysusUtils.h index ce3fa04..bc74f56 100644 --- a/src/tdautils/dionysusUtils.h +++ b/src/tdautils/dionysusUtils.h @@ -1,7 +1,14 @@ #ifndef __DIONYSUSUTILS_H__ #define __DIONYSUSUTILS_H__ -#include +#include "../dionysus/dionysus/simplex.h" +#include "../dionysus/dionysus/rips.h" +#include "../dionysus/dionysus/filtration.h" +#include "../dionysus/dionysus/standard-reduction.h" + +// swapping simplex +//#include + #include #include @@ -385,7 +392,43 @@ inline Filtration RipsFiltrationDionysus( return filtration; } +/* +template< typename Distances, typename Generator, typename Filtration, + typename RealMatrix, typename Print > +inline Filtration RipsFiltrationDionysus2( + const RealMatrix & X, + const unsigned nSample, + const unsigned nDim, + const bool is_row_names, + const int maxdimension, + const double maxscale, + const bool printProgress, + const Print & print +) { + // This is a Matrix of Points + PointContainer points = TdaToStl< PointContainer >(X, nSample, nDim, is_row_names); + // + Distances distances(points); + Generator rips(distances); + typename Generator::Evaluator size(distances); + Filtration filtration; + EvaluatePushBack< Filtration, typename Generator::Evaluator > functor( + filtration, size); + + // Generate maxdimension skeleton of the Rips complex + rips.generate(maxdimension + 1, maxscale, functor); + if (printProgress) { + print("# Generated complex of size: %d \n", filtration.size()); + } + + // Sort the simplices with respect to comparison criteria + // e.g. distance or function values + filtration.sort(ComparisonDataDimension< typename Filtration::Simplex >()); + + return filtration; +} +*/ -# endif // __DIONYSUSUTILS_H__ \ No newline at end of file +# endif // __DIONYSUSUTILS_H__ diff --git a/tests/testthat/test_kde.R b/tests/testthat/test_kde.R new file mode 100644 index 0000000..3e39ae3 --- /dev/null +++ b/tests/testthat/test_kde.R @@ -0,0 +1,20 @@ +context("KDE") + +test_that("kde returns 1 when points are the same", { + h <- 1/sqrt(2*pi) + X <- expand.grid(0,0);Grid <- expand.grid(0,0) + expect_equal(kde(X=X,Grid=Grid,h=1/sqrt(2*pi)),1) +}) + +test_that("peaks are higher than surrounding points", { + h <- 1/sqrt(2*pi) + Grid <- expand.grid(seq(-2,2,by = .01)) + X <- expand.grid(c(-1,1)) + KDE = kde(X = X, Grid = Grid, h=h) + checks = c(max(KDE) > KDE[1], max(KDE) > KDE[401], max(KDE) > KDE[201]) + for (bool in checks) { + expect_true(bool) + } +}) + + diff --git a/tests/testthat/test_rips.R b/tests/testthat/test_rips.R new file mode 100644 index 0000000..a479f6f --- /dev/null +++ b/tests/testthat/test_rips.R @@ -0,0 +1,29 @@ +context("ripsFiltration") + +test_that("default circle example ripsFiltration", { + n <- 5 + X <- cbind(cos(2*pi*seq_len(n)/n), sin(2*pi*seq_len(n)/n)) + maxdimension <- 1 + maxscale <- 1.5 + FltRips <- ripsFiltration(X = X, maxdimension = maxdimension, + maxscale = maxscale, dist = "euclidean", library = "Dionysus", + printProgress = TRUE) + expect_equal(FltRips$cmplx[[1]],1) + expect_true(FltRips$increasing) + expect_equal(FltRips$values[[1]],0) + expect_true(abs(FltRips$values[[8]]-1.175571)<.000001) +}) + +test_that("One dimensional ripsFiltration in a line", { + Y <- matrix(c(1,2.1,3.3,4.6,6)) + FltRips <- ripsFiltration(X =Y, maxdimension = 1, + maxscale = 1.5, dist = "euclidean", library = "Dionysus", + printProgress = TRUE) + expect_equal(FltRips$cmplx[[1]],1) + expect_equal(FltRips$cmplx[[6]],c(1,2)) + expect_equal(FltRips$cmplx[[9]],c(4,5)) + expect_equal(FltRips$cmplx[[8]],c(3,4)) + expect_true(FltRips$increasing) +}) + + From 038d2162c1e83a746acfe6da4328549b9a5446ae Mon Sep 17 00:00:00 2001 From: thomashli Date: Wed, 18 Jul 2018 11:41:21 -0600 Subject: [PATCH 02/29] fixed labels? --- R/gridDiag.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/gridDiag.R b/R/gridDiag.R index f857b05..87fac18 100644 --- a/R/gridDiag.R +++ b/R/gridDiag.R @@ -127,7 +127,7 @@ function(X = NULL, FUN = NULL, lim = NULL, by = NULL, FUNvalues = NULL, Diag[1, 3] <- ifelse(is.null(diagLimit), max(FUNvalues), diagLimit) } if (sublevel == FALSE) { - colnames(Diag) <- c("dimension", "Death", "Birth") + colnames(Diag) <- c("dimension", "Birth", "Death") Diag[, 2:3] <- -Diag[, 3:2] } else { colnames(Diag) <- c("dimension", "Birth", "Death") From b62f2fc09771a2808b1c0137eb34024d18158db1 Mon Sep 17 00:00:00 2001 From: thomashli Date: Wed, 18 Jul 2018 11:42:28 -0600 Subject: [PATCH 03/29] fixed labels? --- R/gridDiag.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/gridDiag.R b/R/gridDiag.R index 87fac18..399fcae 100644 --- a/R/gridDiag.R +++ b/R/gridDiag.R @@ -127,6 +127,7 @@ function(X = NULL, FUN = NULL, lim = NULL, by = NULL, FUNvalues = NULL, Diag[1, 3] <- ifelse(is.null(diagLimit), max(FUNvalues), diagLimit) } if (sublevel == FALSE) { + #possible bugfix with labels colnames(Diag) <- c("dimension", "Birth", "Death") Diag[, 2:3] <- -Diag[, 3:2] } else { From 17f0dd11cd1c89cb55272bb13378ad9b1fa32f72 Mon Sep 17 00:00:00 2001 From: thomashli Date: Wed, 18 Jul 2018 14:49:54 -0600 Subject: [PATCH 04/29] switched superlevel by accident --- R/gridDiag.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/gridDiag.R b/R/gridDiag.R index 399fcae..070dc65 100644 --- a/R/gridDiag.R +++ b/R/gridDiag.R @@ -128,7 +128,7 @@ function(X = NULL, FUN = NULL, lim = NULL, by = NULL, FUNvalues = NULL, } if (sublevel == FALSE) { #possible bugfix with labels - colnames(Diag) <- c("dimension", "Birth", "Death") + colnames(Diag) <- c("dimension", "Death", "Birth") Diag[, 2:3] <- -Diag[, 3:2] } else { colnames(Diag) <- c("dimension", "Birth", "Death") From 745c5920675db95f91186e9d8cb9d8cd1c8a9061 Mon Sep 17 00:00:00 2001 From: thomashli Date: Thu, 19 Jul 2018 11:29:03 -0600 Subject: [PATCH 05/29] moved dionysus up 1 --- src/dionysus/backward.hpp | 2212 ----------------- src/dionysus/{dionysus => }/chain.h | 0 src/dionysus/{dionysus => }/chain.hpp | 0 .../{dionysus => }/clearing-reduction.h | 0 .../{dionysus => }/clearing-reduction.hpp | 0 .../{dionysus => }/cohomology-persistence.h | 0 .../{dionysus => }/cohomology-persistence.hpp | 0 src/dionysus/{dionysus => }/diagram.h | 0 src/dionysus/{dionysus => }/distances.h | 0 src/dionysus/{dionysus => }/distances.hpp | 0 src/dionysus/{dionysus => }/dlog/progress.h | 0 src/dionysus/{dionysus => }/fields/q.h | 0 src/dionysus/{dionysus => }/fields/z2.h | 0 src/dionysus/{dionysus => }/fields/zp.h | 0 src/dionysus/{dionysus => }/filtration.h | 0 .../{dionysus => }/omni-field-persistence.h | 0 .../{dionysus => }/omni-field-persistence.hpp | 0 .../{dionysus => }/ordinary-persistence.h | 0 src/dionysus/{dionysus => }/pair-recorder.h | 0 src/dionysus/{dionysus => }/reduced-matrix.h | 0 .../{dionysus => }/reduced-matrix.hpp | 0 src/dionysus/{dionysus => }/reduction.h | 0 .../{dionysus => }/relative-homology-zigzag.h | 0 .../relative-homology-zigzag.hpp | 0 src/dionysus/{dionysus => }/rips.h | 0 src/dionysus/{dionysus => }/rips.hpp | 0 src/dionysus/{dionysus => }/row-reduction.h | 0 src/dionysus/{dionysus => }/row-reduction.hpp | 0 src/dionysus/{dionysus => }/simplex.h | 0 .../{dionysus => }/sparse-row-matrix.h | 0 .../{dionysus => }/sparse-row-matrix.hpp | 0 .../{dionysus => }/standard-reduction.h | 0 .../{dionysus => }/standard-reduction.hpp | 0 src/dionysus/{dionysus => }/trails-chains.h | 0 .../{dionysus => }/zigzag-persistence.h | 0 .../{dionysus => }/zigzag-persistence.hpp | 0 src/tdautils/dionysusUtils.h | 7 +- 37 files changed, 3 insertions(+), 2216 deletions(-) delete mode 100755 src/dionysus/backward.hpp rename src/dionysus/{dionysus => }/chain.h (100%) rename src/dionysus/{dionysus => }/chain.hpp (100%) rename src/dionysus/{dionysus => }/clearing-reduction.h (100%) rename src/dionysus/{dionysus => }/clearing-reduction.hpp (100%) rename src/dionysus/{dionysus => }/cohomology-persistence.h (100%) rename src/dionysus/{dionysus => }/cohomology-persistence.hpp (100%) rename src/dionysus/{dionysus => }/diagram.h (100%) rename src/dionysus/{dionysus => }/distances.h (100%) rename src/dionysus/{dionysus => }/distances.hpp (100%) rename src/dionysus/{dionysus => }/dlog/progress.h (100%) rename src/dionysus/{dionysus => }/fields/q.h (100%) rename src/dionysus/{dionysus => }/fields/z2.h (100%) rename src/dionysus/{dionysus => }/fields/zp.h (100%) rename src/dionysus/{dionysus => }/filtration.h (100%) rename src/dionysus/{dionysus => }/omni-field-persistence.h (100%) rename src/dionysus/{dionysus => }/omni-field-persistence.hpp (100%) rename src/dionysus/{dionysus => }/ordinary-persistence.h (100%) rename src/dionysus/{dionysus => }/pair-recorder.h (100%) rename src/dionysus/{dionysus => }/reduced-matrix.h (100%) rename src/dionysus/{dionysus => }/reduced-matrix.hpp (100%) rename src/dionysus/{dionysus => }/reduction.h (100%) rename src/dionysus/{dionysus => }/relative-homology-zigzag.h (100%) rename src/dionysus/{dionysus => }/relative-homology-zigzag.hpp (100%) rename src/dionysus/{dionysus => }/rips.h (100%) rename src/dionysus/{dionysus => }/rips.hpp (100%) rename src/dionysus/{dionysus => }/row-reduction.h (100%) rename src/dionysus/{dionysus => }/row-reduction.hpp (100%) rename src/dionysus/{dionysus => }/simplex.h (100%) rename src/dionysus/{dionysus => }/sparse-row-matrix.h (100%) rename src/dionysus/{dionysus => }/sparse-row-matrix.hpp (100%) rename src/dionysus/{dionysus => }/standard-reduction.h (100%) rename src/dionysus/{dionysus => }/standard-reduction.hpp (100%) rename src/dionysus/{dionysus => }/trails-chains.h (100%) rename src/dionysus/{dionysus => }/zigzag-persistence.h (100%) rename src/dionysus/{dionysus => }/zigzag-persistence.hpp (100%) diff --git a/src/dionysus/backward.hpp b/src/dionysus/backward.hpp deleted file mode 100755 index 6b331ba..0000000 --- a/src/dionysus/backward.hpp +++ /dev/null @@ -1,2212 +0,0 @@ -/* - * backward.hpp - * Copyright 2013 Google Inc. All Rights Reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#ifndef H_6B9572DA_A64B_49E6_B234_051480991C89 -#define H_6B9572DA_A64B_49E6_B234_051480991C89 - -#ifndef __cplusplus -# error "It's not going to compile without a C++ compiler..." -#endif - -#if defined(BACKWARD_CXX11) -#elif defined(BACKWARD_CXX98) -#else -# if __cplusplus >= 201103L -# define BACKWARD_CXX11 -# else -# define BACKWARD_CXX98 -# endif -#endif - -// You can define one of the following (or leave it to the auto-detection): -// -// #define BACKWARD_SYSTEM_LINUX -// - specialization for linux -// -// #define BACKWARD_SYSTEM_UNKNOWN -// - placebo implementation, does nothing. -// -#if defined(BACKWARD_SYSTEM_LINUX) -#elif defined(BACKWARD_SYSTEM_UNKNOWN) -#else -# if defined(__linux) -# define BACKWARD_SYSTEM_LINUX -# else -# define BACKWARD_SYSTEM_UNKNOWN -# endif -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(BACKWARD_SYSTEM_LINUX) - -// On linux, backtrace can back-trace or "walk" the stack using the following -// library: -// -// #define BACKWARD_HAS_UNWIND 1 -// - unwind comes from libgcc, but I saw an equivalent inside clang itself. -// - with unwind, the stacktrace is as accurate as it can possibly be, since -// this is used by the C++ runtine in gcc/clang for stack unwinding on -// exception. -// - normally libgcc is already linked to your program by default. -// -// #define BACKWARD_HAS_BACKTRACE == 1 -// - backtrace seems to be a little bit more portable than libunwind, but on -// linux, it uses unwind anyway, but abstract away a tiny information that is -// sadly really important in order to get perfectly accurate stack traces. -// - backtrace is part of the (e)glib library. -// -// The default is: -// #define BACKWARD_HAS_UNWIND == 1 -// -# if BACKWARD_HAS_UNWIND == 1 -# elif BACKWARD_HAS_BACKTRACE == 1 -# else -# undef BACKWARD_HAS_UNWIND -# define BACKWARD_HAS_UNWIND 1 -# undef BACKWARD_HAS_BACKTRACE -# define BACKWARD_HAS_BACKTRACE 0 -# endif - -// On linux, backward can extract detailed information about a stack trace -// using one of the following library: -// -// #define BACKWARD_HAS_DW 1 -// - libdw gives you the most juicy details out of your stack traces: -// - object filename -// - function name -// - source filename -// - line and column numbers -// - source code snippet (assuming the file is accessible) -// - variables name and values (if not optimized out) -// - You need to link with the lib "dw": -// - apt-get install libdw-dev -// - g++/clang++ -ldw ... -// -// #define BACKWARD_HAS_BFD 1 -// - With libbfd, you get a fair about of details: -// - object filename -// - function name -// - source filename -// - line numbers -// - source code snippet (assuming the file is accessible) -// - You need to link with the lib "bfd": -// - apt-get install binutils-dev -// - g++/clang++ -lbfd ... -// -// #define BACKWARD_HAS_BACKTRACE_SYMBOL 1 -// - backtrace provides minimal details for a stack trace: -// - object filename -// - function name -// - backtrace is part of the (e)glib library. -// -// The default is: -// #define BACKWARD_HAS_BACKTRACE_SYMBOL == 1 -// -# if BACKWARD_HAS_DW == 1 -# elif BACKWARD_HAS_BFD == 1 -# elif BACKWARD_HAS_BACKTRACE_SYMBOL == 1 -# else -# undef BACKWARD_HAS_DW -# define BACKWARD_HAS_DW 0 -# undef BACKWARD_HAS_BFD -# define BACKWARD_HAS_BFD 0 -# undef BACKWARD_HAS_BACKTRACE_SYMBOL -# define BACKWARD_HAS_BACKTRACE_SYMBOL 1 -# endif - - -# if BACKWARD_HAS_UNWIND == 1 - -# include -// while gcc's unwind.h defines something like that: -// extern _Unwind_Ptr _Unwind_GetIP (struct _Unwind_Context *); -// extern _Unwind_Ptr _Unwind_GetIPInfo (struct _Unwind_Context *, int *); -// -// clang's unwind.h defines something like this: -// uintptr_t _Unwind_GetIP(struct _Unwind_Context* __context); -// -// Even if the _Unwind_GetIPInfo can be linked to, it is not declared, worse we -// cannot just redeclare it because clang's unwind.h doesn't define _Unwind_Ptr -// anyway. -// -// Luckily we can play on the fact that the guard macros have a different name: -#ifdef __CLANG_UNWIND_H -// In fact, this function still comes from libgcc (on my different linux boxes, -// clang links against libgcc). -# include -extern "C" uintptr_t _Unwind_GetIPInfo(_Unwind_Context*, int*); -#endif - -# endif - -# include -# include -# include -# include -# include -# include -# include - -# if BACKWARD_HAS_BFD == 1 -# include -# ifndef _GNU_SOURCE -# define _GNU_SOURCE -# include -# undef _GNU_SOURCE -# else -# include -# endif -# endif - -# if BACKWARD_HAS_DW == 1 -# include -# include -# include -# endif - -# if (BACKWARD_HAS_BACKTRACE == 1) || (BACKWARD_HAS_BACKTRACE_SYMBOL == 1) - // then we shall rely on backtrace -# include -# endif - -#endif // defined(BACKWARD_SYSTEM_LINUX) - -#if defined(BACKWARD_CXX11) -# include -# include // for std::swap - namespace backward { - namespace details { - template - struct hashtable { - typedef std::unordered_map type; - }; - using std::move; - } // namespace details - } // namespace backward -#elif defined(BACKWARD_CXX98) -# include - namespace backward { - namespace details { - template - struct hashtable { - typedef std::map type; - }; - template - const T& move(const T& v) { return v; } - template - T& move(T& v) { return v; } - } // namespace details - } // namespace backward -#else -# error "Mmm if its not C++11 nor C++98... go play in the toaster." -#endif - -namespace backward { - -namespace system_tag { - struct linux_tag; // seems that I cannot call that "linux" because the name - // is already defined... so I am adding _tag everywhere. - struct unknown_tag; - -#if defined(BACKWARD_SYSTEM_LINUX) - typedef linux_tag current_tag; -#elif defined(BACKWARD_SYSTEM_UNKNOWN) - typedef unknown_tag current_tag; -#else -# error "May I please get my system defines?" -#endif -} // namespace system_tag - - -namespace stacktrace_tag { -#ifdef BACKWARD_SYSTEM_LINUX - struct unwind; - struct backtrace; - -# if BACKWARD_HAS_UNWIND == 1 - typedef unwind current; -# elif BACKWARD_HAS_BACKTRACE == 1 - typedef backtrace current; -# else -# error "I know it's difficult but you need to make a choice!" -# endif -#endif // BACKWARD_SYSTEM_LINUX -} // namespace stacktrace_tag - - -namespace trace_resolver_tag { -#ifdef BACKWARD_SYSTEM_LINUX - struct libdw; - struct libbfd; - struct backtrace_symbol; - -# if BACKWARD_HAS_DW == 1 - typedef libdw current; -# elif BACKWARD_HAS_BFD == 1 - typedef libbfd current; -# elif BACKWARD_HAS_BACKTRACE_SYMBOL == 1 - typedef backtrace_symbol current; -# else -# error "You shall not pass, until you know what you want." -# endif -#endif // BACKWARD_SYSTEM_LINUX -} // namespace trace_resolver_tag - -namespace details { - -template - struct rm_ptr { typedef T type; }; - -template - struct rm_ptr { typedef T type; }; - -template - struct rm_ptr { typedef const T type; }; - -template -struct deleter { - template - void operator()(U& ptr) const { - (*F)(ptr); - } -}; - -template -struct default_delete { - void operator()(T& ptr) const { - delete ptr; - } -}; - -template > -class handle { - struct dummy; - T _val; - bool _empty; - -#if defined(BACKWARD_CXX11) - handle(const handle&) = delete; - handle& operator=(const handle&) = delete; -#endif - -public: - ~handle() { - if (not _empty) { - Deleter()(_val); - } - } - - explicit handle(): _val(), _empty(true) {} - explicit handle(T val): _val(val), _empty(false) {} - -#if defined(BACKWARD_CXX11) - handle(handle&& from): _empty(true) { - swap(from); - } - handle& operator=(handle&& from) { - swap(from); return *this; - } -#else - explicit handle(const handle& from): _empty(true) { - // some sort of poor man's move semantic. - swap(const_cast(from)); - } - handle& operator=(const handle& from) { - // some sort of poor man's move semantic. - swap(const_cast(from)); return *this; - } -#endif - - void reset(T new_val) { - handle tmp(new_val); - swap(tmp); - } - operator const dummy*() const { - if (_empty) { - return 0; - } - return reinterpret_cast(_val); - } - T get() { - return _val; - } - T release() { - _empty = true; - return _val; - } - void swap(handle& b) { - using std::swap; - swap(b._val, _val); // can throw, we are safe here. - swap(b._empty, _empty); // should not throw: if you cannot swap two - // bools without throwing... It's a lost cause anyway! - } - - T operator->() { return _val; } - const T operator->() const { return _val; } - - typedef typename rm_ptr::type& ref_t; - ref_t operator*() { return *_val; } - const ref_t operator*() const { return *_val; } - ref_t operator[](size_t idx) { return _val[idx]; } - - // Watch out, we've got a badass over here - T* operator&() { - _empty = false; - return &_val; - } -}; - -} // namespace details - -/*************** A TRACE ***************/ - -struct Trace { - void* addr; - size_t idx; - - Trace(): - addr(0), idx(0) {} - - explicit Trace(void* addr, size_t idx): - addr(addr), idx(idx) {} -}; - -// Really simple, generic, and dumb representation of a variable. -// A variable has a name and can represent either: -// - a value (as a string) -// - a list of values (a list of strings) -// - a map of values (a list of variable) -class Variable { -public: - enum Kind { VALUE, LIST, MAP }; - - typedef std::vector list_t; - typedef std::vector map_t; - - std::string name; - Kind kind; - - Variable(Kind k): kind(k) { - switch (k) { - case VALUE: - new (&storage) std::string(); - break; - - case LIST: - new (&storage) list_t(); - break; - - case MAP: - new (&storage) map_t(); - break; - } - } - - std::string& value() { - return reinterpret_cast(storage); - } - list_t& list() { - return reinterpret_cast(storage); - } - map_t& map() { - return reinterpret_cast(storage); - } - - - const std::string& value() const { - return reinterpret_cast(storage); - } - const list_t& list() const { - return reinterpret_cast(storage); - } - const map_t& map() const { - return reinterpret_cast(storage); - } - -private: - // the C++98 style union for non-trivial objects, yes yes I know, its not - // aligned as good as it can be, blabla... Screw this. - union { - char s1[sizeof (std::string)]; - char s2[sizeof (list_t)]; - char s3[sizeof (map_t)]; - } storage; -}; - -struct TraceWithLocals: public Trace { - // Locals variable and values. - std::vector locals; - - TraceWithLocals(): Trace() {} - TraceWithLocals(const Trace& mini_trace): - Trace(mini_trace) {} -}; - -struct ResolvedTrace: public TraceWithLocals { - - struct SourceLoc { - std::string function; - std::string filename; - unsigned line; - unsigned col; - - SourceLoc(): line(0), col(0) {} - - bool operator==(const SourceLoc& b) const { - return function == b.function - and filename == b.filename - and line == b.line - and col == b.col; - } - - bool operator!=(const SourceLoc& b) const { - return not (*this == b); - } - }; - - // In which binary object this trace is located. - std::string object_filename; - - // The function in the object that contain the trace. This is not the same - // as source.function which can be an function inlined in object_function. - std::string object_function; - - // The source location of this trace. It is possible for filename to be - // empty and for line/col to be invalid (value 0) if this information - // couldn't be deduced, for example if there is no debug information in the - // binary object. - SourceLoc source; - - // An optionals list of "inliners". All the successive sources location - // from where the source location of the trace (the attribute right above) - // is inlined. It is especially useful when you compiled with optimization. - typedef std::vector source_locs_t; - source_locs_t inliners; - - ResolvedTrace(const Trace& mini_trace): - TraceWithLocals(mini_trace) {} - ResolvedTrace(const TraceWithLocals& mini_trace_with_locals): - TraceWithLocals(mini_trace_with_locals) {} -}; - -/*************** STACK TRACE ***************/ - -// default implemention. -template -class StackTraceImpl { -public: - size_t size() const { return 0; } - Trace operator[](size_t) { return Trace(); } - size_t load_here(size_t=0) { return 0; } - size_t load_from(void*, size_t=0) { return 0; } - unsigned thread_id() const { return 0; } -}; - -#ifdef BACKWARD_SYSTEM_LINUX - -class StackTraceLinuxImplBase { -public: - StackTraceLinuxImplBase(): _thread_id(0), _skip(0) {} - - unsigned thread_id() const { - return _thread_id; - } - -protected: - void load_thread_info() { - _thread_id = syscall(SYS_gettid); - if (_thread_id == (size_t) getpid()) { - // If the thread is the main one, let's hide that. - // I like to keep little secret sometimes. - _thread_id = 0; - } - } - - void skip_n_firsts(size_t n) { _skip = n; } - size_t skip_n_firsts() const { return _skip; } - -private: - size_t _thread_id; - size_t _skip; -}; - -class StackTraceLinuxImplHolder: public StackTraceLinuxImplBase { -public: - size_t size() const { - return _stacktrace.size() ? _stacktrace.size() - skip_n_firsts() : 0; - } - Trace operator[](size_t idx) { - if (idx >= size()) { - return Trace(); - } - return Trace(_stacktrace[idx + skip_n_firsts()], idx); - } - void** begin() { - if (size()) { - return &_stacktrace[skip_n_firsts()]; - } - return 0; - } - -protected: - std::vector _stacktrace; -}; - - -#if BACKWARD_HAS_UNWIND == 1 - -namespace details { - -template -class Unwinder { -public: - size_t operator()(F& f, size_t depth) { - _f = &f; - _index = -1; - _depth = depth; - _Unwind_Backtrace(&this->backtrace_trampoline, this); - return _index; - } - -private: - F* _f; - ssize_t _index; - size_t _depth; - - static _Unwind_Reason_Code backtrace_trampoline( - _Unwind_Context* ctx, void *self) { - return ((Unwinder*)self)->backtrace(ctx); - } - - _Unwind_Reason_Code backtrace(_Unwind_Context* ctx) { - if (_index >= 0 and static_cast(_index) >= _depth) - return _URC_END_OF_STACK; - - int ip_before_instruction = 0; - uintptr_t ip = _Unwind_GetIPInfo(ctx, &ip_before_instruction); - - if (not ip_before_instruction) { - ip -= 1; - } - - if (_index >= 0) { // ignore first frame. - (*_f)(_index, (void*)ip); - } - _index += 1; - return _URC_NO_REASON; - } -}; - -template -size_t unwind(F f, size_t depth) { - Unwinder unwinder; - return unwinder(f, depth); -} - -} // namespace details - - -template <> -class StackTraceImpl: public StackTraceLinuxImplHolder { -public: - __attribute__ ((noinline)) // TODO use some macro - size_t load_here(size_t depth=32) { - load_thread_info(); - if (depth == 0) { - return 0; - } - _stacktrace.resize(depth); - size_t trace_cnt = details::unwind(callback(*this), depth); - _stacktrace.resize(trace_cnt); - skip_n_firsts(0); - return size(); - } - size_t load_from(void* addr, size_t depth=32) { - load_here(depth + 8); - - for (size_t i = 0; i < _stacktrace.size(); ++i) { - if (_stacktrace[i] == addr) { - skip_n_firsts(i); - break; - } - } - - _stacktrace.resize(std::min(_stacktrace.size(), - skip_n_firsts() + depth)); - return size(); - } - -private: - struct callback { - StackTraceImpl& self; - callback(StackTraceImpl& self): self(self) {} - - void operator()(size_t idx, void* addr) { - self._stacktrace[idx] = addr; - } - }; -}; - - -#else // BACKWARD_HAS_UNWIND == 0 - -template <> -class StackTraceImpl: public StackTraceLinuxImplHolder { -public: - __attribute__ ((noinline)) // TODO use some macro - size_t load_here(size_t depth=32) { - load_thread_info(); - if (depth == 0) { - return 0; - } - _stacktrace.resize(depth + 1); - size_t trace_cnt = backtrace(&_stacktrace[0], _stacktrace.size()); - _stacktrace.resize(trace_cnt); - skip_n_firsts(1); - return size(); - } - - size_t load_from(void* addr, size_t depth=32) { - load_here(depth + 8); - - for (size_t i = 0; i < _stacktrace.size(); ++i) { - if (_stacktrace[i] == addr) { - skip_n_firsts(i); - _stacktrace[i] = (void*)( (uintptr_t)_stacktrace[i] + 1); - break; - } - } - - _stacktrace.resize(std::min(_stacktrace.size(), - skip_n_firsts() + depth)); - return size(); - } -}; - -#endif // BACKWARD_HAS_UNWIND -#endif // BACKWARD_SYSTEM_LINUX - -class StackTrace: - public StackTraceImpl {}; - -/*********** STACKTRACE WITH LOCALS ***********/ - -// default implemention. -template -class StackTraceWithLocalsImpl: - public StackTrace {}; - -#ifdef BACKWARD_SYSTEM_LINUX -#if BACKWARD_HAS_UNWIND -#if BACKWARD_HAS_DW - -template <> -class StackTraceWithLocalsImpl: - public StackTraceLinuxImplBase { -public: - __attribute__ ((noinline)) // TODO use some macro - size_t load_here(size_t depth=32) { - load_thread_info(); - if (depth == 0) { - return 0; - } - _stacktrace.resize(depth); - size_t trace_cnt = details::unwind(callback(*this), depth); - _stacktrace.resize(trace_cnt); - skip_n_firsts(0); - return size(); - } - size_t load_from(void* addr, size_t depth=32) { - load_here(depth + 8); - - for (size_t i = 0; i < _stacktrace.size(); ++i) { - if (_stacktrace[i].addr == addr) { - skip_n_firsts(i); - break; - } - } - _stacktrace.resize(std::min(_stacktrace.size(), - skip_n_firsts() + depth)); - return size(); - } - size_t size() const { - return _stacktrace.size() ? _stacktrace.size() - skip_n_firsts() : 0; - } - const TraceWithLocals& operator[](size_t idx) { - if (idx >= size()) { - return _nil_trace; - } - return _stacktrace[idx + skip_n_firsts()]; - } - -private: - std::vector _stacktrace; - TraceWithLocals _nil_trace; - - void resolve_trace(TraceWithLocals& trace) { - Variable v(Variable::VALUE); - v.name = "var"; - v.value() = "42"; - trace.locals.push_back(v); - } - - struct callback { - StackTraceWithLocalsImpl& self; - callback(StackTraceWithLocalsImpl& self): self(self) {} - - void operator()(size_t idx, void* addr) { - self._stacktrace[idx].addr = addr; - self.resolve_trace(self._stacktrace[idx]); - } - }; -}; - -#endif // BACKWARD_HAS_DW -#endif // BACKWARD_HAS_UNWIND -#endif // BACKWARD_SYSTEM_LINUX - -class StackTraceWithLocals: - public StackTraceWithLocalsImpl {}; - -/*************** TRACE RESOLVER ***************/ - -template -class TraceResolverImpl; - -#ifdef BACKWARD_SYSTEM_UNKNOWN - -template <> -class TraceResolverImpl { -public: - template - void load_stacktrace(ST&) {} - ResolvedTrace resolve(ResolvedTrace t) { - return t; - } -}; - -#endif - -#ifdef BACKWARD_SYSTEM_LINUX - -class TraceResolverLinuxImplBase { -protected: - std::string demangle(const char* funcname) { - using namespace details; - _demangle_buffer.reset( - abi::__cxa_demangle(funcname, _demangle_buffer.release(), - &_demangle_buffer_length, 0) - ); - if (_demangle_buffer) { - return _demangle_buffer.get(); - } - return funcname; - } - -private: - details::handle _demangle_buffer; - size_t _demangle_buffer_length; -}; - -template -class TraceResolverLinuxImpl; - -#if BACKWARD_HAS_BACKTRACE_SYMBOL == 1 - -template <> -class TraceResolverLinuxImpl: - public TraceResolverLinuxImplBase { -public: - template - void load_stacktrace(ST& st) { - using namespace details; - if (st.size() == 0) { - return; - } - _symbols.reset( - backtrace_symbols(st.begin(), st.size()) - ); - } - - ResolvedTrace resolve(ResolvedTrace trace) { - char* filename = _symbols[trace.idx]; - char* funcname = filename; - while (*funcname && *funcname != '(') { - funcname += 1; - } - trace.object_filename.assign(filename, funcname++); - char* funcname_end = funcname; - while (*funcname_end && *funcname_end != ')' && *funcname_end != '+') { - funcname_end += 1; - } - *funcname_end = '\0'; - trace.object_function = this->demangle(funcname); - trace.source.function = trace.object_function; // we cannot do better. - return trace; - } - -private: - details::handle _symbols; -}; - -#endif // BACKWARD_HAS_BACKTRACE_SYMBOL == 1 - -#if BACKWARD_HAS_BFD == 1 - -template <> -class TraceResolverLinuxImpl: - public TraceResolverLinuxImplBase { -public: - TraceResolverLinuxImpl(): _bfd_loaded(false) {} - - template - void load_stacktrace(ST&) {} - - ResolvedTrace resolve(ResolvedTrace trace) { - Dl_info symbol_info; - - // trace.addr is a virtual address in memory pointing to some code. - // Let's try to find from which loaded object it comes from. - // The loaded object can be yourself btw. - if (not dladdr(trace.addr, &symbol_info)) { - return trace; // dat broken trace... - } - - // Now we get in symbol_info: - // .dli_fname: - // pathname of the shared object that contains the address. - // .dli_fbase: - // where the object is loaded in memory. - // .dli_sname: - // the name of the nearest symbol to trace.addr, we expect a - // function name. - // .dli_saddr: - // the exact address corresponding to .dli_sname. - - if (symbol_info.dli_sname) { - trace.object_function = demangle(symbol_info.dli_sname); - } - - if (not symbol_info.dli_fname) { - return trace; - } - - trace.object_filename = symbol_info.dli_fname; - bfd_fileobject& fobj = load_object_with_bfd(symbol_info.dli_fname); - if (not fobj.handle) { - return trace; // sad, we couldn't load the object :( - } - - - find_sym_result* details_selected; // to be filled. - - // trace.addr is the next instruction to be executed after returning - // from the nested stack frame. In C++ this usually relate to the next - // statement right after the function call that leaded to a new stack - // frame. This is not usually what you want to see when printing out a - // stacktrace... - find_sym_result details_call_site = find_symbol_details(fobj, - trace.addr, symbol_info.dli_fbase); - details_selected = &details_call_site; - -#if BACKWARD_HAS_UNWIND == 0 - // ...this is why we also try to resolve the symbol that is right - // before the return address. If we are lucky enough, we will get the - // line of the function that was called. But if the code is optimized, - // we might get something absolutely not related since the compiler - // can reschedule the return address with inline functions and - // tail-call optimisation (among other things that I don't even know - // or cannot even dream about with my tiny limited brain). - find_sym_result details_adjusted_call_site = find_symbol_details(fobj, - (void*) (uintptr_t(trace.addr) - 1), - symbol_info.dli_fbase); - - // In debug mode, we should always get the right thing(TM). - if (details_call_site.found and details_adjusted_call_site.found) { - // Ok, we assume that details_adjusted_call_site is a better estimation. - details_selected = &details_adjusted_call_site; - trace.addr = (void*) (uintptr_t(trace.addr) - 1); - } - - if (details_selected == &details_call_site and details_call_site.found) { - // we have to re-resolve the symbol in order to reset some - // internal state in BFD... so we can call backtrace_inliners - // thereafter... - details_call_site = find_symbol_details(fobj, trace.addr, - symbol_info.dli_fbase); - } -#endif // BACKWARD_HAS_UNWIND - - if (details_selected->found) { - if (details_selected->filename) { - trace.source.filename = details_selected->filename; - } - trace.source.line = details_selected->line; - - if (details_selected->funcname) { - // this time we get the name of the function where the code is - // located, instead of the function were the address is - // located. In short, if the code was inlined, we get the - // function correspoding to the code. Else we already got in - // trace.function. - trace.source.function = demangle(details_selected->funcname); - - if (not symbol_info.dli_sname) { - // for the case dladdr failed to find the symbol name of - // the function, we might as well try to put something - // here. - trace.object_function = trace.source.function; - } - } - - // Maybe the source of the trace got inlined inside the function - // (trace.source.function). Let's see if we can get all the inlined - // calls along the way up to the initial call site. - trace.inliners = backtrace_inliners(fobj, *details_selected); - -#if 0 - if (trace.inliners.size() == 0) { - // Maybe the trace was not inlined... or maybe it was and we - // are lacking the debug information. Let's try to make the - // world better and see if we can get the line number of the - // function (trace.source.function) now. - // - // We will get the location of where the function start (to be - // exact: the first instruction that really start the - // function), not where the name of the function is defined. - // This can be quite far away from the name of the function - // btw. - // - // If the source of the function is the same as the source of - // the trace, we cannot say if the trace was really inlined or - // not. However, if the filename of the source is different - // between the function and the trace... we can declare it as - // an inliner. This is not 100% accurate, but better than - // nothing. - - if (symbol_info.dli_saddr) { - find_sym_result details = find_symbol_details(fobj, - symbol_info.dli_saddr, - symbol_info.dli_fbase); - - if (details.found) { - ResolvedTrace::SourceLoc diy_inliner; - diy_inliner.line = details.line; - if (details.filename) { - diy_inliner.filename = details.filename; - } - if (details.funcname) { - diy_inliner.function = demangle(details.funcname); - } else { - diy_inliner.function = trace.source.function; - } - if (diy_inliner != trace.source) { - trace.inliners.push_back(diy_inliner); - } - } - } - } -#endif - } - - return trace; - } - -private: - bool _bfd_loaded; - - typedef details::handle - > bfd_handle_t; - - typedef details::handle bfd_symtab_t; - - - struct bfd_fileobject { - bfd_handle_t handle; - bfd_vma base_addr; - bfd_symtab_t symtab; - bfd_symtab_t dynamic_symtab; - }; - - typedef details::hashtable::type - fobj_bfd_map_t; - fobj_bfd_map_t _fobj_bfd_map; - - bfd_fileobject& load_object_with_bfd(const std::string& filename_object) { - using namespace details; - - if (not _bfd_loaded) { - using namespace details; - bfd_init(); - _bfd_loaded = true; - } - - fobj_bfd_map_t::iterator it = - _fobj_bfd_map.find(filename_object); - if (it != _fobj_bfd_map.end()) { - return it->second; - } - - // this new object is empty for now. - bfd_fileobject& r = _fobj_bfd_map[filename_object]; - - // we do the work temporary in this one; - bfd_handle_t bfd_handle; - - int fd = open(filename_object.c_str(), O_RDONLY); - bfd_handle.reset( - bfd_fdopenr(filename_object.c_str(), "default", fd) - ); - if (not bfd_handle) { - close(fd); - return r; - } - - if (not bfd_check_format(bfd_handle.get(), bfd_object)) { - return r; // not an object? You lose. - } - - if ((bfd_get_file_flags(bfd_handle.get()) & HAS_SYMS) == 0) { - return r; // that's what happen when you forget to compile in debug. - } - - ssize_t symtab_storage_size = - bfd_get_symtab_upper_bound(bfd_handle.get()); - - ssize_t dyn_symtab_storage_size = - bfd_get_dynamic_symtab_upper_bound(bfd_handle.get()); - - if (symtab_storage_size <= 0 and dyn_symtab_storage_size <= 0) { - return r; // weird, is the file is corrupted? - } - - bfd_symtab_t symtab, dynamic_symtab; - ssize_t symcount = 0, dyn_symcount = 0; - - if (symtab_storage_size > 0) { - symtab.reset( - (bfd_symbol**) malloc(symtab_storage_size) - ); - symcount = bfd_canonicalize_symtab( - bfd_handle.get(), symtab.get() - ); - } - - if (dyn_symtab_storage_size > 0) { - dynamic_symtab.reset( - (bfd_symbol**) malloc(dyn_symtab_storage_size) - ); - dyn_symcount = bfd_canonicalize_dynamic_symtab( - bfd_handle.get(), dynamic_symtab.get() - ); - } - - - if (symcount <= 0 and dyn_symcount <= 0) { - return r; // damned, that's a stripped file that you got there! - } - - r.handle = move(bfd_handle); - r.symtab = move(symtab); - r.dynamic_symtab = move(dynamic_symtab); - return r; - } - - struct find_sym_result { - bool found; - const char* filename; - const char* funcname; - unsigned int line; - }; - - struct find_sym_context { - TraceResolverLinuxImpl* self; - bfd_fileobject* fobj; - void* addr; - void* base_addr; - find_sym_result result; - }; - - find_sym_result find_symbol_details(bfd_fileobject& fobj, void* addr, - void* base_addr) { - find_sym_context context; - context.self = this; - context.fobj = &fobj; - context.addr = addr; - context.base_addr = base_addr; - context.result.found = false; - bfd_map_over_sections(fobj.handle.get(), &find_in_section_trampoline, - (void*)&context); - return context.result; - } - - static void find_in_section_trampoline(bfd*, asection* section, - void* data) { - find_sym_context* context = static_cast(data); - context->self->find_in_section( - reinterpret_cast(context->addr), - reinterpret_cast(context->base_addr), - *context->fobj, - section, context->result - ); - } - - void find_in_section(bfd_vma addr, bfd_vma base_addr, - bfd_fileobject& fobj, asection* section, find_sym_result& result) - { - if (result.found) return; - - if ((bfd_get_section_flags(fobj.handle.get(), section) - & SEC_ALLOC) == 0) - return; // a debug section is never loaded automatically. - - bfd_vma sec_addr = bfd_get_section_vma(fobj.handle.get(), section); - bfd_size_type size = bfd_get_section_size(section); - - // are we in the boundaries of the section? - if (addr < sec_addr or addr >= sec_addr + size) { - addr -= base_addr; // oups, a relocated object, lets try again... - if (addr < sec_addr or addr >= sec_addr + size) { - return; - } - } - - if (not result.found and fobj.symtab) { - result.found = bfd_find_nearest_line(fobj.handle.get(), section, - fobj.symtab.get(), addr - sec_addr, &result.filename, - &result.funcname, &result.line); - } - - if (not result.found and fobj.dynamic_symtab) { - result.found = bfd_find_nearest_line(fobj.handle.get(), section, - fobj.dynamic_symtab.get(), addr - sec_addr, - &result.filename, &result.funcname, &result.line); - } - - } - - ResolvedTrace::source_locs_t backtrace_inliners(bfd_fileobject& fobj, - find_sym_result previous_result) { - // This function can be called ONLY after a SUCCESSFUL call to - // find_symbol_details. The state is global to the bfd_handle. - ResolvedTrace::source_locs_t results; - while (previous_result.found) { - find_sym_result result; - result.found = bfd_find_inliner_info(fobj.handle.get(), - &result.filename, &result.funcname, &result.line); - - if (result.found) /* and not ( - cstrings_eq(previous_result.filename, result.filename) - and cstrings_eq(previous_result.funcname, result.funcname) - and result.line == previous_result.line - )) */ { - ResolvedTrace::SourceLoc src_loc; - src_loc.line = result.line; - if (result.filename) { - src_loc.filename = result.filename; - } - if (result.funcname) { - src_loc.function = demangle(result.funcname); - } - results.push_back(src_loc); - } - previous_result = result; - } - return results; - } - - bool cstrings_eq(const char* a, const char* b) { - if (not a or not b) { - return false; - } - return strcmp(a, b) == 0; - } - -}; -#endif // BACKWARD_HAS_BFD == 1 - -#if BACKWARD_HAS_DW == 1 - -template <> -class TraceResolverLinuxImpl: - public TraceResolverLinuxImplBase { -public: - TraceResolverLinuxImpl(): _dwfl_handle_initialized(false) {} - - template - void load_stacktrace(ST&) {} - - ResolvedTrace resolve(ResolvedTrace trace) { - using namespace details; - - Dwarf_Addr trace_addr = (Dwarf_Addr) trace.addr; - - if (not _dwfl_handle_initialized) { - // initialize dwfl... - _dwfl_cb.reset(new Dwfl_Callbacks); - _dwfl_cb->find_elf = &dwfl_linux_proc_find_elf; - _dwfl_cb->find_debuginfo = &dwfl_standard_find_debuginfo; - _dwfl_cb->debuginfo_path = 0; - - _dwfl_handle.reset(dwfl_begin(_dwfl_cb.get())); - _dwfl_handle_initialized = true; - - if (not _dwfl_handle) { - return trace; - } - - // ...from the current process. - dwfl_report_begin(_dwfl_handle.get()); - int r = dwfl_linux_proc_report (_dwfl_handle.get(), getpid()); - dwfl_report_end(_dwfl_handle.get(), NULL, NULL); - if (r < 0) { - return trace; - } - } - - if (not _dwfl_handle) { - return trace; - } - - // find the module (binary object) that contains the trace's address. - // This is not using any debug information, but the addresses ranges of - // all the currently loaded binary object. - Dwfl_Module* mod = dwfl_addrmodule(_dwfl_handle.get(), trace_addr); - if (mod) { - // now that we found it, lets get the name of it, this will be the - // full path to the running binary or one of the loaded library. - const char* module_name = dwfl_module_info (mod, - 0, 0, 0, 0, 0, 0, 0); - if (module_name) { - trace.object_filename = module_name; - } - // We also look after the name of the symbol, equal or before this - // address. This is found by walking the symtab. We should get the - // symbol corresponding to the function (mangled) containing the - // address. If the code corresponding to the address was inlined, - // this is the name of the out-most inliner function. - const char* sym_name = dwfl_module_addrname(mod, trace_addr); - if (sym_name) { - trace.object_function = demangle(sym_name); - } - } - - // now let's get serious, and find out the source location (file and - // line number) of the address. - - // This function will look in .debug_aranges for the address and map it - // to the location of the compilation unit DIE in .debug_info and - // return it. - Dwarf_Addr mod_bias = 0; - Dwarf_Die* cudie = dwfl_module_addrdie(mod, trace_addr, &mod_bias); - -#if 1 - if (not cudie) { - // Sadly clang does not generate the section .debug_aranges, thus - // dwfl_module_addrdie will fail early. Clang doesn't either set - // the lowpc/highpc/range info for every compilation unit. - // - // So in order to save the world: - // for every compilation unit, we will iterate over every single - // DIEs. Normally functions should have a lowpc/highpc/range, which - // we will use to infer the compilation unit. - - // note that this is probably badly inefficient. - while ((cudie = dwfl_module_nextcu(mod, cudie, &mod_bias))) { - Dwarf_Die die_mem; - Dwarf_Die* fundie = find_fundie_by_pc(cudie, - trace_addr - mod_bias, &die_mem); - if (fundie) { - break; - } - } - } -#endif - -//#define BACKWARD_I_DO_NOT_RECOMMEND_TO_ENABLE_THIS_HORRIBLE_PIECE_OF_CODE -#ifdef BACKWARD_I_DO_NOT_RECOMMEND_TO_ENABLE_THIS_HORRIBLE_PIECE_OF_CODE - if (not cudie) { - // If it's still not enough, lets dive deeper in the shit, and try - // to save the world again: for every compilation unit, we will - // load the corresponding .debug_line section, and see if we can - // find our address in it. - - Dwarf_Addr cfi_bias; - Dwarf_CFI* cfi_cache = dwfl_module_eh_cfi(mod, &cfi_bias); - - Dwarf_Addr bias; - while ((cudie = dwfl_module_nextcu(mod, cudie, &bias))) { - if (dwarf_getsrc_die(cudie, trace_addr - bias)) { - - // ...but if we get a match, it might be a false positive - // because our (address - bias) might as well be valid in a - // different compilation unit. So we throw our last card on - // the table and lookup for the address into the .eh_frame - // section. - - handle frame; - dwarf_cfi_addrframe(cfi_cache, trace_addr - cfi_bias, &frame); - if (frame) { - break; - } - } - } - } -#endif - - if (not cudie) { - return trace; // this time we lost the game :/ - } - - // Now that we have a compilation unit DIE, this function will be able - // to load the corresponding section in .debug_line (if not already - // loaded) and hopefully find the source location mapped to our - // address. - Dwarf_Line* srcloc = dwarf_getsrc_die(cudie, trace_addr - mod_bias); - - if (srcloc) { - const char* srcfile = dwarf_linesrc(srcloc, 0, 0); - if (srcfile) { - trace.source.filename = srcfile; - } - int line = 0, col = 0; - dwarf_lineno(srcloc, &line); - dwarf_linecol(srcloc, &col); - trace.source.line = line; - trace.source.col = col; - } - - deep_first_search_by_pc(cudie, trace_addr - mod_bias, - inliners_search_cb(trace)); - if (trace.source.function.size() == 0) { - // fallback. - trace.source.function = trace.object_function; - } - - return trace; - } - -private: - typedef details::handle > - dwfl_handle_t; - details::handle > - _dwfl_cb; - dwfl_handle_t _dwfl_handle; - bool _dwfl_handle_initialized; - - // defined here because in C++98, template function cannot take locally - // defined types... grrr. - struct inliners_search_cb { - void operator()(Dwarf_Die* die) { - switch (dwarf_tag(die)) { - const char* name; - case DW_TAG_subprogram: - if ((name = dwarf_diename(die))) { - trace.source.function = name; - } - break; - - case DW_TAG_inlined_subroutine: - ResolvedTrace::SourceLoc sloc; - Dwarf_Attribute attr_mem; - - if ((name = dwarf_diename(die))) { - trace.source.function = name; - } - if ((name = die_call_file(die))) { - sloc.filename = name; - } - - Dwarf_Word line = 0, col = 0; - dwarf_formudata(dwarf_attr(die, DW_AT_call_line, - &attr_mem), &line); - dwarf_formudata(dwarf_attr(die, DW_AT_call_column, - &attr_mem), &col); - sloc.line = line; - sloc.col = col; - - trace.inliners.push_back(sloc); - break; - }; - } - ResolvedTrace& trace; - inliners_search_cb(ResolvedTrace& t): trace(t) {} - }; - - - static bool die_has_pc(Dwarf_Die* die, Dwarf_Addr pc) { - Dwarf_Addr low, high; - - // continuous range - if (dwarf_hasattr(die, DW_AT_low_pc) and - dwarf_hasattr(die, DW_AT_high_pc)) { - if (dwarf_lowpc(die, &low) != 0) { - return false; - } - if (dwarf_highpc(die, &high) != 0) { - Dwarf_Attribute attr_mem; - Dwarf_Attribute* attr = dwarf_attr(die, DW_AT_high_pc, &attr_mem); - Dwarf_Word value; - if (dwarf_formudata(attr, &value) != 0) { - return false; - } - high = low + value; - } - return pc >= low and pc < high; - } - - // non-continuous range. - Dwarf_Addr base; - ptrdiff_t offset = 0; - while ((offset = dwarf_ranges(die, offset, &base, &low, &high)) > 0) { - if (pc >= low and pc < high) { - return true; - } - } - return false; - } - - static Dwarf_Die* find_fundie_by_pc(Dwarf_Die* parent_die, Dwarf_Addr pc, - Dwarf_Die* result) { - if (dwarf_child(parent_die, result) != 0) { - return 0; - } - - Dwarf_Die* die = result; - do { - switch (dwarf_tag(die)) { - case DW_TAG_subprogram: - case DW_TAG_inlined_subroutine: - if (die_has_pc(die, pc)) { - return result; - } - default: - bool declaration = false; - Dwarf_Attribute attr_mem; - dwarf_formflag(dwarf_attr(die, DW_AT_declaration, - &attr_mem), &declaration); - if (not declaration) { - // let's be curious and look deeper in the tree, - // function are not necessarily at the first level, but - // might be nested inside a namespace, structure etc. - Dwarf_Die die_mem; - Dwarf_Die* indie = find_fundie_by_pc(die, pc, &die_mem); - if (indie) { - *result = die_mem; - return result; - } - } - }; - } while (dwarf_siblingof(die, result) == 0); - return 0; - } - - template - static bool deep_first_search_by_pc(Dwarf_Die* parent_die, - Dwarf_Addr pc, CB cb) { - Dwarf_Die die_mem; - if (dwarf_child(parent_die, &die_mem) != 0) { - return false; - } - - bool branch_has_pc = false; - Dwarf_Die* die = &die_mem; - do { - bool declaration = false; - Dwarf_Attribute attr_mem; - dwarf_formflag(dwarf_attr(die, DW_AT_declaration, &attr_mem), &declaration); - if (not declaration) { - // let's be curious and look deeper in the tree, function are - // not necessarily at the first level, but might be nested - // inside a namespace, structure, a function, an inlined - // function etc. - branch_has_pc = deep_first_search_by_pc(die, pc, cb); - } - if (not branch_has_pc) { - branch_has_pc = die_has_pc(die, pc); - } - if (branch_has_pc) { - cb(die); - } - } while (dwarf_siblingof(die, &die_mem) == 0); - return branch_has_pc; - } - - static const char* die_call_file(Dwarf_Die *die) { - Dwarf_Attribute attr_mem; - Dwarf_Sword file_idx = 0; - - dwarf_formsdata(dwarf_attr(die, DW_AT_call_file, &attr_mem), - &file_idx); - - if (file_idx == 0) { - return 0; - } - - Dwarf_Die die_mem; - Dwarf_Die* cudie = dwarf_diecu(die, &die_mem, 0, 0); - if (not cudie) { - return 0; - } - - Dwarf_Files* files = 0; - size_t nfiles; - dwarf_getsrcfiles(cudie, &files, &nfiles); - if (not files) { - return 0; - } - - return dwarf_filesrc(files, file_idx, 0, 0); - } - -}; -#endif // BACKWARD_HAS_DW == 1 - -template<> -class TraceResolverImpl: - public TraceResolverLinuxImpl {}; - -#endif // BACKWARD_SYSTEM_LINUX - -class TraceResolver: - public TraceResolverImpl {}; - -/*************** CODE SNIPPET ***************/ - -class SourceFile { -public: - typedef std::vector > lines_t; - - SourceFile() {} - SourceFile(const std::string& path): _file(new std::ifstream(path.c_str())) {} - bool is_open() const { return _file->is_open(); } - - lines_t& get_lines(unsigned line_start, unsigned line_count, lines_t& lines) { - using namespace std; - // This function make uses of the dumbest algo ever: - // 1) seek(0) - // 2) read lines one by one and discard until line_start - // 3) read line one by one until line_start + line_count - // - // If you are getting snippets many time from the same file, it is - // somewhat a waste of CPU, feel free to benchmark and propose a - // better solution ;) - - _file->clear(); - _file->seekg(0); - string line; - unsigned line_idx; - - for (line_idx = 1; line_idx < line_start; ++line_idx) { - getline(*_file, line); - if (not *_file) { - return lines; - } - } - - // think of it like a lambda in C++98 ;) - // but look, I will reuse it two times! - // What a good boy am I. - struct isspace { - bool operator()(char c) { - return std::isspace(c); - } - }; - - bool started = false; - for (; line_idx < line_start + line_count; ++line_idx) { - getline(*_file, line); - if (not *_file) { - return lines; - } - if (not started) { - if (std::find_if(line.begin(), line.end(), - not_isspace()) == line.end()) - continue; - started = true; - } - lines.push_back(make_pair(line_idx, line)); - } - - lines.erase( - std::find_if(lines.rbegin(), lines.rend(), - not_isempty()).base(), lines.end() - ); - return lines; - } - - lines_t get_lines(unsigned line_start, unsigned line_count) { - lines_t lines; - return get_lines(line_start, line_count, lines); - } - - // there is no find_if_not in C++98, lets do something crappy to - // workaround. - struct not_isspace { - bool operator()(char c) { - return not std::isspace(c); - } - }; - // and define this one here because C++98 is not happy with local defined - // struct passed to template functions, fuuuu. - struct not_isempty { - bool operator()(const lines_t::value_type& p) { - return not (std::find_if(p.second.begin(), p.second.end(), - not_isspace()) == p.second.end()); - } - }; - - void swap(SourceFile& b) { - _file.swap(b._file); - } - -#if defined(BACKWARD_CXX11) - SourceFile(SourceFile&& from): _file(0) { - swap(from); - } - SourceFile& operator=(SourceFile&& from) { - swap(from); return *this; - } -#else - explicit SourceFile(const SourceFile& from) { - // some sort of poor man's move semantic. - swap(const_cast(from)); - } - SourceFile& operator=(const SourceFile& from) { - // some sort of poor man's move semantic. - swap(const_cast(from)); return *this; - } -#endif - -private: - details::handle - > _file; - -#if defined(BACKWARD_CXX11) - SourceFile(const SourceFile&) = delete; - SourceFile& operator=(const SourceFile&) = delete; -#endif -}; - -class SnippetFactory { -public: - typedef SourceFile::lines_t lines_t; - - lines_t get_snippet(const std::string& filename, - unsigned line_start, unsigned context_size) { - - SourceFile& src_file = get_src_file(filename); - unsigned start = line_start - context_size / 2; - return src_file.get_lines(start, context_size); - } - - lines_t get_combined_snippet( - const std::string& filename_a, unsigned line_a, - const std::string& filename_b, unsigned line_b, - unsigned context_size) { - SourceFile& src_file_a = get_src_file(filename_a); - SourceFile& src_file_b = get_src_file(filename_b); - - lines_t lines = src_file_a.get_lines(line_a - context_size / 4, - context_size / 2); - src_file_b.get_lines(line_b - context_size / 4, context_size / 2, - lines); - return lines; - } - - lines_t get_coalesced_snippet(const std::string& filename, - unsigned line_a, unsigned line_b, unsigned context_size) { - SourceFile& src_file = get_src_file(filename); - - using std::min; using std::max; - unsigned a = min(line_a, line_b); - unsigned b = max(line_a, line_b); - - if ((b - a) < (context_size / 3)) { - return src_file.get_lines((a + b - context_size + 1) / 2, - context_size); - } - - lines_t lines = src_file.get_lines(a - context_size / 4, - context_size / 2); - src_file.get_lines(b - context_size / 4, context_size / 2, lines); - return lines; - } - - -private: - typedef details::hashtable::type src_files_t; - src_files_t _src_files; - - SourceFile& get_src_file(const std::string& filename) { - src_files_t::iterator it = _src_files.find(filename); - if (it != _src_files.end()) { - return it->second; - } - SourceFile& new_src_file = _src_files[filename]; - new_src_file = SourceFile(filename); - return new_src_file; - } -}; - -/*************** PRINTER ***************/ - -#ifdef BACKWARD_SYSTEM_LINUX - -namespace Color { - enum type { - yellow = 33, - purple = 35, - reset = 39 - }; -} // namespace Color - -class Colorize { -public: - Colorize(std::FILE* os): - _os(os), _reset(false), _istty(false) {} - - void init() { - _istty = isatty(fileno(_os)); - } - - void set_color(Color::type ccode) { - if (not _istty) return; - - // I assume that the terminal can handle basic colors. Seriously I - // don't want to deal with all the termcap shit. - fprintf(_os, "\033[%im", static_cast(ccode)); - _reset = (ccode != Color::reset); - } - - ~Colorize() { - if (_reset) { - set_color(Color::reset); - } - } - -private: - std::FILE* _os; - bool _reset; - bool _istty; -}; - -#else // ndef BACKWARD_SYSTEM_LINUX - - -namespace Color { - enum type { - yellow = 0, - purple = 0, - reset = 0 - }; -} // namespace Color - -class Colorize { -public: - Colorize(std::FILE*) {} - void init() {} - void set_color(Color::type) {} -}; - -#endif // BACKWARD_SYSTEM_LINUX - -class Printer { -public: - bool snippet; - bool color; - bool address; - bool object; - - Printer(): - snippet(true), - color(true), - address(false), - object(false) - {} - - template - FILE* print(StackTrace& st, FILE* os = stderr) { - using namespace std; - - Colorize colorize(os); - if (color) { - colorize.init(); - } - - fprintf(os, "Stack trace (most recent call last)"); - if (st.thread_id()) { - fprintf(os, " in thread %u:\n", st.thread_id()); - } else { - fprintf(os, ":\n"); - } - - _resolver.load_stacktrace(st); - for (unsigned trace_idx = st.size(); trace_idx > 0; --trace_idx) { - fprintf(os, "#%-2u", trace_idx); - bool already_indented = true; - const ResolvedTrace trace = _resolver.resolve(st[trace_idx-1]); - - if (not trace.source.filename.size() or object) { - fprintf(os, " Object \"%s\", at %p, in %s\n", - trace.object_filename.c_str(), trace.addr, - trace.object_function.c_str()); - already_indented = false; - } - - if (trace.source.filename.size()) { - for (size_t inliner_idx = trace.inliners.size(); - inliner_idx > 0; --inliner_idx) { - if (not already_indented) { - fprintf(os, " "); - } - const ResolvedTrace::SourceLoc& inliner_loc - = trace.inliners[inliner_idx-1]; - print_source_loc(os, " | ", inliner_loc); - if (snippet) { - print_snippet(os, " | ", inliner_loc, - colorize, Color::purple, 5); - } - already_indented = false; - } - - if (not already_indented) { - fprintf(os, " "); - } - print_source_loc(os, " ", trace.source, trace.addr); - if (snippet) { - print_snippet(os, " ", trace.source, - colorize, Color::yellow, 7); - } - - if (trace.locals.size()) { - print_locals(os, " ", trace.locals); - } - } - } - return os; - } -private: - TraceResolver _resolver; - SnippetFactory _snippets; - - void print_snippet(FILE* os, const char* indent, - const ResolvedTrace::SourceLoc& source_loc, - Colorize& colorize, Color::type color_code, - int context_size) - { - using namespace std; - typedef SnippetFactory::lines_t lines_t; - - lines_t lines = _snippets.get_snippet(source_loc.filename, - source_loc.line, context_size); - - for (lines_t::const_iterator it = lines.begin(); - it != lines.end(); ++it) { - if (it-> first == source_loc.line) { - colorize.set_color(color_code); - fprintf(os, "%s>", indent); - } else { - fprintf(os, "%s ", indent); - } - fprintf(os, "%4u: %s\n", it->first, it->second.c_str()); - if (it-> first == source_loc.line) { - colorize.set_color(Color::reset); - } - } - } - - void print_source_loc(FILE* os, const char* indent, - const ResolvedTrace::SourceLoc& source_loc, - void* addr=0) { - fprintf(os, "%sSource \"%s\", line %i, in %s", - indent, source_loc.filename.c_str(), (int)source_loc.line, - source_loc.function.c_str()); - - if (address and addr != 0) { - fprintf(os, " [%p]\n", addr); - } else { - fprintf(os, "\n"); - } - } - - void print_var(FILE* os, const char* base_indent, int indent, - const Variable& var) { - fprintf(os, "%s%s: ", base_indent, var.name.c_str()); - switch (var.kind) { - case Variable::VALUE: - fprintf(os, "%s\n", var.value().c_str()); - break; - case Variable::LIST: - fprintf(os, "["); - for (size_t i = 0; i < var.list().size(); ++i) { - if (i > 0) { - fprintf(os, ", %s", var.list()[i].c_str()); - } - fprintf(os, "%s", var.list()[i].c_str()); - } - fprintf(os, "]\n"); - break; - case Variable::MAP: - fprintf(os, "{\n"); - for (size_t i = 0; i < var.map().size(); ++i) { - if (i > 0) { - fprintf(os, ",\n%s", base_indent); - } - print_var(os, base_indent, indent + 2, var.map()[i]); - } - fprintf(os, "]\n"); - break; - }; - } - - void print_locals(FILE* os, const char* indent, - const std::vector& locals) { - fprintf(os, "%sLocal variables:\n", indent); - for (size_t i = 0; i < locals.size(); ++i) { - if (i > 0) { - fprintf(os, ",\n%s", indent); - } - print_var(os, indent, 0, locals[i]); - } - } -}; - -/*************** SIGNALS HANDLING ***************/ - -#ifdef BACKWARD_SYSTEM_LINUX - - -class SignalHandling { -public: - static std::vector make_default_signals() { - const int signals[] = { - // default action: Core - SIGILL, - SIGABRT, - SIGFPE, - SIGSEGV, - SIGBUS, - // I am not sure the following signals should be enabled by - // default: - // default action: Term - SIGHUP, - SIGINT, - SIGPIPE, - SIGALRM, - SIGTERM, - SIGUSR1, - SIGUSR2, - SIGPOLL, - SIGPROF, - SIGVTALRM, - SIGIO, - SIGPWR, - // default action: Core - SIGQUIT, - SIGSYS, - SIGTRAP, - SIGXCPU, - SIGXFSZ - }; - return std::vector(signals, signals + sizeof signals); - } - - SignalHandling(const std::vector& signals = make_default_signals()) : _loaded(false) { - bool success = true; - - const size_t stack_size = 1024 * 1024 * 8; - _stack_content.reset((char*)malloc(stack_size)); - if (_stack_content) { - stack_t ss; - ss.ss_sp = _stack_content.get(); - ss.ss_size = stack_size; - ss.ss_flags = 0; - if (sigaltstack(&ss, 0) < 0) { - success = false; - } - } else { - success = false; - } - - for (size_t i = 0; i < signals.size(); ++i) { - struct sigaction action; - action.sa_flags = SA_SIGINFO | SA_ONSTACK; - sigemptyset(&action.sa_mask); - action.sa_sigaction = &sig_handler; - - int r = sigaction(signals[i], &action, 0); - if (r < 0) success = false; - } - _loaded = success; - } - - bool loaded() const { return _loaded; } - -private: - details::handle _stack_content; - bool _loaded; - - static void sig_handler(int, siginfo_t* info, void* _ctx) { - ucontext_t *uctx = (ucontext_t*) _ctx; - - StackTrace st; - void* error_addr = 0; -#ifdef REG_RIP // x86_64 - error_addr = reinterpret_cast(uctx->uc_mcontext.gregs[REG_RIP]); -#elif defined(REG_EIP) // x86_32 - error_addr = reinterpret_cast(uctx->uc_mcontext.gregs[REG_EIP]); -#else -# warning ":/ sorry, ain't know no nothing none not of your architecture!" -#endif - if (error_addr) { - st.load_from(error_addr, 32); - } else { - st.load_here(32); - } - - Printer printer; - printer.address = true; - printer.print(st, stderr); - - psiginfo(info, 0); - // terminate the process immediately. - _exit(EXIT_FAILURE); - } -}; - -#endif // BACKWARD_SYSTEM_LINUX - -#ifdef BACKWARD_SYSTEM_UNKNOWN - -class SignalHandling { -public: - SignalHandling(const std::vector& = std::vector()) {} - bool init() { return false; } -}; - -#endif // BACKWARD_SYSTEM_UNKNOWN - -#if 0 -void crit_err_hdlr(int sig_num, siginfo_t * info, void * ucontext) -{ - void * array[50]; - void * caller_address; - char ** messages; - int size, i; - sig_ucontext_t * uc; - - uc = (sig_ucontext_t *)ucontext; - - /* Get the address at the time the signal was raised from the EIP (x86) */ - caller_address = (void *) uc->uc_mcontext.eip; - - fprintf(stderr, "signal %d (%s), address is %p from %p\n", - sig_num, strsignal(sig_num), info->si_addr, - (void *)caller_address); - - size = backtrace(array, 50); - - /* overwrite sigaction with caller's address */ - array[1] = caller_address; - - messages = backtrace_symbols(array, size); - - -void sig_handler(int sig, siginfo_t* info, void* _ctx) { -ucontext_t *context = (ucontext_t*) _ctx; - -psiginfo(info, "Shit hit the fan"); -exit(EXIT_FAILURE); -} - -using namespace std; - -void badass() { -cout << "baddass!" << endl; -((char*)&badass)[0] = 42; -} - -int main() { -struct sigaction action; -action.sa_flags = SA_SIGINFO; -sigemptyset(&action.sa_mask); -action.sa_sigaction = &sig_handler; -int r = sigaction(SIGSEGV, &action, 0); -if (r < 0) { err(errno, 0); } -r = sigaction(SIGILL, &action, 0); -if (r < 0) { err(errno, 0); } - -badass(); -return 0; -} - - -#endif - -// i want to get a stacktrace on: -// - abort -// - signals (segfault.. abort...) -// - exception -// - dont messup with gdb! -// - thread ID -// - helper for capturing stack trace inside exception -// propose a little magic wrapper to throw an exception adding a stacktrace, -// and propose a specific tool to get a stacktrace from an exception (if its -// available). -// - optional override __cxa_throw, then the specific magic tool could get -// the stacktrace. Might be possible to use a thread-local variable to do -// some shit. RTLD_DEEPBIND might do the tricks to override it on the fly. - -// maybe I can even get the last variables and theirs values? -// that might be possible. - -// print with code snippet -// print traceback demangled -// detect color stuff -// register all signals -// -// Seperate stacktrace (load and co function) -// than object extracting informations about a stack trace. - -// also public a simple function to print a stacktrace. - -// backtrace::StackTrace st; -// st.snapshot(); -// print(st); -// cout << st; - -} // namespace backward - -#endif /* H_GUARD */ diff --git a/src/dionysus/dionysus/chain.h b/src/dionysus/chain.h similarity index 100% rename from src/dionysus/dionysus/chain.h rename to src/dionysus/chain.h diff --git a/src/dionysus/dionysus/chain.hpp b/src/dionysus/chain.hpp similarity index 100% rename from src/dionysus/dionysus/chain.hpp rename to src/dionysus/chain.hpp diff --git a/src/dionysus/dionysus/clearing-reduction.h b/src/dionysus/clearing-reduction.h similarity index 100% rename from src/dionysus/dionysus/clearing-reduction.h rename to src/dionysus/clearing-reduction.h diff --git a/src/dionysus/dionysus/clearing-reduction.hpp b/src/dionysus/clearing-reduction.hpp similarity index 100% rename from src/dionysus/dionysus/clearing-reduction.hpp rename to src/dionysus/clearing-reduction.hpp diff --git a/src/dionysus/dionysus/cohomology-persistence.h b/src/dionysus/cohomology-persistence.h similarity index 100% rename from src/dionysus/dionysus/cohomology-persistence.h rename to src/dionysus/cohomology-persistence.h diff --git a/src/dionysus/dionysus/cohomology-persistence.hpp b/src/dionysus/cohomology-persistence.hpp similarity index 100% rename from src/dionysus/dionysus/cohomology-persistence.hpp rename to src/dionysus/cohomology-persistence.hpp diff --git a/src/dionysus/dionysus/diagram.h b/src/dionysus/diagram.h similarity index 100% rename from src/dionysus/dionysus/diagram.h rename to src/dionysus/diagram.h diff --git a/src/dionysus/dionysus/distances.h b/src/dionysus/distances.h similarity index 100% rename from src/dionysus/dionysus/distances.h rename to src/dionysus/distances.h diff --git a/src/dionysus/dionysus/distances.hpp b/src/dionysus/distances.hpp similarity index 100% rename from src/dionysus/dionysus/distances.hpp rename to src/dionysus/distances.hpp diff --git a/src/dionysus/dionysus/dlog/progress.h b/src/dionysus/dlog/progress.h similarity index 100% rename from src/dionysus/dionysus/dlog/progress.h rename to src/dionysus/dlog/progress.h diff --git a/src/dionysus/dionysus/fields/q.h b/src/dionysus/fields/q.h similarity index 100% rename from src/dionysus/dionysus/fields/q.h rename to src/dionysus/fields/q.h diff --git a/src/dionysus/dionysus/fields/z2.h b/src/dionysus/fields/z2.h similarity index 100% rename from src/dionysus/dionysus/fields/z2.h rename to src/dionysus/fields/z2.h diff --git a/src/dionysus/dionysus/fields/zp.h b/src/dionysus/fields/zp.h similarity index 100% rename from src/dionysus/dionysus/fields/zp.h rename to src/dionysus/fields/zp.h diff --git a/src/dionysus/dionysus/filtration.h b/src/dionysus/filtration.h similarity index 100% rename from src/dionysus/dionysus/filtration.h rename to src/dionysus/filtration.h diff --git a/src/dionysus/dionysus/omni-field-persistence.h b/src/dionysus/omni-field-persistence.h similarity index 100% rename from src/dionysus/dionysus/omni-field-persistence.h rename to src/dionysus/omni-field-persistence.h diff --git a/src/dionysus/dionysus/omni-field-persistence.hpp b/src/dionysus/omni-field-persistence.hpp similarity index 100% rename from src/dionysus/dionysus/omni-field-persistence.hpp rename to src/dionysus/omni-field-persistence.hpp diff --git a/src/dionysus/dionysus/ordinary-persistence.h b/src/dionysus/ordinary-persistence.h similarity index 100% rename from src/dionysus/dionysus/ordinary-persistence.h rename to src/dionysus/ordinary-persistence.h diff --git a/src/dionysus/dionysus/pair-recorder.h b/src/dionysus/pair-recorder.h similarity index 100% rename from src/dionysus/dionysus/pair-recorder.h rename to src/dionysus/pair-recorder.h diff --git a/src/dionysus/dionysus/reduced-matrix.h b/src/dionysus/reduced-matrix.h similarity index 100% rename from src/dionysus/dionysus/reduced-matrix.h rename to src/dionysus/reduced-matrix.h diff --git a/src/dionysus/dionysus/reduced-matrix.hpp b/src/dionysus/reduced-matrix.hpp similarity index 100% rename from src/dionysus/dionysus/reduced-matrix.hpp rename to src/dionysus/reduced-matrix.hpp diff --git a/src/dionysus/dionysus/reduction.h b/src/dionysus/reduction.h similarity index 100% rename from src/dionysus/dionysus/reduction.h rename to src/dionysus/reduction.h diff --git a/src/dionysus/dionysus/relative-homology-zigzag.h b/src/dionysus/relative-homology-zigzag.h similarity index 100% rename from src/dionysus/dionysus/relative-homology-zigzag.h rename to src/dionysus/relative-homology-zigzag.h diff --git a/src/dionysus/dionysus/relative-homology-zigzag.hpp b/src/dionysus/relative-homology-zigzag.hpp similarity index 100% rename from src/dionysus/dionysus/relative-homology-zigzag.hpp rename to src/dionysus/relative-homology-zigzag.hpp diff --git a/src/dionysus/dionysus/rips.h b/src/dionysus/rips.h similarity index 100% rename from src/dionysus/dionysus/rips.h rename to src/dionysus/rips.h diff --git a/src/dionysus/dionysus/rips.hpp b/src/dionysus/rips.hpp similarity index 100% rename from src/dionysus/dionysus/rips.hpp rename to src/dionysus/rips.hpp diff --git a/src/dionysus/dionysus/row-reduction.h b/src/dionysus/row-reduction.h similarity index 100% rename from src/dionysus/dionysus/row-reduction.h rename to src/dionysus/row-reduction.h diff --git a/src/dionysus/dionysus/row-reduction.hpp b/src/dionysus/row-reduction.hpp similarity index 100% rename from src/dionysus/dionysus/row-reduction.hpp rename to src/dionysus/row-reduction.hpp diff --git a/src/dionysus/dionysus/simplex.h b/src/dionysus/simplex.h similarity index 100% rename from src/dionysus/dionysus/simplex.h rename to src/dionysus/simplex.h diff --git a/src/dionysus/dionysus/sparse-row-matrix.h b/src/dionysus/sparse-row-matrix.h similarity index 100% rename from src/dionysus/dionysus/sparse-row-matrix.h rename to src/dionysus/sparse-row-matrix.h diff --git a/src/dionysus/dionysus/sparse-row-matrix.hpp b/src/dionysus/sparse-row-matrix.hpp similarity index 100% rename from src/dionysus/dionysus/sparse-row-matrix.hpp rename to src/dionysus/sparse-row-matrix.hpp diff --git a/src/dionysus/dionysus/standard-reduction.h b/src/dionysus/standard-reduction.h similarity index 100% rename from src/dionysus/dionysus/standard-reduction.h rename to src/dionysus/standard-reduction.h diff --git a/src/dionysus/dionysus/standard-reduction.hpp b/src/dionysus/standard-reduction.hpp similarity index 100% rename from src/dionysus/dionysus/standard-reduction.hpp rename to src/dionysus/standard-reduction.hpp diff --git a/src/dionysus/dionysus/trails-chains.h b/src/dionysus/trails-chains.h similarity index 100% rename from src/dionysus/dionysus/trails-chains.h rename to src/dionysus/trails-chains.h diff --git a/src/dionysus/dionysus/zigzag-persistence.h b/src/dionysus/zigzag-persistence.h similarity index 100% rename from src/dionysus/dionysus/zigzag-persistence.h rename to src/dionysus/zigzag-persistence.h diff --git a/src/dionysus/dionysus/zigzag-persistence.hpp b/src/dionysus/zigzag-persistence.hpp similarity index 100% rename from src/dionysus/dionysus/zigzag-persistence.hpp rename to src/dionysus/zigzag-persistence.hpp diff --git a/src/tdautils/dionysusUtils.h b/src/tdautils/dionysusUtils.h index bc74f56..925b85e 100644 --- a/src/tdautils/dionysusUtils.h +++ b/src/tdautils/dionysusUtils.h @@ -1,10 +1,9 @@ #ifndef __DIONYSUSUTILS_H__ #define __DIONYSUSUTILS_H__ -#include "../dionysus/dionysus/simplex.h" -#include "../dionysus/dionysus/rips.h" -#include "../dionysus/dionysus/filtration.h" -#include "../dionysus/dionysus/standard-reduction.h" +#include +#include +#include // swapping simplex //#include From 9a31d40f92e970a394382a213ba094e122928fb8 Mon Sep 17 00:00:00 2001 From: thomashli Date: Wed, 25 Jul 2018 11:01:25 -0600 Subject: [PATCH 06/29] swapped distances.h --- src/tdautils/gridUtils.h | 17 +++++++++++------ src/tdautils/ripsL2.h | 33 +++++++++++++++++++++++++++------ src/tdautils/ripsL2backup.h | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 12 deletions(-) create mode 100644 src/tdautils/ripsL2backup.h diff --git a/src/tdautils/gridUtils.h b/src/tdautils/gridUtils.h index 7d12b09..1a2219c 100644 --- a/src/tdautils/gridUtils.h +++ b/src/tdautils/gridUtils.h @@ -3,13 +3,18 @@ #include -#include -#include -#include -#include -#include +//#include +//#include +//#include +//#include +//#include #include +#include +#include +#include +#include + #include #include #include @@ -531,4 +536,4 @@ void simplicesFromGridBarycenter( -# endif // __GRIDUTILS_H__ \ No newline at end of file +# endif // __GRIDUTILS_H__ diff --git a/src/tdautils/ripsL2.h b/src/tdautils/ripsL2.h index 2acc8ca..d3d044b 100644 --- a/src/tdautils/ripsL2.h +++ b/src/tdautils/ripsL2.h @@ -1,24 +1,45 @@ #include -#include +//#include #include #include #include -#include -#include +//#include +//#include #include // for BackInsertFunctor #include + +//dionysus2 +//#include +#include +//#include +#include +//#include #include -typedef PairwiseDistances PairDistances; +namespace d = dionysus; + + +//L2 Struct is inside distances.h +// Feels very janky + +// typedef std::vector Point; +// typedef std::vector PointContainer; + +typedef d::PairwiseDistances>, d::L2Distance>> PairDistances; typedef PairDistances::DistanceType DistanceType; typedef PairDistances::IndexType VertexR; typedef Rips< PairDistances, Simplex< VertexR, double > > Generator; typedef Generator::Simplex SmplxR; typedef Filtration FltrR; + +// Comment this test typedef StaticPersistence<> PersistenceR; -//typedef DynamicPersistenceChains<> PersistenceR; -typedef PersistenceDiagram<> PDgmR; +// relabel +//typedef OrdinaryPersistence<> PersistenceR; +//typedef DynamicPersistenceChains<> PersistenceR; +typedef PersistenceDiagram<> PDgmR; + diff --git a/src/tdautils/ripsL2backup.h b/src/tdautils/ripsL2backup.h new file mode 100644 index 0000000..172a900 --- /dev/null +++ b/src/tdautils/ripsL2backup.h @@ -0,0 +1,32 @@ +#include +//#include +#include +#include +#include + +#include +#include +#include // for BackInsertFunctor +#include + +//dionysus2 +//#include +#include +//#include +//#include +//#include + +#include + + +typedef PairwiseDistances PairDistances; +typedef PairDistances::DistanceType DistanceType; +typedef PairDistances::IndexType VertexR; +typedef Rips< PairDistances, Simplex< VertexR, double > > Generator; +typedef Generator::Simplex SmplxR; +typedef Filtration FltrR; +typedef StaticPersistence<> PersistenceR; +//typedef DynamicPersistenceChains<> PersistenceR; +typedef PersistenceDiagram<> PDgmR; + + From 0a9466f3479b237d7fcde0236de85f36366b405f Mon Sep 17 00:00:00 2001 From: thomashli Date: Mon, 30 Jul 2018 16:35:08 -0600 Subject: [PATCH 07/29] Arbit works to --- src/tdautils/ripsArbit.h | 13 +++++++++---- src/tdautils/ripsL2.h | 7 ++++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/tdautils/ripsArbit.h b/src/tdautils/ripsArbit.h index 9e87f52..47d0ec5 100644 --- a/src/tdautils/ripsArbit.h +++ b/src/tdautils/ripsArbit.h @@ -1,19 +1,24 @@ #include -#include +//#include #include #include #include -#include +//#include #include -#include +//#include #include // for BackInsertFunctor #include +//dionysus2 +#include +#include + #include +namespace d = dionysus; -typedef PairwiseDistances PairDistancesA; +typedef d::PairwiseDistances>, ArbitDistance> PairDistancesA; typedef PairDistancesA::DistanceType DistanceTypeA; typedef PairDistancesA::IndexType VertexRA; typedef Rips< PairDistancesA, Simplex< VertexRA, double > > GeneratorA; diff --git a/src/tdautils/ripsL2.h b/src/tdautils/ripsL2.h index d3d044b..53ac103 100644 --- a/src/tdautils/ripsL2.h +++ b/src/tdautils/ripsL2.h @@ -22,14 +22,19 @@ namespace d = dionysus; //L2 Struct is inside distances.h -// Feels very janky +// not sure if I want to typedef here // typedef std::vector Point; // typedef std::vector PointContainer; +// Swapped out typedef d::PairwiseDistances>, d::L2Distance>> PairDistances; + +//Next two lines are fine typedef PairDistances::DistanceType DistanceType; typedef PairDistances::IndexType VertexR; + + typedef Rips< PairDistances, Simplex< VertexR, double > > Generator; typedef Generator::Simplex SmplxR; typedef Filtration FltrR; From 1401ed7ac9b0814c9645c5b36151e158d29928db Mon Sep 17 00:00:00 2001 From: thomashli Date: Sat, 18 Aug 2018 16:33:56 -0700 Subject: [PATCH 08/29] Made FiltrationDiagDionysus2 --- src/tdautils/diagramDS.h | 61 +++++++++++++ src/tdautils/dionysus2Utils.h | 129 +++++++++++++++++++++++++++ tests/testthat/test_FiltrationDiag.R | 15 ++++ 3 files changed, 205 insertions(+) create mode 100644 src/tdautils/diagramDS.h create mode 100644 src/tdautils/dionysus2Utils.h create mode 100644 tests/testthat/test_FiltrationDiag.R diff --git a/src/tdautils/diagramDS.h b/src/tdautils/diagramDS.h new file mode 100644 index 0000000..84333b1 --- /dev/null +++ b/src/tdautils/diagramDS.h @@ -0,0 +1,61 @@ +#ifndef DIAGRAM_DS_H +#define DIAGRAM_DS_H +#endif + +#include +#include + +namespace d = dionysus; + +namespace diagramDS +{ + +template +class DiagramDS +{ + public: + using Value = Value_; + using Data = Data_; + using Diagrams = std::vector>; + + template + DiagramDS(const ReducedMatrix& m, const Filtration& f, const GetValue get_value, const GetData get_data) + { + //auto get_value = [&](const Simplex& s) -> float { return filtration.index(s); }; + //auto get_data = [](Persistence::Index i) { return i; }; + for (typename ReducedMatrix::Index i = 0; i < m.size(); ++i) + { + if (m.skip(i)) + continue; + + auto& s = f[i]; + auto d = s.dimension(); + while (d + 1 > diagrams.size()) + diagrams.emplace_back(); + + auto pair = m.pair(i); + if (pair == m.unpaired()) + { + auto birth = get_value(s); + using Value = decltype(birth); + Value death = std::numeric_limits::infinity(); + diagrams[d].emplace_back(birth, death, get_data(i)); + } else if (pair > i) // positive + { + auto birth = get_value(s); + auto death = get_value(f[pair]); + + if (birth != death) // skip diagonal + diagrams[d].emplace_back(birth, death, get_data(i)); + } // else negative: do nothing + } + } + + Diagrams getDiagrams() {return diagrams;} + private: + Diagrams diagrams; +}; + +} + + diff --git a/src/tdautils/dionysus2Utils.h b/src/tdautils/dionysus2Utils.h new file mode 100644 index 0000000..1670e1a --- /dev/null +++ b/src/tdautils/dionysus2Utils.h @@ -0,0 +1,129 @@ +#ifndef __DIONYSUS2UTILS_H__ +#define __DIONYSUS2UTILS_H__ + +#include +#include +#include +#include +#include +#include +#include +#include "diagramDS.h" +#include + +namespace d = dionysus; + + + +/* + * I'm going to assume that we simply pass in the persDgm, persLoc, and persCycle as they are above. + * I should ask Dave tomorrow what persDgm, persLoc, and persCycle are. + * Persistence Locations: for each diagram, there is a collection of points, and the points are "steps" + * Perstistence Cycle: Birth and Death? + * + * Make example to test it + * + * ask how to get the vignette + * Figure out what the Filtration that gets passed in is. + * Want to pass in Ordinary-Persistence with a Z2Field first go + * + * ask dave what initLocations and initDiagrams do + * + * ask dave what format the cmplx is passed in as + * Figure out what format Ordinary Persistence saves the persistence as + * + * Figure out what format TDA filtration is in typecast utils + * What is the format for TDA filtration? + * + * Figure out what is in FiltrationDiagDionysus + * Parameter probably comes from gridUtils.h defining Persistence as static-persistence + * Probably want to rename persistence2 + * Figure out how to convert Filtration.h in D1 to D2 so the rest of the methods work. + * + */ + +//Helper function for filling in persDgm +//template< typename Diagrams, typename iterator, typename Evaluator, typename SimplexMap > +//inline void initDiagrams; +//Helper function for filling in persLoc and persCycle +// inline void initLocations; + +// FiltrationDiag in Dionysus2 +/** \brief Construct the persistence diagram from the filtration using library +* Dionysus. +* +* @param[out] void Void +* @param[in] filtration The input filtration +* @param[in] maxdimension Max dimension of the homological features to be +* computed +* @param[in] location Are location of birth point, death point, and +* representative cycles returned? +* @param[in] printProgress Is progress printed? +* @param[in] persDgm Memory space for the resulting persistence +* diagram +* @param[in] persLoc Memory space for the resulting birth points and +* death points +* @param[in] persCycle Memory space for the resulting representative +* cycles +* @param[in] max_num_bars Write the max_num_pairs most persistent pairs of the +* diagram. Diagram must point to enough memory space for +* 3*max_num_pairs double. If there is not enough pairs in the diagram, +* write nothing after. +*/ + +template +void FiltrationDiagDionysus2( + const Filtration &filtration, + const int maxdimension, + const bool location, + const bool printProgress, + std::vector< std::vector< std::vector< double > > > & persDgm, + std::vector< std::vector< std::vector< unsigned > > > & persLoc, + std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle +) { + + //Assume that Persistence that is passed in is Persistence2 + //Calculate Persistence + + d::Z2Field k; + //Persistence persistence(k); + //StandardReduction2 reduce(persistence); + d::RowReduction reduce(k); + // We know that the function breaks when this line is called. + reduce(filtration); + + typedef decltype(reduce.persistence().pair(0)) Index; + typedef float Value; + //persistence is reduced. + Index _ = 0; + // move Persistence into persDgm + //auto dgms = d::init_diagrams(reduce->persistence(), filtration, [&](const Smplx2& s) -> float { return filtration.index(s); }, [](typename Persistence::Index i) { return i; }); + + diagramDS::DiagramDS dgms( + reduce.persistence(), + filtration, + [&](const Smplx2& s) -> float { return filtration.index(s);}, + [](typename Persistence::Index i) { return i; } + ); + //emulate initDiagrams function from dionysusUtils + //will put into a function later + persDgm.resize(dgms.getDiagrams().size()); + for (auto &dgm : dgms.getDiagrams()) + { + for (auto &pt : dgm) + { + std::vector pt_; + if (pt.death() == std::numeric_limits::infinity()) { + pt_ = {filtration[pt.birth()].data(),pt.death()}; + } else { + pt_ = {filtration[pt.birth()].data(),filtration[pt.death()].data()}; + } + persDgm[_].push_back(pt_); + } + _++; + } + persDgm.resize(maxdimension + 1); + +} + +#endif __DIONYSUS2UTILS_H__ diff --git a/tests/testthat/test_FiltrationDiag.R b/tests/testthat/test_FiltrationDiag.R new file mode 100644 index 0000000..36acb7f --- /dev/null +++ b/tests/testthat/test_FiltrationDiag.R @@ -0,0 +1,15 @@ +context("FiltrationDiag") + +test_that("FiltrationDiag works with dionysus2" , { + X <- matrix(c(0,0,100,100,0,102,0,101),nrow=4) + Fltrips = ripsFiltration(X,maxdimension = 1, maxscale = 120, library = "Dionysus") + DiagRips = filtrationDiag(Fltrips, maxdimension = 0, library = "Dionysus") + DiagRips2 = filtrationDiag(Fltrips, maxdimension = 0, library = "D2", location = FALSE) + for (i in 1:nrow(DiagRips)) { + for (j in 1:ncol(DiagRips)) { + expect_equal(DiagRips[i,j],DiagRips2[i,j]) + } + } +}) + + From fe4de549cf71a15442f9a88861570a1c87543b27 Mon Sep 17 00:00:00 2001 From: thomashli Date: Sat, 18 Aug 2018 16:42:52 -0700 Subject: [PATCH 09/29] prototype FiltrationDiag working --- src/tdautils/dionysus2Utils.h | 61 ++++++----------------------------- 1 file changed, 10 insertions(+), 51 deletions(-) diff --git a/src/tdautils/dionysus2Utils.h b/src/tdautils/dionysus2Utils.h index 1670e1a..674c4ad 100644 --- a/src/tdautils/dionysus2Utils.h +++ b/src/tdautils/dionysus2Utils.h @@ -13,41 +13,6 @@ namespace d = dionysus; - - -/* - * I'm going to assume that we simply pass in the persDgm, persLoc, and persCycle as they are above. - * I should ask Dave tomorrow what persDgm, persLoc, and persCycle are. - * Persistence Locations: for each diagram, there is a collection of points, and the points are "steps" - * Perstistence Cycle: Birth and Death? - * - * Make example to test it - * - * ask how to get the vignette - * Figure out what the Filtration that gets passed in is. - * Want to pass in Ordinary-Persistence with a Z2Field first go - * - * ask dave what initLocations and initDiagrams do - * - * ask dave what format the cmplx is passed in as - * Figure out what format Ordinary Persistence saves the persistence as - * - * Figure out what format TDA filtration is in typecast utils - * What is the format for TDA filtration? - * - * Figure out what is in FiltrationDiagDionysus - * Parameter probably comes from gridUtils.h defining Persistence as static-persistence - * Probably want to rename persistence2 - * Figure out how to convert Filtration.h in D1 to D2 so the rest of the methods work. - * - */ - -//Helper function for filling in persDgm -//template< typename Diagrams, typename iterator, typename Evaluator, typename SimplexMap > -//inline void initDiagrams; -//Helper function for filling in persLoc and persCycle -// inline void initLocations; - // FiltrationDiag in Dionysus2 /** \brief Construct the persistence diagram from the filtration using library * Dionysus. @@ -82,32 +47,24 @@ void FiltrationDiagDionysus2( std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle ) { - //Assume that Persistence that is passed in is Persistence2 - //Calculate Persistence - + // Assume that Persistence that is passed in is Persistence2, Filtration is Fltr2 + // Create and Calculate Persistence d::Z2Field k; - //Persistence persistence(k); - //StandardReduction2 reduce(persistence); d::RowReduction reduce(k); - // We know that the function breaks when this line is called. reduce(filtration); typedef decltype(reduce.persistence().pair(0)) Index; typedef float Value; - //persistence is reduced. - Index _ = 0; - // move Persistence into persDgm - //auto dgms = d::init_diagrams(reduce->persistence(), filtration, [&](const Smplx2& s) -> float { return filtration.index(s); }, [](typename Persistence::Index i) { return i; }); - - diagramDS::DiagramDS dgms( + // Putting persistence into diagrams data structure + diagramDS::DiagramDS dgms( reduce.persistence(), filtration, [&](const Smplx2& s) -> float { return filtration.index(s);}, [](typename Persistence::Index i) { return i; } ); - //emulate initDiagrams function from dionysusUtils - //will put into a function later + // Fill in Diagram persDgm.resize(dgms.getDiagrams().size()); + Index _ = 0; for (auto &dgm : dgms.getDiagrams()) { for (auto &pt : dgm) @@ -122,8 +79,10 @@ void FiltrationDiagDionysus2( } _++; } - persDgm.resize(maxdimension + 1); - + // Capping at maxdimension + if (persDgm.size() > maxdimension) { + persDgm.resize(maxdimension + 1); + } } #endif __DIONYSUS2UTILS_H__ From b8af11e9207d0bcc55df128b0d3b325a5e2ac464 Mon Sep 17 00:00:00 2001 From: thomashli Date: Sat, 18 Aug 2018 16:44:09 -0700 Subject: [PATCH 10/29] checkout master? --- src/tdautils/gridUtils.h | 13 +++++++++++-- src/tdautils/ripsL2.h | 27 +++++++++++---------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/tdautils/gridUtils.h b/src/tdautils/gridUtils.h index 1a2219c..65bb5b7 100644 --- a/src/tdautils/gridUtils.h +++ b/src/tdautils/gridUtils.h @@ -12,7 +12,9 @@ #include #include -#include +//#include +#include +#include #include #include @@ -42,7 +44,14 @@ typedef OffsetBeginMap FiltrationPersistenceMap; - +//dionysus2 +//needs changing +typedef d::Simplex Smplx2; +typedef d::Filtration Fltr2; +typedef d::Simplex<> Simplex2; +typedef d::Filtration Filtration2; +typedef d::ReducedMatrix Persistence2; +typedef d::StandardReduction StandardReduction2; // add a single edge to the filtration template< typename VectorList > diff --git a/src/tdautils/ripsL2.h b/src/tdautils/ripsL2.h index 53ac103..81e2d44 100644 --- a/src/tdautils/ripsL2.h +++ b/src/tdautils/ripsL2.h @@ -10,10 +10,12 @@ #include //dionysus2 +#include +//#include +//#include //#include -#include //#include -#include +//#include //#include #include @@ -27,24 +29,17 @@ namespace d = dionysus; // typedef std::vector Point; // typedef std::vector PointContainer; -// Swapped out typedef d::PairwiseDistances>, d::L2Distance>> PairDistances; - -//Next two lines are fine typedef PairDistances::DistanceType DistanceType; typedef PairDistances::IndexType VertexR; - - typedef Rips< PairDistances, Simplex< VertexR, double > > Generator; typedef Generator::Simplex SmplxR; -typedef Filtration FltrR; - -// Comment this test -typedef StaticPersistence<> PersistenceR; - -// relabel -//typedef OrdinaryPersistence<> PersistenceR; - +typedef Filtration FltrR; +//typedef StaticPersistence<> PersistenceR; +//typedef d::Simplex<> Simplex2; +//typedef d::Filtration Filtration2; +//typedef d::OrdinaryPersistence Persistence2; +//typedef d::StandardReduction StandardReduction2; //typedef DynamicPersistenceChains<> PersistenceR; -typedef PersistenceDiagram<> PDgmR; +//typedef PersistenceDiagram<> PDgmR; From f611acdc2ce5d46861336a1dd3912eab61f80cf0 Mon Sep 17 00:00:00 2001 From: thomashli Date: Sat, 18 Aug 2018 16:48:44 -0700 Subject: [PATCH 11/29] preparing to rebase --- R/RcppExports.R | 67 +++++++++++ src/RcppExports.cpp | 264 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 331 insertions(+) create mode 100644 R/RcppExports.R create mode 100644 src/RcppExports.cpp diff --git a/R/RcppExports.R b/R/RcppExports.R new file mode 100644 index 0000000..ace733d --- /dev/null +++ b/R/RcppExports.R @@ -0,0 +1,67 @@ +# Generated by using Rcpp::compileAttributes() -> do not edit by hand +# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 + +GridFiltration <- function(FUNvalues, gridDim, maxdimension, decomposition, printProgress) { + .Call(`_TDA_GridFiltration`, FUNvalues, gridDim, maxdimension, decomposition, printProgress) +} + +GridDiag <- function(FUNvalues, gridDim, maxdimension, decomposition, library, location, printProgress) { + .Call(`_TDA_GridDiag`, FUNvalues, gridDim, maxdimension, decomposition, library, location, printProgress) +} + +Bottleneck <- function(Diag1, Diag2) { + .Call(`_TDA_Bottleneck`, Diag1, Diag2) +} + +Wasserstein <- function(Diag1, Diag2, p) { + .Call(`_TDA_Wasserstein`, Diag1, Diag2, p) +} + +Kde <- function(X, Grid, h, kertype, weight, printProgress) { + .Call(`_TDA_Kde`, X, Grid, h, kertype, weight, printProgress) +} + +KdeDist <- function(X, Grid, h, weight, printProgress) { + .Call(`_TDA_KdeDist`, X, Grid, h, weight, printProgress) +} + +Dtm <- function(knnDistance, weightBound, r) { + .Call(`_TDA_Dtm`, knnDistance, weightBound, r) +} + +DtmWeight <- function(knnDistance, weightBound, r, knnIndex, weight) { + .Call(`_TDA_DtmWeight`, knnDistance, weightBound, r, knnIndex, weight) +} + +FiltrationDiag <- function(filtration, maxdimension, library, location, printProgress) { + .Call(`_TDA_FiltrationDiag`, filtration, maxdimension, library, location, printProgress) +} + +FunFiltration <- function(FUNvalues, cmplx) { + .Call(`_TDA_FunFiltration`, FUNvalues, cmplx) +} + +RipsFiltration <- function(X, maxdimension, maxscale, dist, library, printProgress) { + .Call(`_TDA_RipsFiltration`, X, maxdimension, maxscale, dist, library, printProgress) +} + +RipsDiag <- function(X, maxdimension, maxscale, dist, libraryFiltration, libraryDiag, location, printProgress) { + .Call(`_TDA_RipsDiag`, X, maxdimension, maxscale, dist, libraryFiltration, libraryDiag, location, printProgress) +} + +AlphaShapeFiltration <- function(X, printProgress) { + .Call(`_TDA_AlphaShapeFiltration`, X, printProgress) +} + +AlphaShapeDiag <- function(X, maxdimension, libraryDiag, location, printProgress) { + .Call(`_TDA_AlphaShapeDiag`, X, maxdimension, libraryDiag, location, printProgress) +} + +AlphaComplexFiltration <- function(X, printProgress) { + .Call(`_TDA_AlphaComplexFiltration`, X, printProgress) +} + +AlphaComplexDiag <- function(X, maxdimension, libraryDiag, location, printProgress) { + .Call(`_TDA_AlphaComplexDiag`, X, maxdimension, libraryDiag, location, printProgress) +} + diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp new file mode 100644 index 0000000..f772558 --- /dev/null +++ b/src/RcppExports.cpp @@ -0,0 +1,264 @@ +// Generated by using Rcpp::compileAttributes() -> do not edit by hand +// Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 + +#include +#include + +using namespace Rcpp; + +// GridFiltration +Rcpp::List GridFiltration(const Rcpp::NumericVector& FUNvalues, const Rcpp::IntegerVector& gridDim, const int maxdimension, const std::string& decomposition, const bool printProgress); +RcppExport SEXP _TDA_GridFiltration(SEXP FUNvaluesSEXP, SEXP gridDimSEXP, SEXP maxdimensionSEXP, SEXP decompositionSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericVector& >::type FUNvalues(FUNvaluesSEXP); + Rcpp::traits::input_parameter< const Rcpp::IntegerVector& >::type gridDim(gridDimSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type decomposition(decompositionSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(GridFiltration(FUNvalues, gridDim, maxdimension, decomposition, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// GridDiag +Rcpp::List GridDiag(const Rcpp::NumericVector& FUNvalues, const Rcpp::IntegerVector& gridDim, const int maxdimension, const std::string& decomposition, const std::string& library, const bool location, const bool printProgress); +RcppExport SEXP _TDA_GridDiag(SEXP FUNvaluesSEXP, SEXP gridDimSEXP, SEXP maxdimensionSEXP, SEXP decompositionSEXP, SEXP librarySEXP, SEXP locationSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericVector& >::type FUNvalues(FUNvaluesSEXP); + Rcpp::traits::input_parameter< const Rcpp::IntegerVector& >::type gridDim(gridDimSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type decomposition(decompositionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type library(librarySEXP); + Rcpp::traits::input_parameter< const bool >::type location(locationSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(GridDiag(FUNvalues, gridDim, maxdimension, decomposition, library, location, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// Bottleneck +double Bottleneck(const Rcpp::NumericMatrix& Diag1, const Rcpp::NumericMatrix& Diag2); +RcppExport SEXP _TDA_Bottleneck(SEXP Diag1SEXP, SEXP Diag2SEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type Diag1(Diag1SEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type Diag2(Diag2SEXP); + rcpp_result_gen = Rcpp::wrap(Bottleneck(Diag1, Diag2)); + return rcpp_result_gen; +END_RCPP +} +// Wasserstein +double Wasserstein(const Rcpp::NumericMatrix& Diag1, const Rcpp::NumericMatrix& Diag2, const int p); +RcppExport SEXP _TDA_Wasserstein(SEXP Diag1SEXP, SEXP Diag2SEXP, SEXP pSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type Diag1(Diag1SEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type Diag2(Diag2SEXP); + Rcpp::traits::input_parameter< const int >::type p(pSEXP); + rcpp_result_gen = Rcpp::wrap(Wasserstein(Diag1, Diag2, p)); + return rcpp_result_gen; +END_RCPP +} +// Kde +Rcpp::NumericVector Kde(const Rcpp::NumericMatrix& X, const Rcpp::NumericMatrix& Grid, const double h, const std::string& kertype, const Rcpp::NumericVector& weight, const bool printProgress); +RcppExport SEXP _TDA_Kde(SEXP XSEXP, SEXP GridSEXP, SEXP hSEXP, SEXP kertypeSEXP, SEXP weightSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type Grid(GridSEXP); + Rcpp::traits::input_parameter< const double >::type h(hSEXP); + Rcpp::traits::input_parameter< const std::string& >::type kertype(kertypeSEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericVector& >::type weight(weightSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(Kde(X, Grid, h, kertype, weight, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// KdeDist +Rcpp::NumericVector KdeDist(const Rcpp::NumericMatrix& X, const Rcpp::NumericMatrix& Grid, const double h, const Rcpp::NumericVector& weight, const bool printProgress); +RcppExport SEXP _TDA_KdeDist(SEXP XSEXP, SEXP GridSEXP, SEXP hSEXP, SEXP weightSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type Grid(GridSEXP); + Rcpp::traits::input_parameter< const double >::type h(hSEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericVector& >::type weight(weightSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(KdeDist(X, Grid, h, weight, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// Dtm +Rcpp::NumericVector Dtm(const Rcpp::NumericMatrix& knnDistance, const double weightBound, const double r); +RcppExport SEXP _TDA_Dtm(SEXP knnDistanceSEXP, SEXP weightBoundSEXP, SEXP rSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type knnDistance(knnDistanceSEXP); + Rcpp::traits::input_parameter< const double >::type weightBound(weightBoundSEXP); + Rcpp::traits::input_parameter< const double >::type r(rSEXP); + rcpp_result_gen = Rcpp::wrap(Dtm(knnDistance, weightBound, r)); + return rcpp_result_gen; +END_RCPP +} +// DtmWeight +Rcpp::NumericVector DtmWeight(const Rcpp::NumericMatrix& knnDistance, const double weightBound, const double r, const Rcpp::NumericMatrix& knnIndex, const Rcpp::NumericVector& weight); +RcppExport SEXP _TDA_DtmWeight(SEXP knnDistanceSEXP, SEXP weightBoundSEXP, SEXP rSEXP, SEXP knnIndexSEXP, SEXP weightSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type knnDistance(knnDistanceSEXP); + Rcpp::traits::input_parameter< const double >::type weightBound(weightBoundSEXP); + Rcpp::traits::input_parameter< const double >::type r(rSEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type knnIndex(knnIndexSEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericVector& >::type weight(weightSEXP); + rcpp_result_gen = Rcpp::wrap(DtmWeight(knnDistance, weightBound, r, knnIndex, weight)); + return rcpp_result_gen; +END_RCPP +} +// FiltrationDiag +Rcpp::List FiltrationDiag(const Rcpp::List& filtration, const int maxdimension, const std::string& library, const bool location, const bool printProgress); +RcppExport SEXP _TDA_FiltrationDiag(SEXP filtrationSEXP, SEXP maxdimensionSEXP, SEXP librarySEXP, SEXP locationSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::List& >::type filtration(filtrationSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type library(librarySEXP); + Rcpp::traits::input_parameter< const bool >::type location(locationSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(FiltrationDiag(filtration, maxdimension, library, location, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// FunFiltration +Rcpp::List FunFiltration(const Rcpp::NumericVector& FUNvalues, const Rcpp::List& cmplx); +RcppExport SEXP _TDA_FunFiltration(SEXP FUNvaluesSEXP, SEXP cmplxSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericVector& >::type FUNvalues(FUNvaluesSEXP); + Rcpp::traits::input_parameter< const Rcpp::List& >::type cmplx(cmplxSEXP); + rcpp_result_gen = Rcpp::wrap(FunFiltration(FUNvalues, cmplx)); + return rcpp_result_gen; +END_RCPP +} +// RipsFiltration +Rcpp::List RipsFiltration(const Rcpp::NumericMatrix& X, const int maxdimension, const double maxscale, const std::string& dist, const std::string& library, const bool printProgress); +RcppExport SEXP _TDA_RipsFiltration(SEXP XSEXP, SEXP maxdimensionSEXP, SEXP maxscaleSEXP, SEXP distSEXP, SEXP librarySEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const double >::type maxscale(maxscaleSEXP); + Rcpp::traits::input_parameter< const std::string& >::type dist(distSEXP); + Rcpp::traits::input_parameter< const std::string& >::type library(librarySEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(RipsFiltration(X, maxdimension, maxscale, dist, library, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// RipsDiag +Rcpp::List RipsDiag(const Rcpp::NumericMatrix& X, const int maxdimension, const double maxscale, const std::string& dist, const std::string& libraryFiltration, const std::string& libraryDiag, const bool location, const bool printProgress); +RcppExport SEXP _TDA_RipsDiag(SEXP XSEXP, SEXP maxdimensionSEXP, SEXP maxscaleSEXP, SEXP distSEXP, SEXP libraryFiltrationSEXP, SEXP libraryDiagSEXP, SEXP locationSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const double >::type maxscale(maxscaleSEXP); + Rcpp::traits::input_parameter< const std::string& >::type dist(distSEXP); + Rcpp::traits::input_parameter< const std::string& >::type libraryFiltration(libraryFiltrationSEXP); + Rcpp::traits::input_parameter< const std::string& >::type libraryDiag(libraryDiagSEXP); + Rcpp::traits::input_parameter< const bool >::type location(locationSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(RipsDiag(X, maxdimension, maxscale, dist, libraryFiltration, libraryDiag, location, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// AlphaShapeFiltration +Rcpp::List AlphaShapeFiltration(const Rcpp::NumericMatrix& X, const bool printProgress); +RcppExport SEXP _TDA_AlphaShapeFiltration(SEXP XSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(AlphaShapeFiltration(X, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// AlphaShapeDiag +Rcpp::List AlphaShapeDiag(const Rcpp::NumericMatrix& X, const int maxdimension, const std::string& libraryDiag, const bool location, const bool printProgress); +RcppExport SEXP _TDA_AlphaShapeDiag(SEXP XSEXP, SEXP maxdimensionSEXP, SEXP libraryDiagSEXP, SEXP locationSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type libraryDiag(libraryDiagSEXP); + Rcpp::traits::input_parameter< const bool >::type location(locationSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(AlphaShapeDiag(X, maxdimension, libraryDiag, location, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// AlphaComplexFiltration +Rcpp::List AlphaComplexFiltration(const Rcpp::NumericMatrix& X, const bool printProgress); +RcppExport SEXP _TDA_AlphaComplexFiltration(SEXP XSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(AlphaComplexFiltration(X, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// AlphaComplexDiag +Rcpp::List AlphaComplexDiag(const Rcpp::NumericMatrix& X, const int maxdimension, const std::string& libraryDiag, const bool location, const bool printProgress); +RcppExport SEXP _TDA_AlphaComplexDiag(SEXP XSEXP, SEXP maxdimensionSEXP, SEXP libraryDiagSEXP, SEXP locationSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type libraryDiag(libraryDiagSEXP); + Rcpp::traits::input_parameter< const bool >::type location(locationSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(AlphaComplexDiag(X, maxdimension, libraryDiag, location, printProgress)); + return rcpp_result_gen; +END_RCPP +} + +static const R_CallMethodDef CallEntries[] = { + {"_TDA_GridFiltration", (DL_FUNC) &_TDA_GridFiltration, 5}, + {"_TDA_GridDiag", (DL_FUNC) &_TDA_GridDiag, 7}, + {"_TDA_Bottleneck", (DL_FUNC) &_TDA_Bottleneck, 2}, + {"_TDA_Wasserstein", (DL_FUNC) &_TDA_Wasserstein, 3}, + {"_TDA_Kde", (DL_FUNC) &_TDA_Kde, 6}, + {"_TDA_KdeDist", (DL_FUNC) &_TDA_KdeDist, 5}, + {"_TDA_Dtm", (DL_FUNC) &_TDA_Dtm, 3}, + {"_TDA_DtmWeight", (DL_FUNC) &_TDA_DtmWeight, 5}, + {"_TDA_FiltrationDiag", (DL_FUNC) &_TDA_FiltrationDiag, 5}, + {"_TDA_FunFiltration", (DL_FUNC) &_TDA_FunFiltration, 2}, + {"_TDA_RipsFiltration", (DL_FUNC) &_TDA_RipsFiltration, 6}, + {"_TDA_RipsDiag", (DL_FUNC) &_TDA_RipsDiag, 8}, + {"_TDA_AlphaShapeFiltration", (DL_FUNC) &_TDA_AlphaShapeFiltration, 2}, + {"_TDA_AlphaShapeDiag", (DL_FUNC) &_TDA_AlphaShapeDiag, 5}, + {"_TDA_AlphaComplexFiltration", (DL_FUNC) &_TDA_AlphaComplexFiltration, 2}, + {"_TDA_AlphaComplexDiag", (DL_FUNC) &_TDA_AlphaComplexDiag, 5}, + {NULL, NULL, 0} +}; + +RcppExport void R_init_TDA(DllInfo *dll) { + R_registerRoutines(dll, NULL, CallEntries, NULL, NULL); + R_useDynamicSymbols(dll, FALSE); +} From 71f94f8010c5b3d0e1a270dda495bda2e13a5e2b Mon Sep 17 00:00:00 2001 From: thomashli Date: Sat, 18 Aug 2018 16:49:22 -0700 Subject: [PATCH 12/29] prep for rebasing master --- R/filtrationDiag.R | 9 ++++++--- src/diag.cpp | 4 ++-- src/rips.h | 2 +- src/tdautils/filtrationDiag.h | 8 +++++++- src/tdautils/typecastUtils.h | 25 +++++++++++++++++++++++-- 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/R/filtrationDiag.R b/R/filtrationDiag.R index 5990fab..0945278 100644 --- a/R/filtrationDiag.R +++ b/R/filtrationDiag.R @@ -15,8 +15,11 @@ filtrationDiag <- function( if (library == "dionysus" || library == "DIONYSUS") { library <- "Dionysus" } - if (library != "GUDHI" && library != "Dionysus") { - stop("library for computing persistence diagram should be a string: either 'GUDHI' or 'Dionysus'") + if (library == "D2") { + library <- "D2" + } + if (library != "GUDHI" && library != "Dionysus" && library != "D2") { + stop("library for computing persistence diagram should be a string: either 'GUDHI' or 'Dionysus' or 'Dionysus2'") } if (!is.logical(location)) { stop("location should be logical") @@ -71,4 +74,4 @@ filtrationDiag <- function( "deathLocation" = DeathLocation, "cycleLocation" = CycleLocation) } return (out) -} \ No newline at end of file +} diff --git a/src/diag.cpp b/src/diag.cpp index aa44399..8fb1d21 100644 --- a/src/diag.cpp +++ b/src/diag.cpp @@ -20,7 +20,7 @@ // for Dionysus #include - +#include // for phat #include @@ -509,4 +509,4 @@ Rcpp::List AlphaComplexDiag( concatStlToRcpp< Rcpp::NumericMatrix >(persDgm, true, 3), concatStlToRcpp< Rcpp::NumericMatrix >(persLoc, false, 2), StlToRcppMatrixList< Rcpp::List, Rcpp::NumericMatrix >(persCycle)); -} \ No newline at end of file +} diff --git a/src/rips.h b/src/rips.h index 9d5c69c..32e811e 100644 --- a/src/rips.h +++ b/src/rips.h @@ -10,7 +10,7 @@ // for Dionysus #include -// for phat +// for phat #include // for Rips diff --git a/src/tdautils/filtrationDiag.h b/src/tdautils/filtrationDiag.h index a458a2a..fb649de 100644 --- a/src/tdautils/filtrationDiag.h +++ b/src/tdautils/filtrationDiag.h @@ -64,6 +64,12 @@ inline void filtrationDiagSorted( smplxTree, coeff_field_characteristic, min_persistence, maxdimension, printProgress, persDgm); } + else if (library[0] == 'D' && library[1] == '2') { + FiltrationDiagDionysus2( + filtrationTdaToDionysus2< VertexVector, Fltr2>( + cmplx, values, idxShift), + maxdimension, location, printProgress, persDgm, persLoc, persCycle); + } else if (library[0] == 'D') { FiltrationDiagDionysus< Persistence >( filtrationTdaToDionysus< VertexVector, Fltr >( @@ -138,4 +144,4 @@ inline void filtrationDiag( -# endif // __FILTRATIONDIAG_H__ \ No newline at end of file +# endif // __FILTRATIONDIAG_H__ diff --git a/src/tdautils/typecastUtils.h b/src/tdautils/typecastUtils.h index 51472e2..7f195e2 100644 --- a/src/tdautils/typecastUtils.h +++ b/src/tdautils/typecastUtils.h @@ -664,7 +664,28 @@ inline Filtration filtrationTdaToDionysus( return filtration; } - +//Marker D2 +template< typename IntegerVector, typename Filtration, typename VectorList, + typename RealVector > +inline Filtration filtrationTdaToDionysus2( + const VectorList & cmplx, const RealVector & values, + const unsigned idxShift) { + Filtration filtration; + typename VectorList::const_iterator iCmplx = cmplx.begin(); + typename RealVector::const_iterator iValue = values.begin(); + for (; iCmplx != cmplx.end(); ++iCmplx, ++iValue) { + const IntegerVector tdaVec(*iCmplx); + IntegerVector dionysusVec(tdaVec.size()); + typename IntegerVector::const_iterator iTda = tdaVec.begin(); + typename IntegerVector::iterator iDionysus = dionysusVec.begin(); + for (; iTda != tdaVec.end(); ++iTda, ++iDionysus) { + // R is 1-base, while C++ is 0-base + *iDionysus = *iTda - idxShift; + } + filtration.push_back(typename Filtration::Cell(dionysusVec, *iValue)); + } + return filtration; +} template< typename Filtration, typename RcppVector, typename RcppList > inline Filtration filtrationRcppToDionysus(const RcppList & rcppList) { @@ -760,4 +781,4 @@ inline void filtrationDionysusToPhat( -# endif // __TYPECASTUTILS_H__ \ No newline at end of file +# endif // __TYPECASTUTILS_H__ From 5198f18e2ac7622c1333e7560cf5c171a2ab1007 Mon Sep 17 00:00:00 2001 From: thomashli Date: Thu, 13 Dec 2018 21:20:32 -0800 Subject: [PATCH 13/29] updating --- src/rips.h | 12 +++++++++- src/tdautils/diagramDS.h | 1 + src/tdautils/dionysus2Utils.h | 43 +++++++++++++++++++++++++++++++++++ src/tdautils/typecastUtils.h | 33 +++++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 1 deletion(-) diff --git a/src/rips.h b/src/rips.h index 32e811e..f62d60f 100644 --- a/src/rips.h +++ b/src/rips.h @@ -14,6 +14,7 @@ #include // for Rips +#include #include #include @@ -59,10 +60,19 @@ inline void ripsFiltration( maxdimension, maxscale, printProgress, print); filtrationGudhiToTda< IntVector >(smplxTree, cmplx, values, boundary); } + else { if (dist[0] == 'e') { - // RipsDiag for L2 distance + // RipsDiag for L2 distance + /* + if (library[0] == 'D' && library([1] == '2') { + filtrationDionysus2ToTda< IntVector >( + RipsFiltrationDionysus2< PairDistances2, Generator2, FltrR2 >(X, nSample, + nDim, false, maxdimension, maxscale, printProgress, print), + cmplx, values, boundary); + } + */ filtrationDionysusToTda< IntVector >( RipsFiltrationDionysus< PairDistances, Generator, FltrR >(X, nSample, nDim, false, maxdimension, maxscale, printProgress, print), diff --git a/src/tdautils/diagramDS.h b/src/tdautils/diagramDS.h index 84333b1..8cbca88 100644 --- a/src/tdautils/diagramDS.h +++ b/src/tdautils/diagramDS.h @@ -4,6 +4,7 @@ #include #include +#include namespace d = dionysus; diff --git a/src/tdautils/dionysus2Utils.h b/src/tdautils/dionysus2Utils.h index 674c4ad..f379f45 100644 --- a/src/tdautils/dionysus2Utils.h +++ b/src/tdautils/dionysus2Utils.h @@ -85,4 +85,47 @@ void FiltrationDiagDionysus2( } } +template< typename Distances, typename Generator, typename Filtration, + typename RealMatrix, typename Print > +inline Filtration RipsFiltrationDionysus2( + const RealMatrix & X, + const unsigned nSample, + const unsigned nDim, + const bool is_row_names, + const int maxdimension, + const double maxscale, + const bool printProgress, + const Print & print +) { + + PointContainer points = TdaToStl< PointContainer >(X, nSample, nDim, + is_row_names); + //lol copy paste + //read_points(infilename, points); + //read_points2(infilename, points); + + Distances distances(points); //PairDistances2 + Generator rips(distances); //Generator2 + typename Generator::Evaluator size(distances); + Filtration filtration; + //EvaluatePushBack< Filtration, typename Generator::Evaluator > functor(filtration, size); + auto functor = [&filtration](Simplex2&& s) { filtration.push_back(s); }; + // Generate maxdimension skeleton of the Rips complex + // rips.generate(skeleton, max_distance, [&filtration](Simplex&& s) { filtration.push_back(s); }); + + rips.generate(maxdimension + 1, maxscale, functor); + + if (printProgress) { + print("# Generated complex of size: %d \n", filtration.size()); + } + + // Sort the simplices with respect to comparison criteria + // e.g. distance or function values + // filtration.sort(ComparisonDataDimension< typename Filtration::Simplex >()); + filtration.sort(Generator::Comparison(distances)); + + return filtration; +} + + #endif __DIONYSUS2UTILS_H__ diff --git a/src/tdautils/typecastUtils.h b/src/tdautils/typecastUtils.h index 7f195e2..9a1c6cf 100644 --- a/src/tdautils/typecastUtils.h +++ b/src/tdautils/typecastUtils.h @@ -602,6 +602,39 @@ inline void filtrationDionysusToTda( } +template< typename IntegerVector, typename Filtration, typename VectorList, + typename RealVector > +inline void filtrationDionysus2ToTda( + const Filtration & filtration, VectorList & cmplx, RealVector & values, + VectorList & boundary) { + + const unsigned nFltr = filtration.size(); + std::map< typename Filtration::Cell, unsigned, + typename Filtration::Cell::VertexComparison > simplex_map; + unsigned size_of_simplex_map = 0; + + cmplx = VectorList(nFltr); + values = RealVector(nFltr); + boundary = VectorList(nFltr); + typename VectorList::iterator iCmplx = cmplx.begin(); + typename RealVector::iterator iValue = values.begin(); + typename VectorList::iterator iBdy = boundary.begin(); + + for (typename Filtration::Index it = filtration.begin(); + it != filtration.end(); ++it, ++iCmplx, ++iValue, ++iBdy) { + const typename Filtration::Simplex & c = filtration.simplex(it); + + IntegerVector cmplxVec; + IntegerVector boundaryVec; + filtrationDionysusOne(c, simplex_map, 1, cmplxVec, *iValue, boundaryVec); + *iCmplx = cmplxVec; + *iBdy = boundaryVec; + + simplex_map.insert(typename + std::map< typename Filtration::Simplex, unsigned >::value_type( + c, size_of_simplex_map++)); + } +} template< typename RcppList, typename RcppVector, typename Filtration > inline RcppList filtrationDionysusToRcpp(const Filtration & filtration) { From aa3a09d6ec4c2c3a4c51392625c30f648e24435c Mon Sep 17 00:00:00 2001 From: thomashli Date: Thu, 19 Jul 2018 11:29:03 -0600 Subject: [PATCH 14/29] moved dionysus up 1 --- src/dionysus/backward.hpp | 2212 ----------------- src/dionysus/{dionysus => }/chain.h | 0 src/dionysus/{dionysus => }/chain.hpp | 0 .../{dionysus => }/clearing-reduction.h | 0 .../{dionysus => }/clearing-reduction.hpp | 0 .../{dionysus => }/cohomology-persistence.h | 0 .../{dionysus => }/cohomology-persistence.hpp | 0 src/dionysus/{dionysus => }/diagram.h | 0 src/dionysus/{dionysus => }/distances.h | 0 src/dionysus/{dionysus => }/distances.hpp | 0 src/dionysus/{dionysus => }/dlog/progress.h | 0 src/dionysus/{dionysus => }/fields/q.h | 0 src/dionysus/{dionysus => }/fields/z2.h | 0 src/dionysus/{dionysus => }/fields/zp.h | 0 src/dionysus/{dionysus => }/filtration.h | 0 .../{dionysus => }/omni-field-persistence.h | 0 .../{dionysus => }/omni-field-persistence.hpp | 0 .../{dionysus => }/ordinary-persistence.h | 0 src/dionysus/{dionysus => }/pair-recorder.h | 0 src/dionysus/{dionysus => }/reduced-matrix.h | 0 .../{dionysus => }/reduced-matrix.hpp | 0 src/dionysus/{dionysus => }/reduction.h | 0 .../{dionysus => }/relative-homology-zigzag.h | 0 .../relative-homology-zigzag.hpp | 0 src/dionysus/{dionysus => }/rips.h | 0 src/dionysus/{dionysus => }/rips.hpp | 0 src/dionysus/{dionysus => }/row-reduction.h | 0 src/dionysus/{dionysus => }/row-reduction.hpp | 0 src/dionysus/{dionysus => }/simplex.h | 0 .../{dionysus => }/sparse-row-matrix.h | 0 .../{dionysus => }/sparse-row-matrix.hpp | 0 .../{dionysus => }/standard-reduction.h | 0 .../{dionysus => }/standard-reduction.hpp | 0 src/dionysus/{dionysus => }/trails-chains.h | 0 .../{dionysus => }/zigzag-persistence.h | 0 .../{dionysus => }/zigzag-persistence.hpp | 0 src/tdautils/dionysusUtils.h | 7 +- 37 files changed, 3 insertions(+), 2216 deletions(-) delete mode 100755 src/dionysus/backward.hpp rename src/dionysus/{dionysus => }/chain.h (100%) rename src/dionysus/{dionysus => }/chain.hpp (100%) rename src/dionysus/{dionysus => }/clearing-reduction.h (100%) rename src/dionysus/{dionysus => }/clearing-reduction.hpp (100%) rename src/dionysus/{dionysus => }/cohomology-persistence.h (100%) rename src/dionysus/{dionysus => }/cohomology-persistence.hpp (100%) rename src/dionysus/{dionysus => }/diagram.h (100%) rename src/dionysus/{dionysus => }/distances.h (100%) rename src/dionysus/{dionysus => }/distances.hpp (100%) rename src/dionysus/{dionysus => }/dlog/progress.h (100%) rename src/dionysus/{dionysus => }/fields/q.h (100%) rename src/dionysus/{dionysus => }/fields/z2.h (100%) rename src/dionysus/{dionysus => }/fields/zp.h (100%) rename src/dionysus/{dionysus => }/filtration.h (100%) rename src/dionysus/{dionysus => }/omni-field-persistence.h (100%) rename src/dionysus/{dionysus => }/omni-field-persistence.hpp (100%) rename src/dionysus/{dionysus => }/ordinary-persistence.h (100%) rename src/dionysus/{dionysus => }/pair-recorder.h (100%) rename src/dionysus/{dionysus => }/reduced-matrix.h (100%) rename src/dionysus/{dionysus => }/reduced-matrix.hpp (100%) rename src/dionysus/{dionysus => }/reduction.h (100%) rename src/dionysus/{dionysus => }/relative-homology-zigzag.h (100%) rename src/dionysus/{dionysus => }/relative-homology-zigzag.hpp (100%) rename src/dionysus/{dionysus => }/rips.h (100%) rename src/dionysus/{dionysus => }/rips.hpp (100%) rename src/dionysus/{dionysus => }/row-reduction.h (100%) rename src/dionysus/{dionysus => }/row-reduction.hpp (100%) rename src/dionysus/{dionysus => }/simplex.h (100%) rename src/dionysus/{dionysus => }/sparse-row-matrix.h (100%) rename src/dionysus/{dionysus => }/sparse-row-matrix.hpp (100%) rename src/dionysus/{dionysus => }/standard-reduction.h (100%) rename src/dionysus/{dionysus => }/standard-reduction.hpp (100%) rename src/dionysus/{dionysus => }/trails-chains.h (100%) rename src/dionysus/{dionysus => }/zigzag-persistence.h (100%) rename src/dionysus/{dionysus => }/zigzag-persistence.hpp (100%) diff --git a/src/dionysus/backward.hpp b/src/dionysus/backward.hpp deleted file mode 100755 index 6b331ba..0000000 --- a/src/dionysus/backward.hpp +++ /dev/null @@ -1,2212 +0,0 @@ -/* - * backward.hpp - * Copyright 2013 Google Inc. All Rights Reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#ifndef H_6B9572DA_A64B_49E6_B234_051480991C89 -#define H_6B9572DA_A64B_49E6_B234_051480991C89 - -#ifndef __cplusplus -# error "It's not going to compile without a C++ compiler..." -#endif - -#if defined(BACKWARD_CXX11) -#elif defined(BACKWARD_CXX98) -#else -# if __cplusplus >= 201103L -# define BACKWARD_CXX11 -# else -# define BACKWARD_CXX98 -# endif -#endif - -// You can define one of the following (or leave it to the auto-detection): -// -// #define BACKWARD_SYSTEM_LINUX -// - specialization for linux -// -// #define BACKWARD_SYSTEM_UNKNOWN -// - placebo implementation, does nothing. -// -#if defined(BACKWARD_SYSTEM_LINUX) -#elif defined(BACKWARD_SYSTEM_UNKNOWN) -#else -# if defined(__linux) -# define BACKWARD_SYSTEM_LINUX -# else -# define BACKWARD_SYSTEM_UNKNOWN -# endif -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(BACKWARD_SYSTEM_LINUX) - -// On linux, backtrace can back-trace or "walk" the stack using the following -// library: -// -// #define BACKWARD_HAS_UNWIND 1 -// - unwind comes from libgcc, but I saw an equivalent inside clang itself. -// - with unwind, the stacktrace is as accurate as it can possibly be, since -// this is used by the C++ runtine in gcc/clang for stack unwinding on -// exception. -// - normally libgcc is already linked to your program by default. -// -// #define BACKWARD_HAS_BACKTRACE == 1 -// - backtrace seems to be a little bit more portable than libunwind, but on -// linux, it uses unwind anyway, but abstract away a tiny information that is -// sadly really important in order to get perfectly accurate stack traces. -// - backtrace is part of the (e)glib library. -// -// The default is: -// #define BACKWARD_HAS_UNWIND == 1 -// -# if BACKWARD_HAS_UNWIND == 1 -# elif BACKWARD_HAS_BACKTRACE == 1 -# else -# undef BACKWARD_HAS_UNWIND -# define BACKWARD_HAS_UNWIND 1 -# undef BACKWARD_HAS_BACKTRACE -# define BACKWARD_HAS_BACKTRACE 0 -# endif - -// On linux, backward can extract detailed information about a stack trace -// using one of the following library: -// -// #define BACKWARD_HAS_DW 1 -// - libdw gives you the most juicy details out of your stack traces: -// - object filename -// - function name -// - source filename -// - line and column numbers -// - source code snippet (assuming the file is accessible) -// - variables name and values (if not optimized out) -// - You need to link with the lib "dw": -// - apt-get install libdw-dev -// - g++/clang++ -ldw ... -// -// #define BACKWARD_HAS_BFD 1 -// - With libbfd, you get a fair about of details: -// - object filename -// - function name -// - source filename -// - line numbers -// - source code snippet (assuming the file is accessible) -// - You need to link with the lib "bfd": -// - apt-get install binutils-dev -// - g++/clang++ -lbfd ... -// -// #define BACKWARD_HAS_BACKTRACE_SYMBOL 1 -// - backtrace provides minimal details for a stack trace: -// - object filename -// - function name -// - backtrace is part of the (e)glib library. -// -// The default is: -// #define BACKWARD_HAS_BACKTRACE_SYMBOL == 1 -// -# if BACKWARD_HAS_DW == 1 -# elif BACKWARD_HAS_BFD == 1 -# elif BACKWARD_HAS_BACKTRACE_SYMBOL == 1 -# else -# undef BACKWARD_HAS_DW -# define BACKWARD_HAS_DW 0 -# undef BACKWARD_HAS_BFD -# define BACKWARD_HAS_BFD 0 -# undef BACKWARD_HAS_BACKTRACE_SYMBOL -# define BACKWARD_HAS_BACKTRACE_SYMBOL 1 -# endif - - -# if BACKWARD_HAS_UNWIND == 1 - -# include -// while gcc's unwind.h defines something like that: -// extern _Unwind_Ptr _Unwind_GetIP (struct _Unwind_Context *); -// extern _Unwind_Ptr _Unwind_GetIPInfo (struct _Unwind_Context *, int *); -// -// clang's unwind.h defines something like this: -// uintptr_t _Unwind_GetIP(struct _Unwind_Context* __context); -// -// Even if the _Unwind_GetIPInfo can be linked to, it is not declared, worse we -// cannot just redeclare it because clang's unwind.h doesn't define _Unwind_Ptr -// anyway. -// -// Luckily we can play on the fact that the guard macros have a different name: -#ifdef __CLANG_UNWIND_H -// In fact, this function still comes from libgcc (on my different linux boxes, -// clang links against libgcc). -# include -extern "C" uintptr_t _Unwind_GetIPInfo(_Unwind_Context*, int*); -#endif - -# endif - -# include -# include -# include -# include -# include -# include -# include - -# if BACKWARD_HAS_BFD == 1 -# include -# ifndef _GNU_SOURCE -# define _GNU_SOURCE -# include -# undef _GNU_SOURCE -# else -# include -# endif -# endif - -# if BACKWARD_HAS_DW == 1 -# include -# include -# include -# endif - -# if (BACKWARD_HAS_BACKTRACE == 1) || (BACKWARD_HAS_BACKTRACE_SYMBOL == 1) - // then we shall rely on backtrace -# include -# endif - -#endif // defined(BACKWARD_SYSTEM_LINUX) - -#if defined(BACKWARD_CXX11) -# include -# include // for std::swap - namespace backward { - namespace details { - template - struct hashtable { - typedef std::unordered_map type; - }; - using std::move; - } // namespace details - } // namespace backward -#elif defined(BACKWARD_CXX98) -# include - namespace backward { - namespace details { - template - struct hashtable { - typedef std::map type; - }; - template - const T& move(const T& v) { return v; } - template - T& move(T& v) { return v; } - } // namespace details - } // namespace backward -#else -# error "Mmm if its not C++11 nor C++98... go play in the toaster." -#endif - -namespace backward { - -namespace system_tag { - struct linux_tag; // seems that I cannot call that "linux" because the name - // is already defined... so I am adding _tag everywhere. - struct unknown_tag; - -#if defined(BACKWARD_SYSTEM_LINUX) - typedef linux_tag current_tag; -#elif defined(BACKWARD_SYSTEM_UNKNOWN) - typedef unknown_tag current_tag; -#else -# error "May I please get my system defines?" -#endif -} // namespace system_tag - - -namespace stacktrace_tag { -#ifdef BACKWARD_SYSTEM_LINUX - struct unwind; - struct backtrace; - -# if BACKWARD_HAS_UNWIND == 1 - typedef unwind current; -# elif BACKWARD_HAS_BACKTRACE == 1 - typedef backtrace current; -# else -# error "I know it's difficult but you need to make a choice!" -# endif -#endif // BACKWARD_SYSTEM_LINUX -} // namespace stacktrace_tag - - -namespace trace_resolver_tag { -#ifdef BACKWARD_SYSTEM_LINUX - struct libdw; - struct libbfd; - struct backtrace_symbol; - -# if BACKWARD_HAS_DW == 1 - typedef libdw current; -# elif BACKWARD_HAS_BFD == 1 - typedef libbfd current; -# elif BACKWARD_HAS_BACKTRACE_SYMBOL == 1 - typedef backtrace_symbol current; -# else -# error "You shall not pass, until you know what you want." -# endif -#endif // BACKWARD_SYSTEM_LINUX -} // namespace trace_resolver_tag - -namespace details { - -template - struct rm_ptr { typedef T type; }; - -template - struct rm_ptr { typedef T type; }; - -template - struct rm_ptr { typedef const T type; }; - -template -struct deleter { - template - void operator()(U& ptr) const { - (*F)(ptr); - } -}; - -template -struct default_delete { - void operator()(T& ptr) const { - delete ptr; - } -}; - -template > -class handle { - struct dummy; - T _val; - bool _empty; - -#if defined(BACKWARD_CXX11) - handle(const handle&) = delete; - handle& operator=(const handle&) = delete; -#endif - -public: - ~handle() { - if (not _empty) { - Deleter()(_val); - } - } - - explicit handle(): _val(), _empty(true) {} - explicit handle(T val): _val(val), _empty(false) {} - -#if defined(BACKWARD_CXX11) - handle(handle&& from): _empty(true) { - swap(from); - } - handle& operator=(handle&& from) { - swap(from); return *this; - } -#else - explicit handle(const handle& from): _empty(true) { - // some sort of poor man's move semantic. - swap(const_cast(from)); - } - handle& operator=(const handle& from) { - // some sort of poor man's move semantic. - swap(const_cast(from)); return *this; - } -#endif - - void reset(T new_val) { - handle tmp(new_val); - swap(tmp); - } - operator const dummy*() const { - if (_empty) { - return 0; - } - return reinterpret_cast(_val); - } - T get() { - return _val; - } - T release() { - _empty = true; - return _val; - } - void swap(handle& b) { - using std::swap; - swap(b._val, _val); // can throw, we are safe here. - swap(b._empty, _empty); // should not throw: if you cannot swap two - // bools without throwing... It's a lost cause anyway! - } - - T operator->() { return _val; } - const T operator->() const { return _val; } - - typedef typename rm_ptr::type& ref_t; - ref_t operator*() { return *_val; } - const ref_t operator*() const { return *_val; } - ref_t operator[](size_t idx) { return _val[idx]; } - - // Watch out, we've got a badass over here - T* operator&() { - _empty = false; - return &_val; - } -}; - -} // namespace details - -/*************** A TRACE ***************/ - -struct Trace { - void* addr; - size_t idx; - - Trace(): - addr(0), idx(0) {} - - explicit Trace(void* addr, size_t idx): - addr(addr), idx(idx) {} -}; - -// Really simple, generic, and dumb representation of a variable. -// A variable has a name and can represent either: -// - a value (as a string) -// - a list of values (a list of strings) -// - a map of values (a list of variable) -class Variable { -public: - enum Kind { VALUE, LIST, MAP }; - - typedef std::vector list_t; - typedef std::vector map_t; - - std::string name; - Kind kind; - - Variable(Kind k): kind(k) { - switch (k) { - case VALUE: - new (&storage) std::string(); - break; - - case LIST: - new (&storage) list_t(); - break; - - case MAP: - new (&storage) map_t(); - break; - } - } - - std::string& value() { - return reinterpret_cast(storage); - } - list_t& list() { - return reinterpret_cast(storage); - } - map_t& map() { - return reinterpret_cast(storage); - } - - - const std::string& value() const { - return reinterpret_cast(storage); - } - const list_t& list() const { - return reinterpret_cast(storage); - } - const map_t& map() const { - return reinterpret_cast(storage); - } - -private: - // the C++98 style union for non-trivial objects, yes yes I know, its not - // aligned as good as it can be, blabla... Screw this. - union { - char s1[sizeof (std::string)]; - char s2[sizeof (list_t)]; - char s3[sizeof (map_t)]; - } storage; -}; - -struct TraceWithLocals: public Trace { - // Locals variable and values. - std::vector locals; - - TraceWithLocals(): Trace() {} - TraceWithLocals(const Trace& mini_trace): - Trace(mini_trace) {} -}; - -struct ResolvedTrace: public TraceWithLocals { - - struct SourceLoc { - std::string function; - std::string filename; - unsigned line; - unsigned col; - - SourceLoc(): line(0), col(0) {} - - bool operator==(const SourceLoc& b) const { - return function == b.function - and filename == b.filename - and line == b.line - and col == b.col; - } - - bool operator!=(const SourceLoc& b) const { - return not (*this == b); - } - }; - - // In which binary object this trace is located. - std::string object_filename; - - // The function in the object that contain the trace. This is not the same - // as source.function which can be an function inlined in object_function. - std::string object_function; - - // The source location of this trace. It is possible for filename to be - // empty and for line/col to be invalid (value 0) if this information - // couldn't be deduced, for example if there is no debug information in the - // binary object. - SourceLoc source; - - // An optionals list of "inliners". All the successive sources location - // from where the source location of the trace (the attribute right above) - // is inlined. It is especially useful when you compiled with optimization. - typedef std::vector source_locs_t; - source_locs_t inliners; - - ResolvedTrace(const Trace& mini_trace): - TraceWithLocals(mini_trace) {} - ResolvedTrace(const TraceWithLocals& mini_trace_with_locals): - TraceWithLocals(mini_trace_with_locals) {} -}; - -/*************** STACK TRACE ***************/ - -// default implemention. -template -class StackTraceImpl { -public: - size_t size() const { return 0; } - Trace operator[](size_t) { return Trace(); } - size_t load_here(size_t=0) { return 0; } - size_t load_from(void*, size_t=0) { return 0; } - unsigned thread_id() const { return 0; } -}; - -#ifdef BACKWARD_SYSTEM_LINUX - -class StackTraceLinuxImplBase { -public: - StackTraceLinuxImplBase(): _thread_id(0), _skip(0) {} - - unsigned thread_id() const { - return _thread_id; - } - -protected: - void load_thread_info() { - _thread_id = syscall(SYS_gettid); - if (_thread_id == (size_t) getpid()) { - // If the thread is the main one, let's hide that. - // I like to keep little secret sometimes. - _thread_id = 0; - } - } - - void skip_n_firsts(size_t n) { _skip = n; } - size_t skip_n_firsts() const { return _skip; } - -private: - size_t _thread_id; - size_t _skip; -}; - -class StackTraceLinuxImplHolder: public StackTraceLinuxImplBase { -public: - size_t size() const { - return _stacktrace.size() ? _stacktrace.size() - skip_n_firsts() : 0; - } - Trace operator[](size_t idx) { - if (idx >= size()) { - return Trace(); - } - return Trace(_stacktrace[idx + skip_n_firsts()], idx); - } - void** begin() { - if (size()) { - return &_stacktrace[skip_n_firsts()]; - } - return 0; - } - -protected: - std::vector _stacktrace; -}; - - -#if BACKWARD_HAS_UNWIND == 1 - -namespace details { - -template -class Unwinder { -public: - size_t operator()(F& f, size_t depth) { - _f = &f; - _index = -1; - _depth = depth; - _Unwind_Backtrace(&this->backtrace_trampoline, this); - return _index; - } - -private: - F* _f; - ssize_t _index; - size_t _depth; - - static _Unwind_Reason_Code backtrace_trampoline( - _Unwind_Context* ctx, void *self) { - return ((Unwinder*)self)->backtrace(ctx); - } - - _Unwind_Reason_Code backtrace(_Unwind_Context* ctx) { - if (_index >= 0 and static_cast(_index) >= _depth) - return _URC_END_OF_STACK; - - int ip_before_instruction = 0; - uintptr_t ip = _Unwind_GetIPInfo(ctx, &ip_before_instruction); - - if (not ip_before_instruction) { - ip -= 1; - } - - if (_index >= 0) { // ignore first frame. - (*_f)(_index, (void*)ip); - } - _index += 1; - return _URC_NO_REASON; - } -}; - -template -size_t unwind(F f, size_t depth) { - Unwinder unwinder; - return unwinder(f, depth); -} - -} // namespace details - - -template <> -class StackTraceImpl: public StackTraceLinuxImplHolder { -public: - __attribute__ ((noinline)) // TODO use some macro - size_t load_here(size_t depth=32) { - load_thread_info(); - if (depth == 0) { - return 0; - } - _stacktrace.resize(depth); - size_t trace_cnt = details::unwind(callback(*this), depth); - _stacktrace.resize(trace_cnt); - skip_n_firsts(0); - return size(); - } - size_t load_from(void* addr, size_t depth=32) { - load_here(depth + 8); - - for (size_t i = 0; i < _stacktrace.size(); ++i) { - if (_stacktrace[i] == addr) { - skip_n_firsts(i); - break; - } - } - - _stacktrace.resize(std::min(_stacktrace.size(), - skip_n_firsts() + depth)); - return size(); - } - -private: - struct callback { - StackTraceImpl& self; - callback(StackTraceImpl& self): self(self) {} - - void operator()(size_t idx, void* addr) { - self._stacktrace[idx] = addr; - } - }; -}; - - -#else // BACKWARD_HAS_UNWIND == 0 - -template <> -class StackTraceImpl: public StackTraceLinuxImplHolder { -public: - __attribute__ ((noinline)) // TODO use some macro - size_t load_here(size_t depth=32) { - load_thread_info(); - if (depth == 0) { - return 0; - } - _stacktrace.resize(depth + 1); - size_t trace_cnt = backtrace(&_stacktrace[0], _stacktrace.size()); - _stacktrace.resize(trace_cnt); - skip_n_firsts(1); - return size(); - } - - size_t load_from(void* addr, size_t depth=32) { - load_here(depth + 8); - - for (size_t i = 0; i < _stacktrace.size(); ++i) { - if (_stacktrace[i] == addr) { - skip_n_firsts(i); - _stacktrace[i] = (void*)( (uintptr_t)_stacktrace[i] + 1); - break; - } - } - - _stacktrace.resize(std::min(_stacktrace.size(), - skip_n_firsts() + depth)); - return size(); - } -}; - -#endif // BACKWARD_HAS_UNWIND -#endif // BACKWARD_SYSTEM_LINUX - -class StackTrace: - public StackTraceImpl {}; - -/*********** STACKTRACE WITH LOCALS ***********/ - -// default implemention. -template -class StackTraceWithLocalsImpl: - public StackTrace {}; - -#ifdef BACKWARD_SYSTEM_LINUX -#if BACKWARD_HAS_UNWIND -#if BACKWARD_HAS_DW - -template <> -class StackTraceWithLocalsImpl: - public StackTraceLinuxImplBase { -public: - __attribute__ ((noinline)) // TODO use some macro - size_t load_here(size_t depth=32) { - load_thread_info(); - if (depth == 0) { - return 0; - } - _stacktrace.resize(depth); - size_t trace_cnt = details::unwind(callback(*this), depth); - _stacktrace.resize(trace_cnt); - skip_n_firsts(0); - return size(); - } - size_t load_from(void* addr, size_t depth=32) { - load_here(depth + 8); - - for (size_t i = 0; i < _stacktrace.size(); ++i) { - if (_stacktrace[i].addr == addr) { - skip_n_firsts(i); - break; - } - } - _stacktrace.resize(std::min(_stacktrace.size(), - skip_n_firsts() + depth)); - return size(); - } - size_t size() const { - return _stacktrace.size() ? _stacktrace.size() - skip_n_firsts() : 0; - } - const TraceWithLocals& operator[](size_t idx) { - if (idx >= size()) { - return _nil_trace; - } - return _stacktrace[idx + skip_n_firsts()]; - } - -private: - std::vector _stacktrace; - TraceWithLocals _nil_trace; - - void resolve_trace(TraceWithLocals& trace) { - Variable v(Variable::VALUE); - v.name = "var"; - v.value() = "42"; - trace.locals.push_back(v); - } - - struct callback { - StackTraceWithLocalsImpl& self; - callback(StackTraceWithLocalsImpl& self): self(self) {} - - void operator()(size_t idx, void* addr) { - self._stacktrace[idx].addr = addr; - self.resolve_trace(self._stacktrace[idx]); - } - }; -}; - -#endif // BACKWARD_HAS_DW -#endif // BACKWARD_HAS_UNWIND -#endif // BACKWARD_SYSTEM_LINUX - -class StackTraceWithLocals: - public StackTraceWithLocalsImpl {}; - -/*************** TRACE RESOLVER ***************/ - -template -class TraceResolverImpl; - -#ifdef BACKWARD_SYSTEM_UNKNOWN - -template <> -class TraceResolverImpl { -public: - template - void load_stacktrace(ST&) {} - ResolvedTrace resolve(ResolvedTrace t) { - return t; - } -}; - -#endif - -#ifdef BACKWARD_SYSTEM_LINUX - -class TraceResolverLinuxImplBase { -protected: - std::string demangle(const char* funcname) { - using namespace details; - _demangle_buffer.reset( - abi::__cxa_demangle(funcname, _demangle_buffer.release(), - &_demangle_buffer_length, 0) - ); - if (_demangle_buffer) { - return _demangle_buffer.get(); - } - return funcname; - } - -private: - details::handle _demangle_buffer; - size_t _demangle_buffer_length; -}; - -template -class TraceResolverLinuxImpl; - -#if BACKWARD_HAS_BACKTRACE_SYMBOL == 1 - -template <> -class TraceResolverLinuxImpl: - public TraceResolverLinuxImplBase { -public: - template - void load_stacktrace(ST& st) { - using namespace details; - if (st.size() == 0) { - return; - } - _symbols.reset( - backtrace_symbols(st.begin(), st.size()) - ); - } - - ResolvedTrace resolve(ResolvedTrace trace) { - char* filename = _symbols[trace.idx]; - char* funcname = filename; - while (*funcname && *funcname != '(') { - funcname += 1; - } - trace.object_filename.assign(filename, funcname++); - char* funcname_end = funcname; - while (*funcname_end && *funcname_end != ')' && *funcname_end != '+') { - funcname_end += 1; - } - *funcname_end = '\0'; - trace.object_function = this->demangle(funcname); - trace.source.function = trace.object_function; // we cannot do better. - return trace; - } - -private: - details::handle _symbols; -}; - -#endif // BACKWARD_HAS_BACKTRACE_SYMBOL == 1 - -#if BACKWARD_HAS_BFD == 1 - -template <> -class TraceResolverLinuxImpl: - public TraceResolverLinuxImplBase { -public: - TraceResolverLinuxImpl(): _bfd_loaded(false) {} - - template - void load_stacktrace(ST&) {} - - ResolvedTrace resolve(ResolvedTrace trace) { - Dl_info symbol_info; - - // trace.addr is a virtual address in memory pointing to some code. - // Let's try to find from which loaded object it comes from. - // The loaded object can be yourself btw. - if (not dladdr(trace.addr, &symbol_info)) { - return trace; // dat broken trace... - } - - // Now we get in symbol_info: - // .dli_fname: - // pathname of the shared object that contains the address. - // .dli_fbase: - // where the object is loaded in memory. - // .dli_sname: - // the name of the nearest symbol to trace.addr, we expect a - // function name. - // .dli_saddr: - // the exact address corresponding to .dli_sname. - - if (symbol_info.dli_sname) { - trace.object_function = demangle(symbol_info.dli_sname); - } - - if (not symbol_info.dli_fname) { - return trace; - } - - trace.object_filename = symbol_info.dli_fname; - bfd_fileobject& fobj = load_object_with_bfd(symbol_info.dli_fname); - if (not fobj.handle) { - return trace; // sad, we couldn't load the object :( - } - - - find_sym_result* details_selected; // to be filled. - - // trace.addr is the next instruction to be executed after returning - // from the nested stack frame. In C++ this usually relate to the next - // statement right after the function call that leaded to a new stack - // frame. This is not usually what you want to see when printing out a - // stacktrace... - find_sym_result details_call_site = find_symbol_details(fobj, - trace.addr, symbol_info.dli_fbase); - details_selected = &details_call_site; - -#if BACKWARD_HAS_UNWIND == 0 - // ...this is why we also try to resolve the symbol that is right - // before the return address. If we are lucky enough, we will get the - // line of the function that was called. But if the code is optimized, - // we might get something absolutely not related since the compiler - // can reschedule the return address with inline functions and - // tail-call optimisation (among other things that I don't even know - // or cannot even dream about with my tiny limited brain). - find_sym_result details_adjusted_call_site = find_symbol_details(fobj, - (void*) (uintptr_t(trace.addr) - 1), - symbol_info.dli_fbase); - - // In debug mode, we should always get the right thing(TM). - if (details_call_site.found and details_adjusted_call_site.found) { - // Ok, we assume that details_adjusted_call_site is a better estimation. - details_selected = &details_adjusted_call_site; - trace.addr = (void*) (uintptr_t(trace.addr) - 1); - } - - if (details_selected == &details_call_site and details_call_site.found) { - // we have to re-resolve the symbol in order to reset some - // internal state in BFD... so we can call backtrace_inliners - // thereafter... - details_call_site = find_symbol_details(fobj, trace.addr, - symbol_info.dli_fbase); - } -#endif // BACKWARD_HAS_UNWIND - - if (details_selected->found) { - if (details_selected->filename) { - trace.source.filename = details_selected->filename; - } - trace.source.line = details_selected->line; - - if (details_selected->funcname) { - // this time we get the name of the function where the code is - // located, instead of the function were the address is - // located. In short, if the code was inlined, we get the - // function correspoding to the code. Else we already got in - // trace.function. - trace.source.function = demangle(details_selected->funcname); - - if (not symbol_info.dli_sname) { - // for the case dladdr failed to find the symbol name of - // the function, we might as well try to put something - // here. - trace.object_function = trace.source.function; - } - } - - // Maybe the source of the trace got inlined inside the function - // (trace.source.function). Let's see if we can get all the inlined - // calls along the way up to the initial call site. - trace.inliners = backtrace_inliners(fobj, *details_selected); - -#if 0 - if (trace.inliners.size() == 0) { - // Maybe the trace was not inlined... or maybe it was and we - // are lacking the debug information. Let's try to make the - // world better and see if we can get the line number of the - // function (trace.source.function) now. - // - // We will get the location of where the function start (to be - // exact: the first instruction that really start the - // function), not where the name of the function is defined. - // This can be quite far away from the name of the function - // btw. - // - // If the source of the function is the same as the source of - // the trace, we cannot say if the trace was really inlined or - // not. However, if the filename of the source is different - // between the function and the trace... we can declare it as - // an inliner. This is not 100% accurate, but better than - // nothing. - - if (symbol_info.dli_saddr) { - find_sym_result details = find_symbol_details(fobj, - symbol_info.dli_saddr, - symbol_info.dli_fbase); - - if (details.found) { - ResolvedTrace::SourceLoc diy_inliner; - diy_inliner.line = details.line; - if (details.filename) { - diy_inliner.filename = details.filename; - } - if (details.funcname) { - diy_inliner.function = demangle(details.funcname); - } else { - diy_inliner.function = trace.source.function; - } - if (diy_inliner != trace.source) { - trace.inliners.push_back(diy_inliner); - } - } - } - } -#endif - } - - return trace; - } - -private: - bool _bfd_loaded; - - typedef details::handle - > bfd_handle_t; - - typedef details::handle bfd_symtab_t; - - - struct bfd_fileobject { - bfd_handle_t handle; - bfd_vma base_addr; - bfd_symtab_t symtab; - bfd_symtab_t dynamic_symtab; - }; - - typedef details::hashtable::type - fobj_bfd_map_t; - fobj_bfd_map_t _fobj_bfd_map; - - bfd_fileobject& load_object_with_bfd(const std::string& filename_object) { - using namespace details; - - if (not _bfd_loaded) { - using namespace details; - bfd_init(); - _bfd_loaded = true; - } - - fobj_bfd_map_t::iterator it = - _fobj_bfd_map.find(filename_object); - if (it != _fobj_bfd_map.end()) { - return it->second; - } - - // this new object is empty for now. - bfd_fileobject& r = _fobj_bfd_map[filename_object]; - - // we do the work temporary in this one; - bfd_handle_t bfd_handle; - - int fd = open(filename_object.c_str(), O_RDONLY); - bfd_handle.reset( - bfd_fdopenr(filename_object.c_str(), "default", fd) - ); - if (not bfd_handle) { - close(fd); - return r; - } - - if (not bfd_check_format(bfd_handle.get(), bfd_object)) { - return r; // not an object? You lose. - } - - if ((bfd_get_file_flags(bfd_handle.get()) & HAS_SYMS) == 0) { - return r; // that's what happen when you forget to compile in debug. - } - - ssize_t symtab_storage_size = - bfd_get_symtab_upper_bound(bfd_handle.get()); - - ssize_t dyn_symtab_storage_size = - bfd_get_dynamic_symtab_upper_bound(bfd_handle.get()); - - if (symtab_storage_size <= 0 and dyn_symtab_storage_size <= 0) { - return r; // weird, is the file is corrupted? - } - - bfd_symtab_t symtab, dynamic_symtab; - ssize_t symcount = 0, dyn_symcount = 0; - - if (symtab_storage_size > 0) { - symtab.reset( - (bfd_symbol**) malloc(symtab_storage_size) - ); - symcount = bfd_canonicalize_symtab( - bfd_handle.get(), symtab.get() - ); - } - - if (dyn_symtab_storage_size > 0) { - dynamic_symtab.reset( - (bfd_symbol**) malloc(dyn_symtab_storage_size) - ); - dyn_symcount = bfd_canonicalize_dynamic_symtab( - bfd_handle.get(), dynamic_symtab.get() - ); - } - - - if (symcount <= 0 and dyn_symcount <= 0) { - return r; // damned, that's a stripped file that you got there! - } - - r.handle = move(bfd_handle); - r.symtab = move(symtab); - r.dynamic_symtab = move(dynamic_symtab); - return r; - } - - struct find_sym_result { - bool found; - const char* filename; - const char* funcname; - unsigned int line; - }; - - struct find_sym_context { - TraceResolverLinuxImpl* self; - bfd_fileobject* fobj; - void* addr; - void* base_addr; - find_sym_result result; - }; - - find_sym_result find_symbol_details(bfd_fileobject& fobj, void* addr, - void* base_addr) { - find_sym_context context; - context.self = this; - context.fobj = &fobj; - context.addr = addr; - context.base_addr = base_addr; - context.result.found = false; - bfd_map_over_sections(fobj.handle.get(), &find_in_section_trampoline, - (void*)&context); - return context.result; - } - - static void find_in_section_trampoline(bfd*, asection* section, - void* data) { - find_sym_context* context = static_cast(data); - context->self->find_in_section( - reinterpret_cast(context->addr), - reinterpret_cast(context->base_addr), - *context->fobj, - section, context->result - ); - } - - void find_in_section(bfd_vma addr, bfd_vma base_addr, - bfd_fileobject& fobj, asection* section, find_sym_result& result) - { - if (result.found) return; - - if ((bfd_get_section_flags(fobj.handle.get(), section) - & SEC_ALLOC) == 0) - return; // a debug section is never loaded automatically. - - bfd_vma sec_addr = bfd_get_section_vma(fobj.handle.get(), section); - bfd_size_type size = bfd_get_section_size(section); - - // are we in the boundaries of the section? - if (addr < sec_addr or addr >= sec_addr + size) { - addr -= base_addr; // oups, a relocated object, lets try again... - if (addr < sec_addr or addr >= sec_addr + size) { - return; - } - } - - if (not result.found and fobj.symtab) { - result.found = bfd_find_nearest_line(fobj.handle.get(), section, - fobj.symtab.get(), addr - sec_addr, &result.filename, - &result.funcname, &result.line); - } - - if (not result.found and fobj.dynamic_symtab) { - result.found = bfd_find_nearest_line(fobj.handle.get(), section, - fobj.dynamic_symtab.get(), addr - sec_addr, - &result.filename, &result.funcname, &result.line); - } - - } - - ResolvedTrace::source_locs_t backtrace_inliners(bfd_fileobject& fobj, - find_sym_result previous_result) { - // This function can be called ONLY after a SUCCESSFUL call to - // find_symbol_details. The state is global to the bfd_handle. - ResolvedTrace::source_locs_t results; - while (previous_result.found) { - find_sym_result result; - result.found = bfd_find_inliner_info(fobj.handle.get(), - &result.filename, &result.funcname, &result.line); - - if (result.found) /* and not ( - cstrings_eq(previous_result.filename, result.filename) - and cstrings_eq(previous_result.funcname, result.funcname) - and result.line == previous_result.line - )) */ { - ResolvedTrace::SourceLoc src_loc; - src_loc.line = result.line; - if (result.filename) { - src_loc.filename = result.filename; - } - if (result.funcname) { - src_loc.function = demangle(result.funcname); - } - results.push_back(src_loc); - } - previous_result = result; - } - return results; - } - - bool cstrings_eq(const char* a, const char* b) { - if (not a or not b) { - return false; - } - return strcmp(a, b) == 0; - } - -}; -#endif // BACKWARD_HAS_BFD == 1 - -#if BACKWARD_HAS_DW == 1 - -template <> -class TraceResolverLinuxImpl: - public TraceResolverLinuxImplBase { -public: - TraceResolverLinuxImpl(): _dwfl_handle_initialized(false) {} - - template - void load_stacktrace(ST&) {} - - ResolvedTrace resolve(ResolvedTrace trace) { - using namespace details; - - Dwarf_Addr trace_addr = (Dwarf_Addr) trace.addr; - - if (not _dwfl_handle_initialized) { - // initialize dwfl... - _dwfl_cb.reset(new Dwfl_Callbacks); - _dwfl_cb->find_elf = &dwfl_linux_proc_find_elf; - _dwfl_cb->find_debuginfo = &dwfl_standard_find_debuginfo; - _dwfl_cb->debuginfo_path = 0; - - _dwfl_handle.reset(dwfl_begin(_dwfl_cb.get())); - _dwfl_handle_initialized = true; - - if (not _dwfl_handle) { - return trace; - } - - // ...from the current process. - dwfl_report_begin(_dwfl_handle.get()); - int r = dwfl_linux_proc_report (_dwfl_handle.get(), getpid()); - dwfl_report_end(_dwfl_handle.get(), NULL, NULL); - if (r < 0) { - return trace; - } - } - - if (not _dwfl_handle) { - return trace; - } - - // find the module (binary object) that contains the trace's address. - // This is not using any debug information, but the addresses ranges of - // all the currently loaded binary object. - Dwfl_Module* mod = dwfl_addrmodule(_dwfl_handle.get(), trace_addr); - if (mod) { - // now that we found it, lets get the name of it, this will be the - // full path to the running binary or one of the loaded library. - const char* module_name = dwfl_module_info (mod, - 0, 0, 0, 0, 0, 0, 0); - if (module_name) { - trace.object_filename = module_name; - } - // We also look after the name of the symbol, equal or before this - // address. This is found by walking the symtab. We should get the - // symbol corresponding to the function (mangled) containing the - // address. If the code corresponding to the address was inlined, - // this is the name of the out-most inliner function. - const char* sym_name = dwfl_module_addrname(mod, trace_addr); - if (sym_name) { - trace.object_function = demangle(sym_name); - } - } - - // now let's get serious, and find out the source location (file and - // line number) of the address. - - // This function will look in .debug_aranges for the address and map it - // to the location of the compilation unit DIE in .debug_info and - // return it. - Dwarf_Addr mod_bias = 0; - Dwarf_Die* cudie = dwfl_module_addrdie(mod, trace_addr, &mod_bias); - -#if 1 - if (not cudie) { - // Sadly clang does not generate the section .debug_aranges, thus - // dwfl_module_addrdie will fail early. Clang doesn't either set - // the lowpc/highpc/range info for every compilation unit. - // - // So in order to save the world: - // for every compilation unit, we will iterate over every single - // DIEs. Normally functions should have a lowpc/highpc/range, which - // we will use to infer the compilation unit. - - // note that this is probably badly inefficient. - while ((cudie = dwfl_module_nextcu(mod, cudie, &mod_bias))) { - Dwarf_Die die_mem; - Dwarf_Die* fundie = find_fundie_by_pc(cudie, - trace_addr - mod_bias, &die_mem); - if (fundie) { - break; - } - } - } -#endif - -//#define BACKWARD_I_DO_NOT_RECOMMEND_TO_ENABLE_THIS_HORRIBLE_PIECE_OF_CODE -#ifdef BACKWARD_I_DO_NOT_RECOMMEND_TO_ENABLE_THIS_HORRIBLE_PIECE_OF_CODE - if (not cudie) { - // If it's still not enough, lets dive deeper in the shit, and try - // to save the world again: for every compilation unit, we will - // load the corresponding .debug_line section, and see if we can - // find our address in it. - - Dwarf_Addr cfi_bias; - Dwarf_CFI* cfi_cache = dwfl_module_eh_cfi(mod, &cfi_bias); - - Dwarf_Addr bias; - while ((cudie = dwfl_module_nextcu(mod, cudie, &bias))) { - if (dwarf_getsrc_die(cudie, trace_addr - bias)) { - - // ...but if we get a match, it might be a false positive - // because our (address - bias) might as well be valid in a - // different compilation unit. So we throw our last card on - // the table and lookup for the address into the .eh_frame - // section. - - handle frame; - dwarf_cfi_addrframe(cfi_cache, trace_addr - cfi_bias, &frame); - if (frame) { - break; - } - } - } - } -#endif - - if (not cudie) { - return trace; // this time we lost the game :/ - } - - // Now that we have a compilation unit DIE, this function will be able - // to load the corresponding section in .debug_line (if not already - // loaded) and hopefully find the source location mapped to our - // address. - Dwarf_Line* srcloc = dwarf_getsrc_die(cudie, trace_addr - mod_bias); - - if (srcloc) { - const char* srcfile = dwarf_linesrc(srcloc, 0, 0); - if (srcfile) { - trace.source.filename = srcfile; - } - int line = 0, col = 0; - dwarf_lineno(srcloc, &line); - dwarf_linecol(srcloc, &col); - trace.source.line = line; - trace.source.col = col; - } - - deep_first_search_by_pc(cudie, trace_addr - mod_bias, - inliners_search_cb(trace)); - if (trace.source.function.size() == 0) { - // fallback. - trace.source.function = trace.object_function; - } - - return trace; - } - -private: - typedef details::handle > - dwfl_handle_t; - details::handle > - _dwfl_cb; - dwfl_handle_t _dwfl_handle; - bool _dwfl_handle_initialized; - - // defined here because in C++98, template function cannot take locally - // defined types... grrr. - struct inliners_search_cb { - void operator()(Dwarf_Die* die) { - switch (dwarf_tag(die)) { - const char* name; - case DW_TAG_subprogram: - if ((name = dwarf_diename(die))) { - trace.source.function = name; - } - break; - - case DW_TAG_inlined_subroutine: - ResolvedTrace::SourceLoc sloc; - Dwarf_Attribute attr_mem; - - if ((name = dwarf_diename(die))) { - trace.source.function = name; - } - if ((name = die_call_file(die))) { - sloc.filename = name; - } - - Dwarf_Word line = 0, col = 0; - dwarf_formudata(dwarf_attr(die, DW_AT_call_line, - &attr_mem), &line); - dwarf_formudata(dwarf_attr(die, DW_AT_call_column, - &attr_mem), &col); - sloc.line = line; - sloc.col = col; - - trace.inliners.push_back(sloc); - break; - }; - } - ResolvedTrace& trace; - inliners_search_cb(ResolvedTrace& t): trace(t) {} - }; - - - static bool die_has_pc(Dwarf_Die* die, Dwarf_Addr pc) { - Dwarf_Addr low, high; - - // continuous range - if (dwarf_hasattr(die, DW_AT_low_pc) and - dwarf_hasattr(die, DW_AT_high_pc)) { - if (dwarf_lowpc(die, &low) != 0) { - return false; - } - if (dwarf_highpc(die, &high) != 0) { - Dwarf_Attribute attr_mem; - Dwarf_Attribute* attr = dwarf_attr(die, DW_AT_high_pc, &attr_mem); - Dwarf_Word value; - if (dwarf_formudata(attr, &value) != 0) { - return false; - } - high = low + value; - } - return pc >= low and pc < high; - } - - // non-continuous range. - Dwarf_Addr base; - ptrdiff_t offset = 0; - while ((offset = dwarf_ranges(die, offset, &base, &low, &high)) > 0) { - if (pc >= low and pc < high) { - return true; - } - } - return false; - } - - static Dwarf_Die* find_fundie_by_pc(Dwarf_Die* parent_die, Dwarf_Addr pc, - Dwarf_Die* result) { - if (dwarf_child(parent_die, result) != 0) { - return 0; - } - - Dwarf_Die* die = result; - do { - switch (dwarf_tag(die)) { - case DW_TAG_subprogram: - case DW_TAG_inlined_subroutine: - if (die_has_pc(die, pc)) { - return result; - } - default: - bool declaration = false; - Dwarf_Attribute attr_mem; - dwarf_formflag(dwarf_attr(die, DW_AT_declaration, - &attr_mem), &declaration); - if (not declaration) { - // let's be curious and look deeper in the tree, - // function are not necessarily at the first level, but - // might be nested inside a namespace, structure etc. - Dwarf_Die die_mem; - Dwarf_Die* indie = find_fundie_by_pc(die, pc, &die_mem); - if (indie) { - *result = die_mem; - return result; - } - } - }; - } while (dwarf_siblingof(die, result) == 0); - return 0; - } - - template - static bool deep_first_search_by_pc(Dwarf_Die* parent_die, - Dwarf_Addr pc, CB cb) { - Dwarf_Die die_mem; - if (dwarf_child(parent_die, &die_mem) != 0) { - return false; - } - - bool branch_has_pc = false; - Dwarf_Die* die = &die_mem; - do { - bool declaration = false; - Dwarf_Attribute attr_mem; - dwarf_formflag(dwarf_attr(die, DW_AT_declaration, &attr_mem), &declaration); - if (not declaration) { - // let's be curious and look deeper in the tree, function are - // not necessarily at the first level, but might be nested - // inside a namespace, structure, a function, an inlined - // function etc. - branch_has_pc = deep_first_search_by_pc(die, pc, cb); - } - if (not branch_has_pc) { - branch_has_pc = die_has_pc(die, pc); - } - if (branch_has_pc) { - cb(die); - } - } while (dwarf_siblingof(die, &die_mem) == 0); - return branch_has_pc; - } - - static const char* die_call_file(Dwarf_Die *die) { - Dwarf_Attribute attr_mem; - Dwarf_Sword file_idx = 0; - - dwarf_formsdata(dwarf_attr(die, DW_AT_call_file, &attr_mem), - &file_idx); - - if (file_idx == 0) { - return 0; - } - - Dwarf_Die die_mem; - Dwarf_Die* cudie = dwarf_diecu(die, &die_mem, 0, 0); - if (not cudie) { - return 0; - } - - Dwarf_Files* files = 0; - size_t nfiles; - dwarf_getsrcfiles(cudie, &files, &nfiles); - if (not files) { - return 0; - } - - return dwarf_filesrc(files, file_idx, 0, 0); - } - -}; -#endif // BACKWARD_HAS_DW == 1 - -template<> -class TraceResolverImpl: - public TraceResolverLinuxImpl {}; - -#endif // BACKWARD_SYSTEM_LINUX - -class TraceResolver: - public TraceResolverImpl {}; - -/*************** CODE SNIPPET ***************/ - -class SourceFile { -public: - typedef std::vector > lines_t; - - SourceFile() {} - SourceFile(const std::string& path): _file(new std::ifstream(path.c_str())) {} - bool is_open() const { return _file->is_open(); } - - lines_t& get_lines(unsigned line_start, unsigned line_count, lines_t& lines) { - using namespace std; - // This function make uses of the dumbest algo ever: - // 1) seek(0) - // 2) read lines one by one and discard until line_start - // 3) read line one by one until line_start + line_count - // - // If you are getting snippets many time from the same file, it is - // somewhat a waste of CPU, feel free to benchmark and propose a - // better solution ;) - - _file->clear(); - _file->seekg(0); - string line; - unsigned line_idx; - - for (line_idx = 1; line_idx < line_start; ++line_idx) { - getline(*_file, line); - if (not *_file) { - return lines; - } - } - - // think of it like a lambda in C++98 ;) - // but look, I will reuse it two times! - // What a good boy am I. - struct isspace { - bool operator()(char c) { - return std::isspace(c); - } - }; - - bool started = false; - for (; line_idx < line_start + line_count; ++line_idx) { - getline(*_file, line); - if (not *_file) { - return lines; - } - if (not started) { - if (std::find_if(line.begin(), line.end(), - not_isspace()) == line.end()) - continue; - started = true; - } - lines.push_back(make_pair(line_idx, line)); - } - - lines.erase( - std::find_if(lines.rbegin(), lines.rend(), - not_isempty()).base(), lines.end() - ); - return lines; - } - - lines_t get_lines(unsigned line_start, unsigned line_count) { - lines_t lines; - return get_lines(line_start, line_count, lines); - } - - // there is no find_if_not in C++98, lets do something crappy to - // workaround. - struct not_isspace { - bool operator()(char c) { - return not std::isspace(c); - } - }; - // and define this one here because C++98 is not happy with local defined - // struct passed to template functions, fuuuu. - struct not_isempty { - bool operator()(const lines_t::value_type& p) { - return not (std::find_if(p.second.begin(), p.second.end(), - not_isspace()) == p.second.end()); - } - }; - - void swap(SourceFile& b) { - _file.swap(b._file); - } - -#if defined(BACKWARD_CXX11) - SourceFile(SourceFile&& from): _file(0) { - swap(from); - } - SourceFile& operator=(SourceFile&& from) { - swap(from); return *this; - } -#else - explicit SourceFile(const SourceFile& from) { - // some sort of poor man's move semantic. - swap(const_cast(from)); - } - SourceFile& operator=(const SourceFile& from) { - // some sort of poor man's move semantic. - swap(const_cast(from)); return *this; - } -#endif - -private: - details::handle - > _file; - -#if defined(BACKWARD_CXX11) - SourceFile(const SourceFile&) = delete; - SourceFile& operator=(const SourceFile&) = delete; -#endif -}; - -class SnippetFactory { -public: - typedef SourceFile::lines_t lines_t; - - lines_t get_snippet(const std::string& filename, - unsigned line_start, unsigned context_size) { - - SourceFile& src_file = get_src_file(filename); - unsigned start = line_start - context_size / 2; - return src_file.get_lines(start, context_size); - } - - lines_t get_combined_snippet( - const std::string& filename_a, unsigned line_a, - const std::string& filename_b, unsigned line_b, - unsigned context_size) { - SourceFile& src_file_a = get_src_file(filename_a); - SourceFile& src_file_b = get_src_file(filename_b); - - lines_t lines = src_file_a.get_lines(line_a - context_size / 4, - context_size / 2); - src_file_b.get_lines(line_b - context_size / 4, context_size / 2, - lines); - return lines; - } - - lines_t get_coalesced_snippet(const std::string& filename, - unsigned line_a, unsigned line_b, unsigned context_size) { - SourceFile& src_file = get_src_file(filename); - - using std::min; using std::max; - unsigned a = min(line_a, line_b); - unsigned b = max(line_a, line_b); - - if ((b - a) < (context_size / 3)) { - return src_file.get_lines((a + b - context_size + 1) / 2, - context_size); - } - - lines_t lines = src_file.get_lines(a - context_size / 4, - context_size / 2); - src_file.get_lines(b - context_size / 4, context_size / 2, lines); - return lines; - } - - -private: - typedef details::hashtable::type src_files_t; - src_files_t _src_files; - - SourceFile& get_src_file(const std::string& filename) { - src_files_t::iterator it = _src_files.find(filename); - if (it != _src_files.end()) { - return it->second; - } - SourceFile& new_src_file = _src_files[filename]; - new_src_file = SourceFile(filename); - return new_src_file; - } -}; - -/*************** PRINTER ***************/ - -#ifdef BACKWARD_SYSTEM_LINUX - -namespace Color { - enum type { - yellow = 33, - purple = 35, - reset = 39 - }; -} // namespace Color - -class Colorize { -public: - Colorize(std::FILE* os): - _os(os), _reset(false), _istty(false) {} - - void init() { - _istty = isatty(fileno(_os)); - } - - void set_color(Color::type ccode) { - if (not _istty) return; - - // I assume that the terminal can handle basic colors. Seriously I - // don't want to deal with all the termcap shit. - fprintf(_os, "\033[%im", static_cast(ccode)); - _reset = (ccode != Color::reset); - } - - ~Colorize() { - if (_reset) { - set_color(Color::reset); - } - } - -private: - std::FILE* _os; - bool _reset; - bool _istty; -}; - -#else // ndef BACKWARD_SYSTEM_LINUX - - -namespace Color { - enum type { - yellow = 0, - purple = 0, - reset = 0 - }; -} // namespace Color - -class Colorize { -public: - Colorize(std::FILE*) {} - void init() {} - void set_color(Color::type) {} -}; - -#endif // BACKWARD_SYSTEM_LINUX - -class Printer { -public: - bool snippet; - bool color; - bool address; - bool object; - - Printer(): - snippet(true), - color(true), - address(false), - object(false) - {} - - template - FILE* print(StackTrace& st, FILE* os = stderr) { - using namespace std; - - Colorize colorize(os); - if (color) { - colorize.init(); - } - - fprintf(os, "Stack trace (most recent call last)"); - if (st.thread_id()) { - fprintf(os, " in thread %u:\n", st.thread_id()); - } else { - fprintf(os, ":\n"); - } - - _resolver.load_stacktrace(st); - for (unsigned trace_idx = st.size(); trace_idx > 0; --trace_idx) { - fprintf(os, "#%-2u", trace_idx); - bool already_indented = true; - const ResolvedTrace trace = _resolver.resolve(st[trace_idx-1]); - - if (not trace.source.filename.size() or object) { - fprintf(os, " Object \"%s\", at %p, in %s\n", - trace.object_filename.c_str(), trace.addr, - trace.object_function.c_str()); - already_indented = false; - } - - if (trace.source.filename.size()) { - for (size_t inliner_idx = trace.inliners.size(); - inliner_idx > 0; --inliner_idx) { - if (not already_indented) { - fprintf(os, " "); - } - const ResolvedTrace::SourceLoc& inliner_loc - = trace.inliners[inliner_idx-1]; - print_source_loc(os, " | ", inliner_loc); - if (snippet) { - print_snippet(os, " | ", inliner_loc, - colorize, Color::purple, 5); - } - already_indented = false; - } - - if (not already_indented) { - fprintf(os, " "); - } - print_source_loc(os, " ", trace.source, trace.addr); - if (snippet) { - print_snippet(os, " ", trace.source, - colorize, Color::yellow, 7); - } - - if (trace.locals.size()) { - print_locals(os, " ", trace.locals); - } - } - } - return os; - } -private: - TraceResolver _resolver; - SnippetFactory _snippets; - - void print_snippet(FILE* os, const char* indent, - const ResolvedTrace::SourceLoc& source_loc, - Colorize& colorize, Color::type color_code, - int context_size) - { - using namespace std; - typedef SnippetFactory::lines_t lines_t; - - lines_t lines = _snippets.get_snippet(source_loc.filename, - source_loc.line, context_size); - - for (lines_t::const_iterator it = lines.begin(); - it != lines.end(); ++it) { - if (it-> first == source_loc.line) { - colorize.set_color(color_code); - fprintf(os, "%s>", indent); - } else { - fprintf(os, "%s ", indent); - } - fprintf(os, "%4u: %s\n", it->first, it->second.c_str()); - if (it-> first == source_loc.line) { - colorize.set_color(Color::reset); - } - } - } - - void print_source_loc(FILE* os, const char* indent, - const ResolvedTrace::SourceLoc& source_loc, - void* addr=0) { - fprintf(os, "%sSource \"%s\", line %i, in %s", - indent, source_loc.filename.c_str(), (int)source_loc.line, - source_loc.function.c_str()); - - if (address and addr != 0) { - fprintf(os, " [%p]\n", addr); - } else { - fprintf(os, "\n"); - } - } - - void print_var(FILE* os, const char* base_indent, int indent, - const Variable& var) { - fprintf(os, "%s%s: ", base_indent, var.name.c_str()); - switch (var.kind) { - case Variable::VALUE: - fprintf(os, "%s\n", var.value().c_str()); - break; - case Variable::LIST: - fprintf(os, "["); - for (size_t i = 0; i < var.list().size(); ++i) { - if (i > 0) { - fprintf(os, ", %s", var.list()[i].c_str()); - } - fprintf(os, "%s", var.list()[i].c_str()); - } - fprintf(os, "]\n"); - break; - case Variable::MAP: - fprintf(os, "{\n"); - for (size_t i = 0; i < var.map().size(); ++i) { - if (i > 0) { - fprintf(os, ",\n%s", base_indent); - } - print_var(os, base_indent, indent + 2, var.map()[i]); - } - fprintf(os, "]\n"); - break; - }; - } - - void print_locals(FILE* os, const char* indent, - const std::vector& locals) { - fprintf(os, "%sLocal variables:\n", indent); - for (size_t i = 0; i < locals.size(); ++i) { - if (i > 0) { - fprintf(os, ",\n%s", indent); - } - print_var(os, indent, 0, locals[i]); - } - } -}; - -/*************** SIGNALS HANDLING ***************/ - -#ifdef BACKWARD_SYSTEM_LINUX - - -class SignalHandling { -public: - static std::vector make_default_signals() { - const int signals[] = { - // default action: Core - SIGILL, - SIGABRT, - SIGFPE, - SIGSEGV, - SIGBUS, - // I am not sure the following signals should be enabled by - // default: - // default action: Term - SIGHUP, - SIGINT, - SIGPIPE, - SIGALRM, - SIGTERM, - SIGUSR1, - SIGUSR2, - SIGPOLL, - SIGPROF, - SIGVTALRM, - SIGIO, - SIGPWR, - // default action: Core - SIGQUIT, - SIGSYS, - SIGTRAP, - SIGXCPU, - SIGXFSZ - }; - return std::vector(signals, signals + sizeof signals); - } - - SignalHandling(const std::vector& signals = make_default_signals()) : _loaded(false) { - bool success = true; - - const size_t stack_size = 1024 * 1024 * 8; - _stack_content.reset((char*)malloc(stack_size)); - if (_stack_content) { - stack_t ss; - ss.ss_sp = _stack_content.get(); - ss.ss_size = stack_size; - ss.ss_flags = 0; - if (sigaltstack(&ss, 0) < 0) { - success = false; - } - } else { - success = false; - } - - for (size_t i = 0; i < signals.size(); ++i) { - struct sigaction action; - action.sa_flags = SA_SIGINFO | SA_ONSTACK; - sigemptyset(&action.sa_mask); - action.sa_sigaction = &sig_handler; - - int r = sigaction(signals[i], &action, 0); - if (r < 0) success = false; - } - _loaded = success; - } - - bool loaded() const { return _loaded; } - -private: - details::handle _stack_content; - bool _loaded; - - static void sig_handler(int, siginfo_t* info, void* _ctx) { - ucontext_t *uctx = (ucontext_t*) _ctx; - - StackTrace st; - void* error_addr = 0; -#ifdef REG_RIP // x86_64 - error_addr = reinterpret_cast(uctx->uc_mcontext.gregs[REG_RIP]); -#elif defined(REG_EIP) // x86_32 - error_addr = reinterpret_cast(uctx->uc_mcontext.gregs[REG_EIP]); -#else -# warning ":/ sorry, ain't know no nothing none not of your architecture!" -#endif - if (error_addr) { - st.load_from(error_addr, 32); - } else { - st.load_here(32); - } - - Printer printer; - printer.address = true; - printer.print(st, stderr); - - psiginfo(info, 0); - // terminate the process immediately. - _exit(EXIT_FAILURE); - } -}; - -#endif // BACKWARD_SYSTEM_LINUX - -#ifdef BACKWARD_SYSTEM_UNKNOWN - -class SignalHandling { -public: - SignalHandling(const std::vector& = std::vector()) {} - bool init() { return false; } -}; - -#endif // BACKWARD_SYSTEM_UNKNOWN - -#if 0 -void crit_err_hdlr(int sig_num, siginfo_t * info, void * ucontext) -{ - void * array[50]; - void * caller_address; - char ** messages; - int size, i; - sig_ucontext_t * uc; - - uc = (sig_ucontext_t *)ucontext; - - /* Get the address at the time the signal was raised from the EIP (x86) */ - caller_address = (void *) uc->uc_mcontext.eip; - - fprintf(stderr, "signal %d (%s), address is %p from %p\n", - sig_num, strsignal(sig_num), info->si_addr, - (void *)caller_address); - - size = backtrace(array, 50); - - /* overwrite sigaction with caller's address */ - array[1] = caller_address; - - messages = backtrace_symbols(array, size); - - -void sig_handler(int sig, siginfo_t* info, void* _ctx) { -ucontext_t *context = (ucontext_t*) _ctx; - -psiginfo(info, "Shit hit the fan"); -exit(EXIT_FAILURE); -} - -using namespace std; - -void badass() { -cout << "baddass!" << endl; -((char*)&badass)[0] = 42; -} - -int main() { -struct sigaction action; -action.sa_flags = SA_SIGINFO; -sigemptyset(&action.sa_mask); -action.sa_sigaction = &sig_handler; -int r = sigaction(SIGSEGV, &action, 0); -if (r < 0) { err(errno, 0); } -r = sigaction(SIGILL, &action, 0); -if (r < 0) { err(errno, 0); } - -badass(); -return 0; -} - - -#endif - -// i want to get a stacktrace on: -// - abort -// - signals (segfault.. abort...) -// - exception -// - dont messup with gdb! -// - thread ID -// - helper for capturing stack trace inside exception -// propose a little magic wrapper to throw an exception adding a stacktrace, -// and propose a specific tool to get a stacktrace from an exception (if its -// available). -// - optional override __cxa_throw, then the specific magic tool could get -// the stacktrace. Might be possible to use a thread-local variable to do -// some shit. RTLD_DEEPBIND might do the tricks to override it on the fly. - -// maybe I can even get the last variables and theirs values? -// that might be possible. - -// print with code snippet -// print traceback demangled -// detect color stuff -// register all signals -// -// Seperate stacktrace (load and co function) -// than object extracting informations about a stack trace. - -// also public a simple function to print a stacktrace. - -// backtrace::StackTrace st; -// st.snapshot(); -// print(st); -// cout << st; - -} // namespace backward - -#endif /* H_GUARD */ diff --git a/src/dionysus/dionysus/chain.h b/src/dionysus/chain.h similarity index 100% rename from src/dionysus/dionysus/chain.h rename to src/dionysus/chain.h diff --git a/src/dionysus/dionysus/chain.hpp b/src/dionysus/chain.hpp similarity index 100% rename from src/dionysus/dionysus/chain.hpp rename to src/dionysus/chain.hpp diff --git a/src/dionysus/dionysus/clearing-reduction.h b/src/dionysus/clearing-reduction.h similarity index 100% rename from src/dionysus/dionysus/clearing-reduction.h rename to src/dionysus/clearing-reduction.h diff --git a/src/dionysus/dionysus/clearing-reduction.hpp b/src/dionysus/clearing-reduction.hpp similarity index 100% rename from src/dionysus/dionysus/clearing-reduction.hpp rename to src/dionysus/clearing-reduction.hpp diff --git a/src/dionysus/dionysus/cohomology-persistence.h b/src/dionysus/cohomology-persistence.h similarity index 100% rename from src/dionysus/dionysus/cohomology-persistence.h rename to src/dionysus/cohomology-persistence.h diff --git a/src/dionysus/dionysus/cohomology-persistence.hpp b/src/dionysus/cohomology-persistence.hpp similarity index 100% rename from src/dionysus/dionysus/cohomology-persistence.hpp rename to src/dionysus/cohomology-persistence.hpp diff --git a/src/dionysus/dionysus/diagram.h b/src/dionysus/diagram.h similarity index 100% rename from src/dionysus/dionysus/diagram.h rename to src/dionysus/diagram.h diff --git a/src/dionysus/dionysus/distances.h b/src/dionysus/distances.h similarity index 100% rename from src/dionysus/dionysus/distances.h rename to src/dionysus/distances.h diff --git a/src/dionysus/dionysus/distances.hpp b/src/dionysus/distances.hpp similarity index 100% rename from src/dionysus/dionysus/distances.hpp rename to src/dionysus/distances.hpp diff --git a/src/dionysus/dionysus/dlog/progress.h b/src/dionysus/dlog/progress.h similarity index 100% rename from src/dionysus/dionysus/dlog/progress.h rename to src/dionysus/dlog/progress.h diff --git a/src/dionysus/dionysus/fields/q.h b/src/dionysus/fields/q.h similarity index 100% rename from src/dionysus/dionysus/fields/q.h rename to src/dionysus/fields/q.h diff --git a/src/dionysus/dionysus/fields/z2.h b/src/dionysus/fields/z2.h similarity index 100% rename from src/dionysus/dionysus/fields/z2.h rename to src/dionysus/fields/z2.h diff --git a/src/dionysus/dionysus/fields/zp.h b/src/dionysus/fields/zp.h similarity index 100% rename from src/dionysus/dionysus/fields/zp.h rename to src/dionysus/fields/zp.h diff --git a/src/dionysus/dionysus/filtration.h b/src/dionysus/filtration.h similarity index 100% rename from src/dionysus/dionysus/filtration.h rename to src/dionysus/filtration.h diff --git a/src/dionysus/dionysus/omni-field-persistence.h b/src/dionysus/omni-field-persistence.h similarity index 100% rename from src/dionysus/dionysus/omni-field-persistence.h rename to src/dionysus/omni-field-persistence.h diff --git a/src/dionysus/dionysus/omni-field-persistence.hpp b/src/dionysus/omni-field-persistence.hpp similarity index 100% rename from src/dionysus/dionysus/omni-field-persistence.hpp rename to src/dionysus/omni-field-persistence.hpp diff --git a/src/dionysus/dionysus/ordinary-persistence.h b/src/dionysus/ordinary-persistence.h similarity index 100% rename from src/dionysus/dionysus/ordinary-persistence.h rename to src/dionysus/ordinary-persistence.h diff --git a/src/dionysus/dionysus/pair-recorder.h b/src/dionysus/pair-recorder.h similarity index 100% rename from src/dionysus/dionysus/pair-recorder.h rename to src/dionysus/pair-recorder.h diff --git a/src/dionysus/dionysus/reduced-matrix.h b/src/dionysus/reduced-matrix.h similarity index 100% rename from src/dionysus/dionysus/reduced-matrix.h rename to src/dionysus/reduced-matrix.h diff --git a/src/dionysus/dionysus/reduced-matrix.hpp b/src/dionysus/reduced-matrix.hpp similarity index 100% rename from src/dionysus/dionysus/reduced-matrix.hpp rename to src/dionysus/reduced-matrix.hpp diff --git a/src/dionysus/dionysus/reduction.h b/src/dionysus/reduction.h similarity index 100% rename from src/dionysus/dionysus/reduction.h rename to src/dionysus/reduction.h diff --git a/src/dionysus/dionysus/relative-homology-zigzag.h b/src/dionysus/relative-homology-zigzag.h similarity index 100% rename from src/dionysus/dionysus/relative-homology-zigzag.h rename to src/dionysus/relative-homology-zigzag.h diff --git a/src/dionysus/dionysus/relative-homology-zigzag.hpp b/src/dionysus/relative-homology-zigzag.hpp similarity index 100% rename from src/dionysus/dionysus/relative-homology-zigzag.hpp rename to src/dionysus/relative-homology-zigzag.hpp diff --git a/src/dionysus/dionysus/rips.h b/src/dionysus/rips.h similarity index 100% rename from src/dionysus/dionysus/rips.h rename to src/dionysus/rips.h diff --git a/src/dionysus/dionysus/rips.hpp b/src/dionysus/rips.hpp similarity index 100% rename from src/dionysus/dionysus/rips.hpp rename to src/dionysus/rips.hpp diff --git a/src/dionysus/dionysus/row-reduction.h b/src/dionysus/row-reduction.h similarity index 100% rename from src/dionysus/dionysus/row-reduction.h rename to src/dionysus/row-reduction.h diff --git a/src/dionysus/dionysus/row-reduction.hpp b/src/dionysus/row-reduction.hpp similarity index 100% rename from src/dionysus/dionysus/row-reduction.hpp rename to src/dionysus/row-reduction.hpp diff --git a/src/dionysus/dionysus/simplex.h b/src/dionysus/simplex.h similarity index 100% rename from src/dionysus/dionysus/simplex.h rename to src/dionysus/simplex.h diff --git a/src/dionysus/dionysus/sparse-row-matrix.h b/src/dionysus/sparse-row-matrix.h similarity index 100% rename from src/dionysus/dionysus/sparse-row-matrix.h rename to src/dionysus/sparse-row-matrix.h diff --git a/src/dionysus/dionysus/sparse-row-matrix.hpp b/src/dionysus/sparse-row-matrix.hpp similarity index 100% rename from src/dionysus/dionysus/sparse-row-matrix.hpp rename to src/dionysus/sparse-row-matrix.hpp diff --git a/src/dionysus/dionysus/standard-reduction.h b/src/dionysus/standard-reduction.h similarity index 100% rename from src/dionysus/dionysus/standard-reduction.h rename to src/dionysus/standard-reduction.h diff --git a/src/dionysus/dionysus/standard-reduction.hpp b/src/dionysus/standard-reduction.hpp similarity index 100% rename from src/dionysus/dionysus/standard-reduction.hpp rename to src/dionysus/standard-reduction.hpp diff --git a/src/dionysus/dionysus/trails-chains.h b/src/dionysus/trails-chains.h similarity index 100% rename from src/dionysus/dionysus/trails-chains.h rename to src/dionysus/trails-chains.h diff --git a/src/dionysus/dionysus/zigzag-persistence.h b/src/dionysus/zigzag-persistence.h similarity index 100% rename from src/dionysus/dionysus/zigzag-persistence.h rename to src/dionysus/zigzag-persistence.h diff --git a/src/dionysus/dionysus/zigzag-persistence.hpp b/src/dionysus/zigzag-persistence.hpp similarity index 100% rename from src/dionysus/dionysus/zigzag-persistence.hpp rename to src/dionysus/zigzag-persistence.hpp diff --git a/src/tdautils/dionysusUtils.h b/src/tdautils/dionysusUtils.h index bc74f56..925b85e 100644 --- a/src/tdautils/dionysusUtils.h +++ b/src/tdautils/dionysusUtils.h @@ -1,10 +1,9 @@ #ifndef __DIONYSUSUTILS_H__ #define __DIONYSUSUTILS_H__ -#include "../dionysus/dionysus/simplex.h" -#include "../dionysus/dionysus/rips.h" -#include "../dionysus/dionysus/filtration.h" -#include "../dionysus/dionysus/standard-reduction.h" +#include +#include +#include // swapping simplex //#include From 6fd8b2ceb2215cf47b73daf8671d4b5d70ee9c5e Mon Sep 17 00:00:00 2001 From: thomashli Date: Wed, 25 Jul 2018 11:01:25 -0600 Subject: [PATCH 15/29] swapped distances.h --- src/tdautils/gridUtils.h | 17 +++++++++++------ src/tdautils/ripsL2.h | 33 +++++++++++++++++++++++++++------ src/tdautils/ripsL2backup.h | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 12 deletions(-) create mode 100644 src/tdautils/ripsL2backup.h diff --git a/src/tdautils/gridUtils.h b/src/tdautils/gridUtils.h index 7d12b09..1a2219c 100644 --- a/src/tdautils/gridUtils.h +++ b/src/tdautils/gridUtils.h @@ -3,13 +3,18 @@ #include -#include -#include -#include -#include -#include +//#include +//#include +//#include +//#include +//#include #include +#include +#include +#include +#include + #include #include #include @@ -531,4 +536,4 @@ void simplicesFromGridBarycenter( -# endif // __GRIDUTILS_H__ \ No newline at end of file +# endif // __GRIDUTILS_H__ diff --git a/src/tdautils/ripsL2.h b/src/tdautils/ripsL2.h index 2acc8ca..d3d044b 100644 --- a/src/tdautils/ripsL2.h +++ b/src/tdautils/ripsL2.h @@ -1,24 +1,45 @@ #include -#include +//#include #include #include #include -#include -#include +//#include +//#include #include // for BackInsertFunctor #include + +//dionysus2 +//#include +#include +//#include +#include +//#include #include -typedef PairwiseDistances PairDistances; +namespace d = dionysus; + + +//L2 Struct is inside distances.h +// Feels very janky + +// typedef std::vector Point; +// typedef std::vector PointContainer; + +typedef d::PairwiseDistances>, d::L2Distance>> PairDistances; typedef PairDistances::DistanceType DistanceType; typedef PairDistances::IndexType VertexR; typedef Rips< PairDistances, Simplex< VertexR, double > > Generator; typedef Generator::Simplex SmplxR; typedef Filtration FltrR; + +// Comment this test typedef StaticPersistence<> PersistenceR; -//typedef DynamicPersistenceChains<> PersistenceR; -typedef PersistenceDiagram<> PDgmR; +// relabel +//typedef OrdinaryPersistence<> PersistenceR; +//typedef DynamicPersistenceChains<> PersistenceR; +typedef PersistenceDiagram<> PDgmR; + diff --git a/src/tdautils/ripsL2backup.h b/src/tdautils/ripsL2backup.h new file mode 100644 index 0000000..172a900 --- /dev/null +++ b/src/tdautils/ripsL2backup.h @@ -0,0 +1,32 @@ +#include +//#include +#include +#include +#include + +#include +#include +#include // for BackInsertFunctor +#include + +//dionysus2 +//#include +#include +//#include +//#include +//#include + +#include + + +typedef PairwiseDistances PairDistances; +typedef PairDistances::DistanceType DistanceType; +typedef PairDistances::IndexType VertexR; +typedef Rips< PairDistances, Simplex< VertexR, double > > Generator; +typedef Generator::Simplex SmplxR; +typedef Filtration FltrR; +typedef StaticPersistence<> PersistenceR; +//typedef DynamicPersistenceChains<> PersistenceR; +typedef PersistenceDiagram<> PDgmR; + + From b5b3705a94d8c942eff2421400c7ee5e1b080c32 Mon Sep 17 00:00:00 2001 From: thomashli Date: Mon, 30 Jul 2018 16:35:08 -0600 Subject: [PATCH 16/29] Arbit works to --- src/tdautils/ripsArbit.h | 13 +++++++++---- src/tdautils/ripsL2.h | 7 ++++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/tdautils/ripsArbit.h b/src/tdautils/ripsArbit.h index 9e87f52..47d0ec5 100644 --- a/src/tdautils/ripsArbit.h +++ b/src/tdautils/ripsArbit.h @@ -1,19 +1,24 @@ #include -#include +//#include #include #include #include -#include +//#include #include -#include +//#include #include // for BackInsertFunctor #include +//dionysus2 +#include +#include + #include +namespace d = dionysus; -typedef PairwiseDistances PairDistancesA; +typedef d::PairwiseDistances>, ArbitDistance> PairDistancesA; typedef PairDistancesA::DistanceType DistanceTypeA; typedef PairDistancesA::IndexType VertexRA; typedef Rips< PairDistancesA, Simplex< VertexRA, double > > GeneratorA; diff --git a/src/tdautils/ripsL2.h b/src/tdautils/ripsL2.h index d3d044b..53ac103 100644 --- a/src/tdautils/ripsL2.h +++ b/src/tdautils/ripsL2.h @@ -22,14 +22,19 @@ namespace d = dionysus; //L2 Struct is inside distances.h -// Feels very janky +// not sure if I want to typedef here // typedef std::vector Point; // typedef std::vector PointContainer; +// Swapped out typedef d::PairwiseDistances>, d::L2Distance>> PairDistances; + +//Next two lines are fine typedef PairDistances::DistanceType DistanceType; typedef PairDistances::IndexType VertexR; + + typedef Rips< PairDistances, Simplex< VertexR, double > > Generator; typedef Generator::Simplex SmplxR; typedef Filtration FltrR; From 798e7a9630542a18b2061709d5bb759c934e0fd6 Mon Sep 17 00:00:00 2001 From: thomashli Date: Sat, 18 Aug 2018 16:33:56 -0700 Subject: [PATCH 17/29] Made FiltrationDiagDionysus2 --- src/tdautils/diagramDS.h | 61 +++++++++++++ src/tdautils/dionysus2Utils.h | 129 +++++++++++++++++++++++++++ tests/testthat/test_FiltrationDiag.R | 15 ++++ 3 files changed, 205 insertions(+) create mode 100644 src/tdautils/diagramDS.h create mode 100644 src/tdautils/dionysus2Utils.h create mode 100644 tests/testthat/test_FiltrationDiag.R diff --git a/src/tdautils/diagramDS.h b/src/tdautils/diagramDS.h new file mode 100644 index 0000000..84333b1 --- /dev/null +++ b/src/tdautils/diagramDS.h @@ -0,0 +1,61 @@ +#ifndef DIAGRAM_DS_H +#define DIAGRAM_DS_H +#endif + +#include +#include + +namespace d = dionysus; + +namespace diagramDS +{ + +template +class DiagramDS +{ + public: + using Value = Value_; + using Data = Data_; + using Diagrams = std::vector>; + + template + DiagramDS(const ReducedMatrix& m, const Filtration& f, const GetValue get_value, const GetData get_data) + { + //auto get_value = [&](const Simplex& s) -> float { return filtration.index(s); }; + //auto get_data = [](Persistence::Index i) { return i; }; + for (typename ReducedMatrix::Index i = 0; i < m.size(); ++i) + { + if (m.skip(i)) + continue; + + auto& s = f[i]; + auto d = s.dimension(); + while (d + 1 > diagrams.size()) + diagrams.emplace_back(); + + auto pair = m.pair(i); + if (pair == m.unpaired()) + { + auto birth = get_value(s); + using Value = decltype(birth); + Value death = std::numeric_limits::infinity(); + diagrams[d].emplace_back(birth, death, get_data(i)); + } else if (pair > i) // positive + { + auto birth = get_value(s); + auto death = get_value(f[pair]); + + if (birth != death) // skip diagonal + diagrams[d].emplace_back(birth, death, get_data(i)); + } // else negative: do nothing + } + } + + Diagrams getDiagrams() {return diagrams;} + private: + Diagrams diagrams; +}; + +} + + diff --git a/src/tdautils/dionysus2Utils.h b/src/tdautils/dionysus2Utils.h new file mode 100644 index 0000000..1670e1a --- /dev/null +++ b/src/tdautils/dionysus2Utils.h @@ -0,0 +1,129 @@ +#ifndef __DIONYSUS2UTILS_H__ +#define __DIONYSUS2UTILS_H__ + +#include +#include +#include +#include +#include +#include +#include +#include "diagramDS.h" +#include + +namespace d = dionysus; + + + +/* + * I'm going to assume that we simply pass in the persDgm, persLoc, and persCycle as they are above. + * I should ask Dave tomorrow what persDgm, persLoc, and persCycle are. + * Persistence Locations: for each diagram, there is a collection of points, and the points are "steps" + * Perstistence Cycle: Birth and Death? + * + * Make example to test it + * + * ask how to get the vignette + * Figure out what the Filtration that gets passed in is. + * Want to pass in Ordinary-Persistence with a Z2Field first go + * + * ask dave what initLocations and initDiagrams do + * + * ask dave what format the cmplx is passed in as + * Figure out what format Ordinary Persistence saves the persistence as + * + * Figure out what format TDA filtration is in typecast utils + * What is the format for TDA filtration? + * + * Figure out what is in FiltrationDiagDionysus + * Parameter probably comes from gridUtils.h defining Persistence as static-persistence + * Probably want to rename persistence2 + * Figure out how to convert Filtration.h in D1 to D2 so the rest of the methods work. + * + */ + +//Helper function for filling in persDgm +//template< typename Diagrams, typename iterator, typename Evaluator, typename SimplexMap > +//inline void initDiagrams; +//Helper function for filling in persLoc and persCycle +// inline void initLocations; + +// FiltrationDiag in Dionysus2 +/** \brief Construct the persistence diagram from the filtration using library +* Dionysus. +* +* @param[out] void Void +* @param[in] filtration The input filtration +* @param[in] maxdimension Max dimension of the homological features to be +* computed +* @param[in] location Are location of birth point, death point, and +* representative cycles returned? +* @param[in] printProgress Is progress printed? +* @param[in] persDgm Memory space for the resulting persistence +* diagram +* @param[in] persLoc Memory space for the resulting birth points and +* death points +* @param[in] persCycle Memory space for the resulting representative +* cycles +* @param[in] max_num_bars Write the max_num_pairs most persistent pairs of the +* diagram. Diagram must point to enough memory space for +* 3*max_num_pairs double. If there is not enough pairs in the diagram, +* write nothing after. +*/ + +template +void FiltrationDiagDionysus2( + const Filtration &filtration, + const int maxdimension, + const bool location, + const bool printProgress, + std::vector< std::vector< std::vector< double > > > & persDgm, + std::vector< std::vector< std::vector< unsigned > > > & persLoc, + std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle +) { + + //Assume that Persistence that is passed in is Persistence2 + //Calculate Persistence + + d::Z2Field k; + //Persistence persistence(k); + //StandardReduction2 reduce(persistence); + d::RowReduction reduce(k); + // We know that the function breaks when this line is called. + reduce(filtration); + + typedef decltype(reduce.persistence().pair(0)) Index; + typedef float Value; + //persistence is reduced. + Index _ = 0; + // move Persistence into persDgm + //auto dgms = d::init_diagrams(reduce->persistence(), filtration, [&](const Smplx2& s) -> float { return filtration.index(s); }, [](typename Persistence::Index i) { return i; }); + + diagramDS::DiagramDS dgms( + reduce.persistence(), + filtration, + [&](const Smplx2& s) -> float { return filtration.index(s);}, + [](typename Persistence::Index i) { return i; } + ); + //emulate initDiagrams function from dionysusUtils + //will put into a function later + persDgm.resize(dgms.getDiagrams().size()); + for (auto &dgm : dgms.getDiagrams()) + { + for (auto &pt : dgm) + { + std::vector pt_; + if (pt.death() == std::numeric_limits::infinity()) { + pt_ = {filtration[pt.birth()].data(),pt.death()}; + } else { + pt_ = {filtration[pt.birth()].data(),filtration[pt.death()].data()}; + } + persDgm[_].push_back(pt_); + } + _++; + } + persDgm.resize(maxdimension + 1); + +} + +#endif __DIONYSUS2UTILS_H__ diff --git a/tests/testthat/test_FiltrationDiag.R b/tests/testthat/test_FiltrationDiag.R new file mode 100644 index 0000000..36acb7f --- /dev/null +++ b/tests/testthat/test_FiltrationDiag.R @@ -0,0 +1,15 @@ +context("FiltrationDiag") + +test_that("FiltrationDiag works with dionysus2" , { + X <- matrix(c(0,0,100,100,0,102,0,101),nrow=4) + Fltrips = ripsFiltration(X,maxdimension = 1, maxscale = 120, library = "Dionysus") + DiagRips = filtrationDiag(Fltrips, maxdimension = 0, library = "Dionysus") + DiagRips2 = filtrationDiag(Fltrips, maxdimension = 0, library = "D2", location = FALSE) + for (i in 1:nrow(DiagRips)) { + for (j in 1:ncol(DiagRips)) { + expect_equal(DiagRips[i,j],DiagRips2[i,j]) + } + } +}) + + From e91ed9fb0971dfd6627a7c787079a6966e090188 Mon Sep 17 00:00:00 2001 From: thomashli Date: Sat, 18 Aug 2018 16:42:52 -0700 Subject: [PATCH 18/29] prototype FiltrationDiag working --- src/tdautils/dionysus2Utils.h | 61 ++++++----------------------------- 1 file changed, 10 insertions(+), 51 deletions(-) diff --git a/src/tdautils/dionysus2Utils.h b/src/tdautils/dionysus2Utils.h index 1670e1a..674c4ad 100644 --- a/src/tdautils/dionysus2Utils.h +++ b/src/tdautils/dionysus2Utils.h @@ -13,41 +13,6 @@ namespace d = dionysus; - - -/* - * I'm going to assume that we simply pass in the persDgm, persLoc, and persCycle as they are above. - * I should ask Dave tomorrow what persDgm, persLoc, and persCycle are. - * Persistence Locations: for each diagram, there is a collection of points, and the points are "steps" - * Perstistence Cycle: Birth and Death? - * - * Make example to test it - * - * ask how to get the vignette - * Figure out what the Filtration that gets passed in is. - * Want to pass in Ordinary-Persistence with a Z2Field first go - * - * ask dave what initLocations and initDiagrams do - * - * ask dave what format the cmplx is passed in as - * Figure out what format Ordinary Persistence saves the persistence as - * - * Figure out what format TDA filtration is in typecast utils - * What is the format for TDA filtration? - * - * Figure out what is in FiltrationDiagDionysus - * Parameter probably comes from gridUtils.h defining Persistence as static-persistence - * Probably want to rename persistence2 - * Figure out how to convert Filtration.h in D1 to D2 so the rest of the methods work. - * - */ - -//Helper function for filling in persDgm -//template< typename Diagrams, typename iterator, typename Evaluator, typename SimplexMap > -//inline void initDiagrams; -//Helper function for filling in persLoc and persCycle -// inline void initLocations; - // FiltrationDiag in Dionysus2 /** \brief Construct the persistence diagram from the filtration using library * Dionysus. @@ -82,32 +47,24 @@ void FiltrationDiagDionysus2( std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle ) { - //Assume that Persistence that is passed in is Persistence2 - //Calculate Persistence - + // Assume that Persistence that is passed in is Persistence2, Filtration is Fltr2 + // Create and Calculate Persistence d::Z2Field k; - //Persistence persistence(k); - //StandardReduction2 reduce(persistence); d::RowReduction reduce(k); - // We know that the function breaks when this line is called. reduce(filtration); typedef decltype(reduce.persistence().pair(0)) Index; typedef float Value; - //persistence is reduced. - Index _ = 0; - // move Persistence into persDgm - //auto dgms = d::init_diagrams(reduce->persistence(), filtration, [&](const Smplx2& s) -> float { return filtration.index(s); }, [](typename Persistence::Index i) { return i; }); - - diagramDS::DiagramDS dgms( + // Putting persistence into diagrams data structure + diagramDS::DiagramDS dgms( reduce.persistence(), filtration, [&](const Smplx2& s) -> float { return filtration.index(s);}, [](typename Persistence::Index i) { return i; } ); - //emulate initDiagrams function from dionysusUtils - //will put into a function later + // Fill in Diagram persDgm.resize(dgms.getDiagrams().size()); + Index _ = 0; for (auto &dgm : dgms.getDiagrams()) { for (auto &pt : dgm) @@ -122,8 +79,10 @@ void FiltrationDiagDionysus2( } _++; } - persDgm.resize(maxdimension + 1); - + // Capping at maxdimension + if (persDgm.size() > maxdimension) { + persDgm.resize(maxdimension + 1); + } } #endif __DIONYSUS2UTILS_H__ From 4ff5a0874b2684dc6cafff5d2ae6595bf8292767 Mon Sep 17 00:00:00 2001 From: thomashli Date: Sat, 18 Aug 2018 16:44:09 -0700 Subject: [PATCH 19/29] checkout master? --- src/tdautils/gridUtils.h | 13 +++++++++++-- src/tdautils/ripsL2.h | 27 +++++++++++---------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/tdautils/gridUtils.h b/src/tdautils/gridUtils.h index 1a2219c..65bb5b7 100644 --- a/src/tdautils/gridUtils.h +++ b/src/tdautils/gridUtils.h @@ -12,7 +12,9 @@ #include #include -#include +//#include +#include +#include #include #include @@ -42,7 +44,14 @@ typedef OffsetBeginMap FiltrationPersistenceMap; - +//dionysus2 +//needs changing +typedef d::Simplex Smplx2; +typedef d::Filtration Fltr2; +typedef d::Simplex<> Simplex2; +typedef d::Filtration Filtration2; +typedef d::ReducedMatrix Persistence2; +typedef d::StandardReduction StandardReduction2; // add a single edge to the filtration template< typename VectorList > diff --git a/src/tdautils/ripsL2.h b/src/tdautils/ripsL2.h index 53ac103..81e2d44 100644 --- a/src/tdautils/ripsL2.h +++ b/src/tdautils/ripsL2.h @@ -10,10 +10,12 @@ #include //dionysus2 +#include +//#include +//#include //#include -#include //#include -#include +//#include //#include #include @@ -27,24 +29,17 @@ namespace d = dionysus; // typedef std::vector Point; // typedef std::vector PointContainer; -// Swapped out typedef d::PairwiseDistances>, d::L2Distance>> PairDistances; - -//Next two lines are fine typedef PairDistances::DistanceType DistanceType; typedef PairDistances::IndexType VertexR; - - typedef Rips< PairDistances, Simplex< VertexR, double > > Generator; typedef Generator::Simplex SmplxR; -typedef Filtration FltrR; - -// Comment this test -typedef StaticPersistence<> PersistenceR; - -// relabel -//typedef OrdinaryPersistence<> PersistenceR; - +typedef Filtration FltrR; +//typedef StaticPersistence<> PersistenceR; +//typedef d::Simplex<> Simplex2; +//typedef d::Filtration Filtration2; +//typedef d::OrdinaryPersistence Persistence2; +//typedef d::StandardReduction StandardReduction2; //typedef DynamicPersistenceChains<> PersistenceR; -typedef PersistenceDiagram<> PDgmR; +//typedef PersistenceDiagram<> PDgmR; From a5e93675b80a03929d2e764c56b08b6f586b3578 Mon Sep 17 00:00:00 2001 From: thomashli Date: Sat, 18 Aug 2018 16:48:44 -0700 Subject: [PATCH 20/29] preparing to rebase --- R/RcppExports.R | 67 +++++++++++ src/RcppExports.cpp | 264 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 331 insertions(+) create mode 100644 R/RcppExports.R create mode 100644 src/RcppExports.cpp diff --git a/R/RcppExports.R b/R/RcppExports.R new file mode 100644 index 0000000..ace733d --- /dev/null +++ b/R/RcppExports.R @@ -0,0 +1,67 @@ +# Generated by using Rcpp::compileAttributes() -> do not edit by hand +# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 + +GridFiltration <- function(FUNvalues, gridDim, maxdimension, decomposition, printProgress) { + .Call(`_TDA_GridFiltration`, FUNvalues, gridDim, maxdimension, decomposition, printProgress) +} + +GridDiag <- function(FUNvalues, gridDim, maxdimension, decomposition, library, location, printProgress) { + .Call(`_TDA_GridDiag`, FUNvalues, gridDim, maxdimension, decomposition, library, location, printProgress) +} + +Bottleneck <- function(Diag1, Diag2) { + .Call(`_TDA_Bottleneck`, Diag1, Diag2) +} + +Wasserstein <- function(Diag1, Diag2, p) { + .Call(`_TDA_Wasserstein`, Diag1, Diag2, p) +} + +Kde <- function(X, Grid, h, kertype, weight, printProgress) { + .Call(`_TDA_Kde`, X, Grid, h, kertype, weight, printProgress) +} + +KdeDist <- function(X, Grid, h, weight, printProgress) { + .Call(`_TDA_KdeDist`, X, Grid, h, weight, printProgress) +} + +Dtm <- function(knnDistance, weightBound, r) { + .Call(`_TDA_Dtm`, knnDistance, weightBound, r) +} + +DtmWeight <- function(knnDistance, weightBound, r, knnIndex, weight) { + .Call(`_TDA_DtmWeight`, knnDistance, weightBound, r, knnIndex, weight) +} + +FiltrationDiag <- function(filtration, maxdimension, library, location, printProgress) { + .Call(`_TDA_FiltrationDiag`, filtration, maxdimension, library, location, printProgress) +} + +FunFiltration <- function(FUNvalues, cmplx) { + .Call(`_TDA_FunFiltration`, FUNvalues, cmplx) +} + +RipsFiltration <- function(X, maxdimension, maxscale, dist, library, printProgress) { + .Call(`_TDA_RipsFiltration`, X, maxdimension, maxscale, dist, library, printProgress) +} + +RipsDiag <- function(X, maxdimension, maxscale, dist, libraryFiltration, libraryDiag, location, printProgress) { + .Call(`_TDA_RipsDiag`, X, maxdimension, maxscale, dist, libraryFiltration, libraryDiag, location, printProgress) +} + +AlphaShapeFiltration <- function(X, printProgress) { + .Call(`_TDA_AlphaShapeFiltration`, X, printProgress) +} + +AlphaShapeDiag <- function(X, maxdimension, libraryDiag, location, printProgress) { + .Call(`_TDA_AlphaShapeDiag`, X, maxdimension, libraryDiag, location, printProgress) +} + +AlphaComplexFiltration <- function(X, printProgress) { + .Call(`_TDA_AlphaComplexFiltration`, X, printProgress) +} + +AlphaComplexDiag <- function(X, maxdimension, libraryDiag, location, printProgress) { + .Call(`_TDA_AlphaComplexDiag`, X, maxdimension, libraryDiag, location, printProgress) +} + diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp new file mode 100644 index 0000000..f772558 --- /dev/null +++ b/src/RcppExports.cpp @@ -0,0 +1,264 @@ +// Generated by using Rcpp::compileAttributes() -> do not edit by hand +// Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 + +#include +#include + +using namespace Rcpp; + +// GridFiltration +Rcpp::List GridFiltration(const Rcpp::NumericVector& FUNvalues, const Rcpp::IntegerVector& gridDim, const int maxdimension, const std::string& decomposition, const bool printProgress); +RcppExport SEXP _TDA_GridFiltration(SEXP FUNvaluesSEXP, SEXP gridDimSEXP, SEXP maxdimensionSEXP, SEXP decompositionSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericVector& >::type FUNvalues(FUNvaluesSEXP); + Rcpp::traits::input_parameter< const Rcpp::IntegerVector& >::type gridDim(gridDimSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type decomposition(decompositionSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(GridFiltration(FUNvalues, gridDim, maxdimension, decomposition, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// GridDiag +Rcpp::List GridDiag(const Rcpp::NumericVector& FUNvalues, const Rcpp::IntegerVector& gridDim, const int maxdimension, const std::string& decomposition, const std::string& library, const bool location, const bool printProgress); +RcppExport SEXP _TDA_GridDiag(SEXP FUNvaluesSEXP, SEXP gridDimSEXP, SEXP maxdimensionSEXP, SEXP decompositionSEXP, SEXP librarySEXP, SEXP locationSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericVector& >::type FUNvalues(FUNvaluesSEXP); + Rcpp::traits::input_parameter< const Rcpp::IntegerVector& >::type gridDim(gridDimSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type decomposition(decompositionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type library(librarySEXP); + Rcpp::traits::input_parameter< const bool >::type location(locationSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(GridDiag(FUNvalues, gridDim, maxdimension, decomposition, library, location, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// Bottleneck +double Bottleneck(const Rcpp::NumericMatrix& Diag1, const Rcpp::NumericMatrix& Diag2); +RcppExport SEXP _TDA_Bottleneck(SEXP Diag1SEXP, SEXP Diag2SEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type Diag1(Diag1SEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type Diag2(Diag2SEXP); + rcpp_result_gen = Rcpp::wrap(Bottleneck(Diag1, Diag2)); + return rcpp_result_gen; +END_RCPP +} +// Wasserstein +double Wasserstein(const Rcpp::NumericMatrix& Diag1, const Rcpp::NumericMatrix& Diag2, const int p); +RcppExport SEXP _TDA_Wasserstein(SEXP Diag1SEXP, SEXP Diag2SEXP, SEXP pSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type Diag1(Diag1SEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type Diag2(Diag2SEXP); + Rcpp::traits::input_parameter< const int >::type p(pSEXP); + rcpp_result_gen = Rcpp::wrap(Wasserstein(Diag1, Diag2, p)); + return rcpp_result_gen; +END_RCPP +} +// Kde +Rcpp::NumericVector Kde(const Rcpp::NumericMatrix& X, const Rcpp::NumericMatrix& Grid, const double h, const std::string& kertype, const Rcpp::NumericVector& weight, const bool printProgress); +RcppExport SEXP _TDA_Kde(SEXP XSEXP, SEXP GridSEXP, SEXP hSEXP, SEXP kertypeSEXP, SEXP weightSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type Grid(GridSEXP); + Rcpp::traits::input_parameter< const double >::type h(hSEXP); + Rcpp::traits::input_parameter< const std::string& >::type kertype(kertypeSEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericVector& >::type weight(weightSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(Kde(X, Grid, h, kertype, weight, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// KdeDist +Rcpp::NumericVector KdeDist(const Rcpp::NumericMatrix& X, const Rcpp::NumericMatrix& Grid, const double h, const Rcpp::NumericVector& weight, const bool printProgress); +RcppExport SEXP _TDA_KdeDist(SEXP XSEXP, SEXP GridSEXP, SEXP hSEXP, SEXP weightSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type Grid(GridSEXP); + Rcpp::traits::input_parameter< const double >::type h(hSEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericVector& >::type weight(weightSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(KdeDist(X, Grid, h, weight, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// Dtm +Rcpp::NumericVector Dtm(const Rcpp::NumericMatrix& knnDistance, const double weightBound, const double r); +RcppExport SEXP _TDA_Dtm(SEXP knnDistanceSEXP, SEXP weightBoundSEXP, SEXP rSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type knnDistance(knnDistanceSEXP); + Rcpp::traits::input_parameter< const double >::type weightBound(weightBoundSEXP); + Rcpp::traits::input_parameter< const double >::type r(rSEXP); + rcpp_result_gen = Rcpp::wrap(Dtm(knnDistance, weightBound, r)); + return rcpp_result_gen; +END_RCPP +} +// DtmWeight +Rcpp::NumericVector DtmWeight(const Rcpp::NumericMatrix& knnDistance, const double weightBound, const double r, const Rcpp::NumericMatrix& knnIndex, const Rcpp::NumericVector& weight); +RcppExport SEXP _TDA_DtmWeight(SEXP knnDistanceSEXP, SEXP weightBoundSEXP, SEXP rSEXP, SEXP knnIndexSEXP, SEXP weightSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type knnDistance(knnDistanceSEXP); + Rcpp::traits::input_parameter< const double >::type weightBound(weightBoundSEXP); + Rcpp::traits::input_parameter< const double >::type r(rSEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type knnIndex(knnIndexSEXP); + Rcpp::traits::input_parameter< const Rcpp::NumericVector& >::type weight(weightSEXP); + rcpp_result_gen = Rcpp::wrap(DtmWeight(knnDistance, weightBound, r, knnIndex, weight)); + return rcpp_result_gen; +END_RCPP +} +// FiltrationDiag +Rcpp::List FiltrationDiag(const Rcpp::List& filtration, const int maxdimension, const std::string& library, const bool location, const bool printProgress); +RcppExport SEXP _TDA_FiltrationDiag(SEXP filtrationSEXP, SEXP maxdimensionSEXP, SEXP librarySEXP, SEXP locationSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::List& >::type filtration(filtrationSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type library(librarySEXP); + Rcpp::traits::input_parameter< const bool >::type location(locationSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(FiltrationDiag(filtration, maxdimension, library, location, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// FunFiltration +Rcpp::List FunFiltration(const Rcpp::NumericVector& FUNvalues, const Rcpp::List& cmplx); +RcppExport SEXP _TDA_FunFiltration(SEXP FUNvaluesSEXP, SEXP cmplxSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericVector& >::type FUNvalues(FUNvaluesSEXP); + Rcpp::traits::input_parameter< const Rcpp::List& >::type cmplx(cmplxSEXP); + rcpp_result_gen = Rcpp::wrap(FunFiltration(FUNvalues, cmplx)); + return rcpp_result_gen; +END_RCPP +} +// RipsFiltration +Rcpp::List RipsFiltration(const Rcpp::NumericMatrix& X, const int maxdimension, const double maxscale, const std::string& dist, const std::string& library, const bool printProgress); +RcppExport SEXP _TDA_RipsFiltration(SEXP XSEXP, SEXP maxdimensionSEXP, SEXP maxscaleSEXP, SEXP distSEXP, SEXP librarySEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const double >::type maxscale(maxscaleSEXP); + Rcpp::traits::input_parameter< const std::string& >::type dist(distSEXP); + Rcpp::traits::input_parameter< const std::string& >::type library(librarySEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(RipsFiltration(X, maxdimension, maxscale, dist, library, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// RipsDiag +Rcpp::List RipsDiag(const Rcpp::NumericMatrix& X, const int maxdimension, const double maxscale, const std::string& dist, const std::string& libraryFiltration, const std::string& libraryDiag, const bool location, const bool printProgress); +RcppExport SEXP _TDA_RipsDiag(SEXP XSEXP, SEXP maxdimensionSEXP, SEXP maxscaleSEXP, SEXP distSEXP, SEXP libraryFiltrationSEXP, SEXP libraryDiagSEXP, SEXP locationSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const double >::type maxscale(maxscaleSEXP); + Rcpp::traits::input_parameter< const std::string& >::type dist(distSEXP); + Rcpp::traits::input_parameter< const std::string& >::type libraryFiltration(libraryFiltrationSEXP); + Rcpp::traits::input_parameter< const std::string& >::type libraryDiag(libraryDiagSEXP); + Rcpp::traits::input_parameter< const bool >::type location(locationSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(RipsDiag(X, maxdimension, maxscale, dist, libraryFiltration, libraryDiag, location, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// AlphaShapeFiltration +Rcpp::List AlphaShapeFiltration(const Rcpp::NumericMatrix& X, const bool printProgress); +RcppExport SEXP _TDA_AlphaShapeFiltration(SEXP XSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(AlphaShapeFiltration(X, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// AlphaShapeDiag +Rcpp::List AlphaShapeDiag(const Rcpp::NumericMatrix& X, const int maxdimension, const std::string& libraryDiag, const bool location, const bool printProgress); +RcppExport SEXP _TDA_AlphaShapeDiag(SEXP XSEXP, SEXP maxdimensionSEXP, SEXP libraryDiagSEXP, SEXP locationSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type libraryDiag(libraryDiagSEXP); + Rcpp::traits::input_parameter< const bool >::type location(locationSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(AlphaShapeDiag(X, maxdimension, libraryDiag, location, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// AlphaComplexFiltration +Rcpp::List AlphaComplexFiltration(const Rcpp::NumericMatrix& X, const bool printProgress); +RcppExport SEXP _TDA_AlphaComplexFiltration(SEXP XSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(AlphaComplexFiltration(X, printProgress)); + return rcpp_result_gen; +END_RCPP +} +// AlphaComplexDiag +Rcpp::List AlphaComplexDiag(const Rcpp::NumericMatrix& X, const int maxdimension, const std::string& libraryDiag, const bool location, const bool printProgress); +RcppExport SEXP _TDA_AlphaComplexDiag(SEXP XSEXP, SEXP maxdimensionSEXP, SEXP libraryDiagSEXP, SEXP locationSEXP, SEXP printProgressSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::NumericMatrix& >::type X(XSEXP); + Rcpp::traits::input_parameter< const int >::type maxdimension(maxdimensionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type libraryDiag(libraryDiagSEXP); + Rcpp::traits::input_parameter< const bool >::type location(locationSEXP); + Rcpp::traits::input_parameter< const bool >::type printProgress(printProgressSEXP); + rcpp_result_gen = Rcpp::wrap(AlphaComplexDiag(X, maxdimension, libraryDiag, location, printProgress)); + return rcpp_result_gen; +END_RCPP +} + +static const R_CallMethodDef CallEntries[] = { + {"_TDA_GridFiltration", (DL_FUNC) &_TDA_GridFiltration, 5}, + {"_TDA_GridDiag", (DL_FUNC) &_TDA_GridDiag, 7}, + {"_TDA_Bottleneck", (DL_FUNC) &_TDA_Bottleneck, 2}, + {"_TDA_Wasserstein", (DL_FUNC) &_TDA_Wasserstein, 3}, + {"_TDA_Kde", (DL_FUNC) &_TDA_Kde, 6}, + {"_TDA_KdeDist", (DL_FUNC) &_TDA_KdeDist, 5}, + {"_TDA_Dtm", (DL_FUNC) &_TDA_Dtm, 3}, + {"_TDA_DtmWeight", (DL_FUNC) &_TDA_DtmWeight, 5}, + {"_TDA_FiltrationDiag", (DL_FUNC) &_TDA_FiltrationDiag, 5}, + {"_TDA_FunFiltration", (DL_FUNC) &_TDA_FunFiltration, 2}, + {"_TDA_RipsFiltration", (DL_FUNC) &_TDA_RipsFiltration, 6}, + {"_TDA_RipsDiag", (DL_FUNC) &_TDA_RipsDiag, 8}, + {"_TDA_AlphaShapeFiltration", (DL_FUNC) &_TDA_AlphaShapeFiltration, 2}, + {"_TDA_AlphaShapeDiag", (DL_FUNC) &_TDA_AlphaShapeDiag, 5}, + {"_TDA_AlphaComplexFiltration", (DL_FUNC) &_TDA_AlphaComplexFiltration, 2}, + {"_TDA_AlphaComplexDiag", (DL_FUNC) &_TDA_AlphaComplexDiag, 5}, + {NULL, NULL, 0} +}; + +RcppExport void R_init_TDA(DllInfo *dll) { + R_registerRoutines(dll, NULL, CallEntries, NULL, NULL); + R_useDynamicSymbols(dll, FALSE); +} From 2a3a2f1a8627c9c3c35da0578b713ac804baedb9 Mon Sep 17 00:00:00 2001 From: thomashli Date: Sat, 18 Aug 2018 16:49:22 -0700 Subject: [PATCH 21/29] prep for rebasing master --- R/filtrationDiag.R | 9 ++++++--- src/diag.cpp | 4 ++-- src/rips.h | 2 +- src/tdautils/filtrationDiag.h | 8 +++++++- src/tdautils/typecastUtils.h | 25 +++++++++++++++++++++++-- 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/R/filtrationDiag.R b/R/filtrationDiag.R index 5990fab..0945278 100644 --- a/R/filtrationDiag.R +++ b/R/filtrationDiag.R @@ -15,8 +15,11 @@ filtrationDiag <- function( if (library == "dionysus" || library == "DIONYSUS") { library <- "Dionysus" } - if (library != "GUDHI" && library != "Dionysus") { - stop("library for computing persistence diagram should be a string: either 'GUDHI' or 'Dionysus'") + if (library == "D2") { + library <- "D2" + } + if (library != "GUDHI" && library != "Dionysus" && library != "D2") { + stop("library for computing persistence diagram should be a string: either 'GUDHI' or 'Dionysus' or 'Dionysus2'") } if (!is.logical(location)) { stop("location should be logical") @@ -71,4 +74,4 @@ filtrationDiag <- function( "deathLocation" = DeathLocation, "cycleLocation" = CycleLocation) } return (out) -} \ No newline at end of file +} diff --git a/src/diag.cpp b/src/diag.cpp index ed3f894..103418d 100644 --- a/src/diag.cpp +++ b/src/diag.cpp @@ -20,7 +20,7 @@ // for Dionysus #include - +#include // for phat #include @@ -507,4 +507,4 @@ Rcpp::List AlphaComplexDiag( concatStlToRcpp< Rcpp::NumericMatrix >(persDgm, true, 3), concatStlToRcpp< Rcpp::NumericMatrix >(persLoc, false, 2), StlToRcppMatrixList< Rcpp::List, Rcpp::NumericMatrix >(persCycle)); -} \ No newline at end of file +} diff --git a/src/rips.h b/src/rips.h index 9d5c69c..32e811e 100644 --- a/src/rips.h +++ b/src/rips.h @@ -10,7 +10,7 @@ // for Dionysus #include -// for phat +// for phat #include // for Rips diff --git a/src/tdautils/filtrationDiag.h b/src/tdautils/filtrationDiag.h index a458a2a..fb649de 100644 --- a/src/tdautils/filtrationDiag.h +++ b/src/tdautils/filtrationDiag.h @@ -64,6 +64,12 @@ inline void filtrationDiagSorted( smplxTree, coeff_field_characteristic, min_persistence, maxdimension, printProgress, persDgm); } + else if (library[0] == 'D' && library[1] == '2') { + FiltrationDiagDionysus2( + filtrationTdaToDionysus2< VertexVector, Fltr2>( + cmplx, values, idxShift), + maxdimension, location, printProgress, persDgm, persLoc, persCycle); + } else if (library[0] == 'D') { FiltrationDiagDionysus< Persistence >( filtrationTdaToDionysus< VertexVector, Fltr >( @@ -138,4 +144,4 @@ inline void filtrationDiag( -# endif // __FILTRATIONDIAG_H__ \ No newline at end of file +# endif // __FILTRATIONDIAG_H__ diff --git a/src/tdautils/typecastUtils.h b/src/tdautils/typecastUtils.h index 1bf08e2..82df01f 100644 --- a/src/tdautils/typecastUtils.h +++ b/src/tdautils/typecastUtils.h @@ -657,7 +657,28 @@ inline Filtration filtrationTdaToDionysus( return filtration; } - +//Marker D2 +template< typename IntegerVector, typename Filtration, typename VectorList, + typename RealVector > +inline Filtration filtrationTdaToDionysus2( + const VectorList & cmplx, const RealVector & values, + const unsigned idxShift) { + Filtration filtration; + typename VectorList::const_iterator iCmplx = cmplx.begin(); + typename RealVector::const_iterator iValue = values.begin(); + for (; iCmplx != cmplx.end(); ++iCmplx, ++iValue) { + const IntegerVector tdaVec(*iCmplx); + IntegerVector dionysusVec(tdaVec.size()); + typename IntegerVector::const_iterator iTda = tdaVec.begin(); + typename IntegerVector::iterator iDionysus = dionysusVec.begin(); + for (; iTda != tdaVec.end(); ++iTda, ++iDionysus) { + // R is 1-base, while C++ is 0-base + *iDionysus = *iTda - idxShift; + } + filtration.push_back(typename Filtration::Cell(dionysusVec, *iValue)); + } + return filtration; +} template< typename Filtration, typename RcppVector, typename RcppList > inline Filtration filtrationRcppToDionysus(const RcppList & rcppList) { @@ -753,4 +774,4 @@ inline void filtrationDionysusToPhat( -# endif // __TYPECASTUTILS_H__ \ No newline at end of file +# endif // __TYPECASTUTILS_H__ From 09a15e2a44005704250c7dcda976a7251749af3d Mon Sep 17 00:00:00 2001 From: thomashli Date: Thu, 13 Dec 2018 21:20:32 -0800 Subject: [PATCH 22/29] updating --- src/rips.h | 12 +++++++++- src/tdautils/diagramDS.h | 1 + src/tdautils/dionysus2Utils.h | 43 +++++++++++++++++++++++++++++++++++ src/tdautils/typecastUtils.h | 33 +++++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 1 deletion(-) diff --git a/src/rips.h b/src/rips.h index 32e811e..f62d60f 100644 --- a/src/rips.h +++ b/src/rips.h @@ -14,6 +14,7 @@ #include // for Rips +#include #include #include @@ -59,10 +60,19 @@ inline void ripsFiltration( maxdimension, maxscale, printProgress, print); filtrationGudhiToTda< IntVector >(smplxTree, cmplx, values, boundary); } + else { if (dist[0] == 'e') { - // RipsDiag for L2 distance + // RipsDiag for L2 distance + /* + if (library[0] == 'D' && library([1] == '2') { + filtrationDionysus2ToTda< IntVector >( + RipsFiltrationDionysus2< PairDistances2, Generator2, FltrR2 >(X, nSample, + nDim, false, maxdimension, maxscale, printProgress, print), + cmplx, values, boundary); + } + */ filtrationDionysusToTda< IntVector >( RipsFiltrationDionysus< PairDistances, Generator, FltrR >(X, nSample, nDim, false, maxdimension, maxscale, printProgress, print), diff --git a/src/tdautils/diagramDS.h b/src/tdautils/diagramDS.h index 84333b1..8cbca88 100644 --- a/src/tdautils/diagramDS.h +++ b/src/tdautils/diagramDS.h @@ -4,6 +4,7 @@ #include #include +#include namespace d = dionysus; diff --git a/src/tdautils/dionysus2Utils.h b/src/tdautils/dionysus2Utils.h index 674c4ad..f379f45 100644 --- a/src/tdautils/dionysus2Utils.h +++ b/src/tdautils/dionysus2Utils.h @@ -85,4 +85,47 @@ void FiltrationDiagDionysus2( } } +template< typename Distances, typename Generator, typename Filtration, + typename RealMatrix, typename Print > +inline Filtration RipsFiltrationDionysus2( + const RealMatrix & X, + const unsigned nSample, + const unsigned nDim, + const bool is_row_names, + const int maxdimension, + const double maxscale, + const bool printProgress, + const Print & print +) { + + PointContainer points = TdaToStl< PointContainer >(X, nSample, nDim, + is_row_names); + //lol copy paste + //read_points(infilename, points); + //read_points2(infilename, points); + + Distances distances(points); //PairDistances2 + Generator rips(distances); //Generator2 + typename Generator::Evaluator size(distances); + Filtration filtration; + //EvaluatePushBack< Filtration, typename Generator::Evaluator > functor(filtration, size); + auto functor = [&filtration](Simplex2&& s) { filtration.push_back(s); }; + // Generate maxdimension skeleton of the Rips complex + // rips.generate(skeleton, max_distance, [&filtration](Simplex&& s) { filtration.push_back(s); }); + + rips.generate(maxdimension + 1, maxscale, functor); + + if (printProgress) { + print("# Generated complex of size: %d \n", filtration.size()); + } + + // Sort the simplices with respect to comparison criteria + // e.g. distance or function values + // filtration.sort(ComparisonDataDimension< typename Filtration::Simplex >()); + filtration.sort(Generator::Comparison(distances)); + + return filtration; +} + + #endif __DIONYSUS2UTILS_H__ diff --git a/src/tdautils/typecastUtils.h b/src/tdautils/typecastUtils.h index 82df01f..94a1ea6 100644 --- a/src/tdautils/typecastUtils.h +++ b/src/tdautils/typecastUtils.h @@ -595,6 +595,39 @@ inline void filtrationDionysusToTda( } +template< typename IntegerVector, typename Filtration, typename VectorList, + typename RealVector > +inline void filtrationDionysus2ToTda( + const Filtration & filtration, VectorList & cmplx, RealVector & values, + VectorList & boundary) { + + const unsigned nFltr = filtration.size(); + std::map< typename Filtration::Cell, unsigned, + typename Filtration::Cell::VertexComparison > simplex_map; + unsigned size_of_simplex_map = 0; + + cmplx = VectorList(nFltr); + values = RealVector(nFltr); + boundary = VectorList(nFltr); + typename VectorList::iterator iCmplx = cmplx.begin(); + typename RealVector::iterator iValue = values.begin(); + typename VectorList::iterator iBdy = boundary.begin(); + + for (typename Filtration::Index it = filtration.begin(); + it != filtration.end(); ++it, ++iCmplx, ++iValue, ++iBdy) { + const typename Filtration::Simplex & c = filtration.simplex(it); + + IntegerVector cmplxVec; + IntegerVector boundaryVec; + filtrationDionysusOne(c, simplex_map, 1, cmplxVec, *iValue, boundaryVec); + *iCmplx = cmplxVec; + *iBdy = boundaryVec; + + simplex_map.insert(typename + std::map< typename Filtration::Simplex, unsigned >::value_type( + c, size_of_simplex_map++)); + } +} template< typename RcppList, typename RcppVector, typename Filtration > inline RcppList filtrationDionysusToRcpp(const Filtration & filtration) { From cb5936ec839277bcd4f2da88b6c0a966a2d0acad Mon Sep 17 00:00:00 2001 From: thomashli Date: Tue, 25 Dec 2018 18:08:07 -0500 Subject: [PATCH 23/29] ripsL2 Dionysus2 working --- src/diag.cpp | 1 + src/rips.h | 26 +++++++++++++------ src/tdautils/dionysus2Utils.h | 33 +++++++++++++++++++++--- src/tdautils/filtrationDiag.h | 1 + src/tdautils/gridUtils.h | 2 -- src/tdautils/ripsD2L2.h | 34 +++++++++++++++++++++++++ src/tdautils/ripsL2.h | 2 +- src/tdautils/typecastUtils.h | 48 ++++++++++++++++++++++++++++++----- tests/testthat/test_rips.R | 4 +-- 9 files changed, 128 insertions(+), 23 deletions(-) create mode 100644 src/tdautils/ripsD2L2.h diff --git a/src/diag.cpp b/src/diag.cpp index 103418d..a4e95b6 100644 --- a/src/diag.cpp +++ b/src/diag.cpp @@ -7,6 +7,7 @@ // for Rips #include +#include #include // for grid diff --git a/src/rips.h b/src/rips.h index f62d60f..7a28065 100644 --- a/src/rips.h +++ b/src/rips.h @@ -65,18 +65,19 @@ inline void ripsFiltration( if (dist[0] == 'e') { // RipsDiag for L2 distance - /* - if (library[0] == 'D' && library([1] == '2') { + + if (library[0] == 'D' && library[1] == '2') { filtrationDionysus2ToTda< IntVector >( RipsFiltrationDionysus2< PairDistances2, Generator2, FltrR2 >(X, nSample, nDim, false, maxdimension, maxscale, printProgress, print), cmplx, values, boundary); } - */ - filtrationDionysusToTda< IntVector >( - RipsFiltrationDionysus< PairDistances, Generator, FltrR >(X, nSample, - nDim, false, maxdimension, maxscale, printProgress, print), - cmplx, values, boundary); + else{ + filtrationDionysusToTda< IntVector >( + RipsFiltrationDionysus< PairDistances, Generator, FltrR >(X, nSample, + nDim, false, maxdimension, maxscale, printProgress, print), + cmplx, values, boundary); + } } else { // RipsDiag for arbitrary distance @@ -162,11 +163,19 @@ inline void ripsDiag( else { if (dist[0] == 'e') { // RipsDiag for L2 distance + if (libraryDiag[0] == 'D' && libraryDiag[0] == '2') { + FiltrationDiagDionysus2( + RipsFiltrationDionysus2< PairDistances2, Generator2, FltrR2 >(X, nSample, + nDim, false, maxdimension, maxscale, printProgress, print), + maxdimension, location, printProgress, persDgm, persLoc, persCycle + ); + } + else { FltrR filtration = RipsFiltrationDionysus< PairDistances, Generator, FltrR >( X, nSample, nDim, false, maxdimension, maxscale, printProgress, print); - + if (libraryDiag[0] == 'D') { FiltrationDiagDionysus< Persistence >( filtration, maxdimension, location, printProgress, persDgm, @@ -191,6 +200,7 @@ inline void ripsDiag( cmplx, values, boundary_matrix, maxdimension, location, printProgress, persDgm, persLoc, persCycle); } + } } else { // RipsDiag for arbitrary distance diff --git a/src/tdautils/dionysus2Utils.h b/src/tdautils/dionysus2Utils.h index f379f45..417f07d 100644 --- a/src/tdautils/dionysus2Utils.h +++ b/src/tdautils/dionysus2Utils.h @@ -13,6 +13,9 @@ namespace d = dionysus; + + + // FiltrationDiag in Dionysus2 /** \brief Construct the persistence diagram from the filtration using library * Dionysus. @@ -85,6 +88,28 @@ void FiltrationDiagDionysus2( } } +/** + * Class: EvaluatePushBack + * + * Push back the simplex and the evaluated value + */ +template< typename Container, typename Evaluator > +class EvaluatePushBack2 { + +public: + EvaluatePushBack2(Container & argContainer, const Evaluator & argEvaluator) : + container(argContainer), evaluator(argEvaluator) {} + + void operator()(const typename Container::value_type & argSmp) const { + typename Container::value_type smp(argSmp.dimension(),argSmp.begin(),argSmp.end(), evaluator(argSmp)); + container.push_back(smp); + } + +private: + Container & container; + const Evaluator & evaluator; +}; + template< typename Distances, typename Generator, typename Filtration, typename RealMatrix, typename Print > inline Filtration RipsFiltrationDionysus2( @@ -108,11 +133,11 @@ inline Filtration RipsFiltrationDionysus2( Generator rips(distances); //Generator2 typename Generator::Evaluator size(distances); Filtration filtration; - //EvaluatePushBack< Filtration, typename Generator::Evaluator > functor(filtration, size); - auto functor = [&filtration](Simplex2&& s) { filtration.push_back(s); }; + EvaluatePushBack2< Filtration, typename Generator::Evaluator > functor(filtration, size); + //auto functor = [&filtration](typename Generator::Simplex&& s) { filtration.push_back(s); }; // Generate maxdimension skeleton of the Rips complex // rips.generate(skeleton, max_distance, [&filtration](Simplex&& s) { filtration.push_back(s); }); - + // rips.generate(maxdimension + 1, maxscale, functor); if (printProgress) { @@ -122,7 +147,7 @@ inline Filtration RipsFiltrationDionysus2( // Sort the simplices with respect to comparison criteria // e.g. distance or function values // filtration.sort(ComparisonDataDimension< typename Filtration::Simplex >()); - filtration.sort(Generator::Comparison(distances)); + filtration.sort(typename Generator::Comparison(distances)); return filtration; } diff --git a/src/tdautils/filtrationDiag.h b/src/tdautils/filtrationDiag.h index fb649de..eac5529 100644 --- a/src/tdautils/filtrationDiag.h +++ b/src/tdautils/filtrationDiag.h @@ -15,6 +15,7 @@ // for Dionysus #include +#include // for phat #include diff --git a/src/tdautils/gridUtils.h b/src/tdautils/gridUtils.h index 65bb5b7..3a89f02 100644 --- a/src/tdautils/gridUtils.h +++ b/src/tdautils/gridUtils.h @@ -48,8 +48,6 @@ typedef OffsetBeginMap Smplx2; typedef d::Filtration Fltr2; -typedef d::Simplex<> Simplex2; -typedef d::Filtration Filtration2; typedef d::ReducedMatrix Persistence2; typedef d::StandardReduction StandardReduction2; diff --git a/src/tdautils/ripsD2L2.h b/src/tdautils/ripsD2L2.h new file mode 100644 index 0000000..be30224 --- /dev/null +++ b/src/tdautils/ripsD2L2.h @@ -0,0 +1,34 @@ +//dionysus2 +#include +#include +#include +#include +#include +#include +#include +#include + +#include // for BackInsertFunctor +#include + + +#include + +namespace d = dionysus; + + +//L2 Struct is inside distances.h + +typedef std::vector Point; +typedef std::vector PointContainer; + +typedef d::PairwiseDistances> PairDistances2; +typedef PairDistances2::DistanceType DistanceType2; +typedef PairDistances2::IndexType VertexR2; + +typedef d::Rips< PairDistances2, d::Simplex< VertexR, double > > Generator2; +typedef Generator2::Simplex SmplxR2; +typedef d::Filtration FltrR2; + +typedef d::Z2Field K; +typedef d::PairRecorder> RipsPersistence2; diff --git a/src/tdautils/ripsL2.h b/src/tdautils/ripsL2.h index 81e2d44..3c52496 100644 --- a/src/tdautils/ripsL2.h +++ b/src/tdautils/ripsL2.h @@ -1,5 +1,5 @@ #include -//#include +#include #include #include #include diff --git a/src/tdautils/typecastUtils.h b/src/tdautils/typecastUtils.h index 94a1ea6..fe90140 100644 --- a/src/tdautils/typecastUtils.h +++ b/src/tdautils/typecastUtils.h @@ -529,6 +529,36 @@ inline void filtrationGudhiToPhat( +template< typename Simplex, typename SimplexMap, typename RealVector > +inline void filtrationDionysus2( + const Simplex & c, const SimplexMap & simplex_map, const int idxShift, + RealVector & cmplxVec, double & value, RealVector & boundaryVec) { + + const unsigned nVtx = c.dimension() + 1; + + cmplxVec = RealVector(nVtx); + typename RealVector::iterator iCmplxVec = cmplxVec.begin(); + //Simplex::Vertices?::const_iterator, want array from unique_ptr + for (auto vit = c.begin(); vit != c.end(); ++vit, ++iCmplxVec) { + // R is 1-base, while C++ is 0-base + *iCmplxVec = *vit + idxShift; + } + + value = c.data(); + + // might need to change for cubical complex + if (nVtx > 1) { + boundaryVec = RealVector(nVtx); + } + typename RealVector::iterator iBdyVec = boundaryVec.begin(); + for (typename Simplex::BoundaryIterator bit = c.boundary_begin(); + bit != c.boundary_end(); ++bit, ++iBdyVec) { + // R is 1-base, while C++ is 0-base + *iBdyVec = simplex_map.find(*bit)->second + idxShift; + } +} + + template< typename Simplex, typename SimplexMap, typename RealVector > inline void filtrationDionysusOne( const Simplex & c, const SimplexMap & simplex_map, const int idxShift, @@ -602,8 +632,14 @@ inline void filtrationDionysus2ToTda( VectorList & boundary) { const unsigned nFltr = filtration.size(); - std::map< typename Filtration::Cell, unsigned, - typename Filtration::Cell::VertexComparison > simplex_map; + //auto lambda = [](typename Filtration::Cell& a,typename Filtration::Cell& b){return a < b;}; + //template + struct VertexComparison2 + { + bool operator()(const typename Filtration::Cell& a, const typename Filtration::Cell& b) const + { return a < b; } + }; + std::map< typename Filtration::Cell, unsigned, VertexComparison2> simplex_map; unsigned size_of_simplex_map = 0; cmplx = VectorList(nFltr); @@ -613,18 +649,18 @@ inline void filtrationDionysus2ToTda( typename RealVector::iterator iValue = values.begin(); typename VectorList::iterator iBdy = boundary.begin(); - for (typename Filtration::Index it = filtration.begin(); + for (typename Filtration::OrderConstIterator it = filtration.begin(); it != filtration.end(); ++it, ++iCmplx, ++iValue, ++iBdy) { - const typename Filtration::Simplex & c = filtration.simplex(it); + const typename Filtration::Cell & c = *it; IntegerVector cmplxVec; IntegerVector boundaryVec; - filtrationDionysusOne(c, simplex_map, 1, cmplxVec, *iValue, boundaryVec); + filtrationDionysus2(c, simplex_map, 1, cmplxVec, *iValue, boundaryVec); *iCmplx = cmplxVec; *iBdy = boundaryVec; simplex_map.insert(typename - std::map< typename Filtration::Simplex, unsigned >::value_type( + std::map< typename Filtration::Cell, unsigned >::value_type( c, size_of_simplex_map++)); } } diff --git a/tests/testthat/test_rips.R b/tests/testthat/test_rips.R index a479f6f..ed65afa 100644 --- a/tests/testthat/test_rips.R +++ b/tests/testthat/test_rips.R @@ -6,7 +6,7 @@ test_that("default circle example ripsFiltration", { maxdimension <- 1 maxscale <- 1.5 FltRips <- ripsFiltration(X = X, maxdimension = maxdimension, - maxscale = maxscale, dist = "euclidean", library = "Dionysus", + maxscale = maxscale, dist = "euclidean", library = "D2", printProgress = TRUE) expect_equal(FltRips$cmplx[[1]],1) expect_true(FltRips$increasing) @@ -17,7 +17,7 @@ test_that("default circle example ripsFiltration", { test_that("One dimensional ripsFiltration in a line", { Y <- matrix(c(1,2.1,3.3,4.6,6)) FltRips <- ripsFiltration(X =Y, maxdimension = 1, - maxscale = 1.5, dist = "euclidean", library = "Dionysus", + maxscale = 1.5, dist = "euclidean", library = "D2", printProgress = TRUE) expect_equal(FltRips$cmplx[[1]],1) expect_equal(FltRips$cmplx[[6]],c(1,2)) From 9b50bd2cbb186218f59f2378430075aa1a7b0d41 Mon Sep 17 00:00:00 2001 From: thomashli Date: Tue, 25 Dec 2018 20:50:33 -0500 Subject: [PATCH 24/29] Arbit Working --- src/diag.cpp | 1 + src/rips.h | 19 ++++++++++++++++++- src/tdautils/ripsD2Arbit.h | 29 +++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 src/tdautils/ripsD2Arbit.h diff --git a/src/diag.cpp b/src/diag.cpp index a4e95b6..37c598f 100644 --- a/src/diag.cpp +++ b/src/diag.cpp @@ -9,6 +9,7 @@ #include #include #include +#include // for grid #include diff --git a/src/rips.h b/src/rips.h index 7a28065..d97d0b0 100644 --- a/src/rips.h +++ b/src/rips.h @@ -17,7 +17,7 @@ #include #include #include - +#include // ripsFiltration @@ -80,12 +80,20 @@ inline void ripsFiltration( } } else { + + if (library[0] == 'D' && library[1] == '2') { + filtrationDionysus2ToTda< IntVector >( + RipsFiltrationDionysus2< PairDistances2A, Generator2A, FltrR2A >(X, nSample, + nDim, true, maxdimension, maxscale, printProgress, print), + cmplx, values, boundary); + } else { // RipsDiag for arbitrary distance filtrationDionysusToTda< IntVector >( RipsFiltrationDionysus< PairDistancesA, GeneratorA, FltrRA >(X, nSample, nDim, true, maxdimension, maxscale, printProgress, print), cmplx, values, boundary); + } } } } @@ -204,6 +212,14 @@ inline void ripsDiag( } else { // RipsDiag for arbitrary distance + + if (libraryDiag[0] == 'D' && libraryDiag[1] == '2') { + FiltrationDiagDionysus2( + RipsFiltrationDionysus2< PairDistances2A, Generator2A, FltrR2A >(X, nSample, + nDim, true, maxdimension, maxscale, printProgress, print), + maxdimension, location, printProgress, persDgm, persLoc, persCycle); + } else { + FltrRA filtration = RipsFiltrationDionysus< PairDistancesA, GeneratorA, FltrRA >( X, nSample, nDim, true, maxdimension, maxscale, @@ -234,6 +250,7 @@ inline void ripsDiag( printProgress, persDgm, persLoc, persCycle); } } + } } } diff --git a/src/tdautils/ripsD2Arbit.h b/src/tdautils/ripsD2Arbit.h new file mode 100644 index 0000000..24ea4da --- /dev/null +++ b/src/tdautils/ripsD2Arbit.h @@ -0,0 +1,29 @@ +//dionysus2 +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include // for BackInsertFunctor +#include + +#include + +namespace d = dionysus; + +typedef d::PairwiseDistances>, ArbitDistance> PairDistances2A; +typedef PairDistances2A::DistanceType DistanceType2A; +typedef PairDistances2A::IndexType VertexR2A; +typedef d::Rips< PairDistances2A, d::Simplex< VertexR2A, double > > Generator2A; +typedef Generator2A::Simplex SmplxR2A; +typedef d::Filtration FltrR2A; +//typedef StaticPersistence<> PersistenceR; +//typedef DynamicPersistenceChains<> PersistenceR; +//typedef PersistenceDiagram<> PDgmR; +typedef d::Z2Field K; +typedef d::PairRecorder> RipsPersistence2A; From 6956426f2e2b0bb0c36a719036ba93af827f68a1 Mon Sep 17 00:00:00 2001 From: thomashli Date: Tue, 25 Dec 2018 21:07:58 -0500 Subject: [PATCH 25/29] added R functionality for rips D2 --- R/ripsFiltration.R | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/R/ripsFiltration.R b/R/ripsFiltration.R index 5937e9f..77f386a 100644 --- a/R/ripsFiltration.R +++ b/R/ripsFiltration.R @@ -21,7 +21,10 @@ ripsFiltration <- function( if (library == "dionysus" || library == "DIONYSUS") { library <- "Dionysus" } - if (library != "GUDHI" && library != "Dionysus") { + if (library == "D2") { + library <- "D2" + } + if (library != "GUDHI" && library != "Dionysus" && library != "D2") { stop("library should be a string: either 'GUDHI' or 'Dionysus'") } if (!is.logical(printProgress)) { From 4f73ef0aa83805b39b15a3f1ceb767f0cb163799 Mon Sep 17 00:00:00 2001 From: thomashli Date: Tue, 25 Dec 2018 21:14:42 -0500 Subject: [PATCH 26/29] dionysusUtils reloaded --- src/tdautils/dionysusUtils.h | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/tdautils/dionysusUtils.h b/src/tdautils/dionysusUtils.h index 925b85e..2ae8d58 100644 --- a/src/tdautils/dionysusUtils.h +++ b/src/tdautils/dionysusUtils.h @@ -413,21 +413,17 @@ inline Filtration RipsFiltrationDionysus2( Filtration filtration; EvaluatePushBack< Filtration, typename Generator::Evaluator > functor( filtration, size); - // Generate maxdimension skeleton of the Rips complex rips.generate(maxdimension + 1, maxscale, functor); - if (printProgress) { print("# Generated complex of size: %d \n", filtration.size()); } - // Sort the simplices with respect to comparison criteria // e.g. distance or function values filtration.sort(ComparisonDataDimension< typename Filtration::Simplex >()); - return filtration; } */ -# endif // __DIONYSUSUTILS_H__ +# endif // __DIONYSUSUTILS_H__ \ No newline at end of file From 5a31e05c569f62725420666f9490ee86c8c66090 Mon Sep 17 00:00:00 2001 From: thomashli Date: Tue, 8 Jan 2019 13:26:01 -0500 Subject: [PATCH 27/29] finished dionysus2 methods --- src/alphaComplex.h | 56 +++++++++------- src/alphaShape.h | 6 ++ src/rips.h | 10 ++- src/tdautils/ripsL2backup.h | 4 +- src/tdautils/typecastUtils.h | 123 ++++++++++++++++++++++++++++++++++- 5 files changed, 168 insertions(+), 31 deletions(-) diff --git a/src/alphaComplex.h b/src/alphaComplex.h index 928ad95..395180d 100644 --- a/src/alphaComplex.h +++ b/src/alphaComplex.h @@ -1,14 +1,14 @@ -#include -#include - -// for changing formats and typecasting -#include - -//for GUDHI -#include - -// for Dionysus -#include +#include +#include + +// for changing formats and typecasting +#include + +//for GUDHI +#include + +// for Dionysus +#include // for phat #include @@ -29,15 +29,15 @@ template< typename RealMatrix, typename Print > void alphaComplexDiag( const RealMatrix & X, //points to some memory space - const unsigned nSample, - const unsigned nDim, + const unsigned nSample, + const unsigned nDim, const int maxdimension, const std::string & libraryDiag, const bool location, const bool printProgress, const Print & print, - std::vector< std::vector< std::vector< double > > > & persDgm, - std::vector< std::vector< std::vector< unsigned > > > & persLoc, + std::vector< std::vector< std::vector< double > > > & persDgm, + std::vector< std::vector< std::vector< unsigned > > > & persLoc, std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle ) { @@ -62,25 +62,31 @@ void alphaComplexDiag( FiltrationDiagGudhi( alphaCmplx, coeff_field_characteristic, min_persistence, 2, printProgress, persDgm); + } + else if (libraryDiag[0] == 'D' && libraryDiag[2] == '2') { + Fltr2 filtration = filtrationGudhiToDionysus2< Fltr2 >(alphaCmplx); + FiltrationDiagDionysus2< Persistence2 >( + filtration, maxdimension, location, printProgress, persDgm, persLoc, + persCycle); } else if (libraryDiag[0] == 'D') { // 2018-08-04 // switching back to original code - Fltr filtration = filtrationGudhiToDionysus< Fltr >(alphaCmplx); - FiltrationDiagDionysus< Persistence >( - filtration, maxdimension, location, printProgress, persDgm, persLoc, + Fltr filtration = filtrationGudhiToDionysus< Fltr >(alphaCmplx); + FiltrationDiagDionysus< Persistence >( + filtration, maxdimension, location, printProgress, persDgm, persLoc, persCycle); } else { // 2018-08-04 // switching back to original code - std::vector< phat::column > cmplx; - std::vector< double > values; - phat::boundary_matrix< phat::vector_vector > boundary_matrix; - filtrationGudhiToPhat< phat::column, phat::dimension >( - alphaCmplx, cmplx, values, boundary_matrix); - FiltrationDiagPhat( - cmplx, values, boundary_matrix, maxdimension, location, + std::vector< phat::column > cmplx; + std::vector< double > values; + phat::boundary_matrix< phat::vector_vector > boundary_matrix; + filtrationGudhiToPhat< phat::column, phat::dimension >( + alphaCmplx, cmplx, values, boundary_matrix); + FiltrationDiagPhat( + cmplx, values, boundary_matrix, maxdimension, location, printProgress, persDgm, persLoc, persCycle); } } diff --git a/src/alphaShape.h b/src/alphaShape.h index a40577b..85bc08d 100644 --- a/src/alphaShape.h +++ b/src/alphaShape.h @@ -55,6 +55,12 @@ void alphaShapeDiag( smplxTree, coeff_field_characteristic, min_persistence, 2, printProgress, persDgm); } + else if (libraryDiag[0] == 'D' && libraryDiag[1] == '2') { + Fltr2 filtration = filtrationGudhiToDionysus2< Fltr2 >(smplxTree); + FiltrationDiagDionysus2< Persistence2 >( + filtration, maxdimension, location, printProgress, persDgm, persLoc, + persCycle); + } else if (libraryDiag[0] == 'D') { Fltr filtration = filtrationGudhiToDionysus< Fltr >(smplxTree); FiltrationDiagDionysus< Persistence >( diff --git a/src/rips.h b/src/rips.h index 02206cf..ab11cd9 100644 --- a/src/rips.h +++ b/src/rips.h @@ -66,7 +66,7 @@ inline void ripsFiltration( if (dist[0] == 'e') { // RipsDiag for L2 distance if (library[0] == 'D' && library[1] == '2') { - filtrationDionysus2ToTda< IntVector >( + filtrationDionysus2Tda< IntVector >( RipsFiltrationDionysus2< PairDistances2, Generator2, FltrR2 >(X, nSample, nDim, false, maxdimension, maxscale, printProgress, print), cmplx, values, boundary); @@ -81,7 +81,7 @@ inline void ripsFiltration( else { if (library[0] == 'D' && library[1] == '2') { - filtrationDionysus2ToTda< IntVector >( + filtrationDionysus2Tda< IntVector >( RipsFiltrationDionysus2< PairDistances2A, Generator2A, FltrR2A >(X, nSample, nDim, true, maxdimension, maxscale, printProgress, print), cmplx, values, boundary); @@ -150,6 +150,12 @@ inline void ripsDiag( FiltrationDiagGudhi( smplxTree, p, min_persistence, maxdimension, printProgress, persDgm); } + else if (libraryDiag[0] == 'D' && libraryDiag[1] == '2') { + FltrR2 filtration = filtrationGudhiToDionysus2< FltrR2 >(smplxTree); + FiltrationDiagDionysus2< Persistence2 >( + filtration, maxdimension, location, printProgress, persDgm, persLoc, + persCycle); + } else if (libraryDiag[0] == 'D') { FltrR filtration = filtrationGudhiToDionysus< FltrR >(smplxTree); FiltrationDiagDionysus< Persistence >( diff --git a/src/tdautils/ripsL2backup.h b/src/tdautils/ripsL2backup.h index 172a900..5dfb996 100644 --- a/src/tdautils/ripsL2backup.h +++ b/src/tdautils/ripsL2backup.h @@ -1,5 +1,5 @@ #include -//#include +#include #include #include #include @@ -11,7 +11,7 @@ //dionysus2 //#include -#include +//#include //#include //#include //#include diff --git a/src/tdautils/typecastUtils.h b/src/tdautils/typecastUtils.h index fe90140..1980f57 100644 --- a/src/tdautils/typecastUtils.h +++ b/src/tdautils/typecastUtils.h @@ -485,6 +485,34 @@ inline Filtration filtrationGudhiToDionysus(SimplexTree & smplxTree) { } +// TODO : see whether 'const SimplexTree &' is possible +template< typename Filtration, typename SimplexTree > +inline Filtration filtrationGudhiToDionysus2(SimplexTree & smplxTree) { + + const typename SimplexTree::Filtration_simplex_range & fltrGudhi = + smplxTree.filtration_simplex_range(); + Filtration fltrDionysus2; + unsigned iFill = 0; + + for (typename SimplexTree::Filtration_simplex_iterator iSt = + fltrGudhi.begin(); iSt != fltrGudhi.end(); ++iSt) { + + // Below two lines are only needed for computing boundary + smplxTree.assign_key(*iSt, iFill); + iFill++; + + std::vector< double > cmplxVec; + double value; + std::vector< double > boundaryVec; + filtrationGudhiOne(*iSt, smplxTree, 0, cmplxVec, value, boundaryVec); + + fltrDionysus2.push_back(typename Filtration::Cell(cmplxVec.size(), + cmplxVec.begin(), cmplxVec.end(), value)); + } + + return fltrDionysus2; +} + // TODO : see whether 'const SimplexTree &' is possible template< typename Column, typename Dimension, typename SimplexTree, @@ -627,7 +655,7 @@ inline void filtrationDionysusToTda( template< typename IntegerVector, typename Filtration, typename VectorList, typename RealVector > -inline void filtrationDionysus2ToTda( +inline void filtrationDionysus2Tda( const Filtration & filtration, VectorList & cmplx, RealVector & values, VectorList & boundary) { @@ -775,7 +803,6 @@ inline Filtration filtrationRcppToDionysus(const RcppList & rcppList) { } - template< typename SimplexTree, typename Filtration > inline SimplexTree filtrationDionysusToGudhi(const Filtration & filtration) { @@ -803,6 +830,55 @@ inline SimplexTree filtrationDionysusToGudhi(const Filtration & filtration) { } +template< typename SimplexTree, typename Filtration > +inline SimplexTree filtrationDionysus2Gudhi(const Filtration & filtration) { + // use custom VertexComparison with Dionysus2 + struct VertexComparison2 + { + bool operator()(const typename Filtration::Cell& a, const typename Filtration::Cell& b) const + { return a < b; } + }; + std::map< typename Filtration::Cell, unsigned, + VertexComparison2 > simplex_map; + unsigned size_of_simplex_map = 0; + SimplexTree smplxTree; + + for (typename Filtration::OrderConstIterator it = filtration.begin(); + it != filtration.end(); ++it) { + const typename Filtration::Cell & c = *it; + + std::vector< double > cmplxVec; + double value; + std::vector< double > boundaryVec; + + filtrationDionysus2(c, simplex_map, 0, cmplxVec, value, boundaryVec); + + smplxTree.insert_simplex(cmplxVec, value); + + simplex_map.insert(typename std::map< typename Filtration::Cell, + unsigned >::value_type(c, size_of_simplex_map++)); + } +/* template + * + for (typename Filtration::OrderConstIterator it = filtration.begin(); + it != filtration.end(); ++it, ++iCmplx, ++iValue, ++iBdy) { + const typename Filtration::Cell & c = *it; + + IntegerVector cmplxVec; + IntegerVector boundaryVec; + filtrationDionysus2(c, simplex_map, 1, cmplxVec, *iValue, boundaryVec); + *iCmplx = cmplxVec; + *iBdy = boundaryVec; + + simplex_map.insert(typename + std::map< typename Filtration::Cell, unsigned >::value_type( + c, size_of_simplex_map++)); + } +*/ + return smplxTree; +} + + template< typename Column, typename Dimension, typename Filtration, typename VectorList, typename RealVector, typename Boundary > @@ -842,5 +918,48 @@ inline void filtrationDionysusToPhat( } +template< typename Column, typename Dimension, typename Filtration, + typename VectorList, typename RealVector, typename Boundary > +inline void filtrationDionysus2Phat( + const Filtration & filtration, VectorList & cmplx, RealVector & values, + Boundary & boundary_matrix) { + // use custom VertexComparison with Dionysus2 + struct VertexComparison2 + { + bool operator()(const typename Filtration::Cell& a, const typename Filtration::Cell& b) const + { return a < b; } + }; + const unsigned nFltr = filtration.size(); + std::map< typename Filtration::Cell, typename Column::value_type, + VertexComparison2 > simplex_map; + typename Column::value_type size_of_simplex_map = 0; + + cmplx = VectorList(nFltr); + values = RealVector(nFltr); + boundary_matrix.set_num_cols(nFltr); + typename VectorList::iterator iCmplx = cmplx.begin(); + typename RealVector::iterator iValue = values.begin(); + + for (typename Filtration::OrderConstIterator it = filtration.begin(); + it != filtration.end(); ++it, ++iCmplx, ++iValue) { + const typename Filtration::Cell & c = *it; + + Column cmplxVec; + Column boundary_indices; + filtrationDionysus2( + c, simplex_map, 0, cmplxVec, *iValue, boundary_indices); + *iCmplx = cmplxVec; + + std::sort(boundary_indices.begin(), boundary_indices.end()); + boundary_matrix.set_col(size_of_simplex_map, boundary_indices); + Dimension dim_of_column = c.dimension(); + boundary_matrix.set_dim(size_of_simplex_map, dim_of_column); + + simplex_map.insert(typename std::map< typename Filtration::Cell, + typename Column::value_type >::value_type(c, size_of_simplex_map++)); + } +} + + # endif // __TYPECASTUTILS_H__ From 9684929bc0711ccee5cd049f2bda2add0868278e Mon Sep 17 00:00:00 2001 From: thomashli Date: Wed, 9 Jan 2019 13:40:02 -0500 Subject: [PATCH 28/29] figure out persistence bug --- src/.DS_Store | Bin 0 -> 6148 bytes src/.diag.cpp.swp | Bin 0 -> 36864 bytes src/.grid.h.swp | Bin 0 -> 16384 bytes src/alphaComplex.h | 13 +- src/alphaComplexb.h | 93 + src/alphaShape.h | 154 +- src/alphaShapeb.h | 81 + src/diag.cpp | 21 +- src/dionysus/bottleneck/basic_defs_bt.h | 476 + src/dionysus/bottleneck/bottleneck.h | 118 + src/dionysus/bottleneck/bottleneck_detail.h | 85 + src/dionysus/bottleneck/bottleneck_detail.hpp | 507 + src/dionysus/bottleneck/bound_match.h | 107 + src/dionysus/bottleneck/bound_match.hpp | 473 + src/dionysus/bottleneck/def_debug_bt.h | 42 + src/dionysus/bottleneck/diagram_reader.h | 196 + src/dionysus/bottleneck/diagram_traits.h | 45 + .../bottleneck/dnn/geometry/euclidean-fixed.h | 162 + src/dionysus/bottleneck/dnn/local/kd-tree.h | 106 + src/dionysus/bottleneck/dnn/local/kd-tree.hpp | 296 + .../bottleneck/dnn/local/search-functors.h | 119 + src/dionysus/bottleneck/dnn/parallel/tbb.h | 235 + src/dionysus/bottleneck/dnn/parallel/utils.h | 100 + src/dionysus/bottleneck/dnn/utils.h | 47 + src/dionysus/bottleneck/neighb_oracle.h | 295 + src/dionysus/wasserstein/auction_oracle.h | 40 + .../wasserstein/auction_oracle_base.h | 85 + .../wasserstein/auction_oracle_base.hpp | 97 + .../auction_oracle_kdtree_pure_geom.h | 97 + .../auction_oracle_kdtree_pure_geom.hpp | 244 + .../auction_oracle_kdtree_restricted.h | 122 + .../auction_oracle_kdtree_restricted.hpp | 598 + .../auction_oracle_kdtree_single_diag.h | 219 + .../auction_oracle_kdtree_single_diag.hpp | 717 + .../wasserstein/auction_oracle_lazy_heap.h | 191 + .../wasserstein/auction_oracle_lazy_heap.hpp | 465 + .../auction_oracle_stupid_sparse_restricted.h | 114 + ...uction_oracle_stupid_sparse_restricted.hpp | 568 + src/dionysus/wasserstein/auction_runner_fr.h | 289 + .../wasserstein/auction_runner_fr.hpp | 1440 ++ src/dionysus/wasserstein/auction_runner_gs.h | 122 + .../wasserstein/auction_runner_gs.hpp | 486 + .../auction_runner_gs_single_diag.h | 149 + .../auction_runner_gs_single_diag.hpp | 738 + src/dionysus/wasserstein/auction_runner_jac.h | 230 + .../wasserstein/auction_runner_jac.hpp | 873 ++ src/dionysus/wasserstein/basic_defs_ws.h | 336 + src/dionysus/wasserstein/basic_defs_ws.hpp | 193 + src/dionysus/wasserstein/catch/catch.hpp | 11545 ++++++++++++++++ src/dionysus/wasserstein/def_debug_ws.h | 44 + src/dionysus/wasserstein/diagonal_heap.h | 149 + src/dionysus/wasserstein/diagram_reader.h | 369 + .../dnn/geometry/euclidean-dynamic.h | 248 + .../dnn/geometry/euclidean-fixed.h | 196 + src/dionysus/wasserstein/dnn/local/kd-tree.h | 97 + .../wasserstein/dnn/local/kd-tree.hpp | 330 + .../wasserstein/dnn/local/search-functors.h | 95 + src/dionysus/wasserstein/dnn/parallel/tbb.h | 237 + src/dionysus/wasserstein/dnn/parallel/utils.h | 100 + src/dionysus/wasserstein/dnn/utils.h | 47 + .../wasserstein/spdlog/async_logger.h | 82 + src/dionysus/wasserstein/spdlog/common.h | 160 + .../spdlog/details/async_log_helper.h | 399 + .../spdlog/details/async_logger_impl.h | 105 + .../wasserstein/spdlog/details/file_helper.h | 117 + .../wasserstein/spdlog/details/log_msg.h | 50 + .../wasserstein/spdlog/details/logger_impl.h | 563 + .../spdlog/details/mpmc_bounded_q.h | 172 + .../wasserstein/spdlog/details/null_mutex.h | 45 + src/dionysus/wasserstein/spdlog/details/os.h | 469 + .../spdlog/details/pattern_formatter_impl.h | 690 + .../wasserstein/spdlog/details/registry.h | 214 + .../wasserstein/spdlog/details/spdlog_impl.h | 263 + .../wasserstein/spdlog/fmt/bundled/format.cc | 940 ++ .../wasserstein/spdlog/fmt/bundled/format.h | 4501 ++++++ .../wasserstein/spdlog/fmt/bundled/ostream.cc | 43 + .../wasserstein/spdlog/fmt/bundled/ostream.h | 126 + .../wasserstein/spdlog/fmt/bundled/posix.cc | 238 + .../wasserstein/spdlog/fmt/bundled/posix.h | 443 + .../wasserstein/spdlog/fmt/bundled/time.h | 58 + src/dionysus/wasserstein/spdlog/fmt/fmt.h | 28 + src/dionysus/wasserstein/spdlog/fmt/ostr.h | 17 + src/dionysus/wasserstein/spdlog/formatter.h | 47 + src/dionysus/wasserstein/spdlog/logger.h | 132 + .../wasserstein/spdlog/sinks/android_sink.h | 90 + .../wasserstein/spdlog/sinks/ansicolor_sink.h | 133 + .../wasserstein/spdlog/sinks/base_sink.h | 50 + .../wasserstein/spdlog/sinks/dist_sink.h | 73 + .../wasserstein/spdlog/sinks/file_sinks.h | 242 + .../wasserstein/spdlog/sinks/msvc_sink.h | 51 + .../wasserstein/spdlog/sinks/null_sink.h | 34 + .../wasserstein/spdlog/sinks/ostream_sink.h | 47 + src/dionysus/wasserstein/spdlog/sinks/sink.h | 53 + .../wasserstein/spdlog/sinks/stdout_sinks.h | 77 + .../wasserstein/spdlog/sinks/syslog_sink.h | 81 + .../wasserstein/spdlog/sinks/wincolor_sink.h | 117 + src/dionysus/wasserstein/spdlog/spdlog.h | 187 + src/dionysus/wasserstein/spdlog/tweakme.h | 141 + src/dionysus/wasserstein/wasserstein.h | 347 + .../wasserstein/wasserstein_pure_geom.hpp | 87 + src/rips.h | 478 +- src/ripsb.h | 262 + src/tdautils/.swp | Bin 0 -> 12288 bytes src/tdautils/filtrationDiag.h | 290 +- src/tdautils/filtrationDiagb.h | 148 + src/tdautils/gridUtils.h | 25 +- src/tdautils/ripsD2L2.h | 2 +- src/tdautils/typecastUtils.h | 49 +- 108 files changed, 38098 insertions(+), 509 deletions(-) create mode 100644 src/.DS_Store create mode 100644 src/.diag.cpp.swp create mode 100644 src/.grid.h.swp create mode 100644 src/alphaComplexb.h create mode 100644 src/alphaShapeb.h create mode 100755 src/dionysus/bottleneck/basic_defs_bt.h create mode 100755 src/dionysus/bottleneck/bottleneck.h create mode 100755 src/dionysus/bottleneck/bottleneck_detail.h create mode 100755 src/dionysus/bottleneck/bottleneck_detail.hpp create mode 100755 src/dionysus/bottleneck/bound_match.h create mode 100755 src/dionysus/bottleneck/bound_match.hpp create mode 100755 src/dionysus/bottleneck/def_debug_bt.h create mode 100755 src/dionysus/bottleneck/diagram_reader.h create mode 100755 src/dionysus/bottleneck/diagram_traits.h create mode 100755 src/dionysus/bottleneck/dnn/geometry/euclidean-fixed.h create mode 100755 src/dionysus/bottleneck/dnn/local/kd-tree.h create mode 100755 src/dionysus/bottleneck/dnn/local/kd-tree.hpp create mode 100755 src/dionysus/bottleneck/dnn/local/search-functors.h create mode 100755 src/dionysus/bottleneck/dnn/parallel/tbb.h create mode 100755 src/dionysus/bottleneck/dnn/parallel/utils.h create mode 100755 src/dionysus/bottleneck/dnn/utils.h create mode 100755 src/dionysus/bottleneck/neighb_oracle.h create mode 100755 src/dionysus/wasserstein/auction_oracle.h create mode 100755 src/dionysus/wasserstein/auction_oracle_base.h create mode 100755 src/dionysus/wasserstein/auction_oracle_base.hpp create mode 100755 src/dionysus/wasserstein/auction_oracle_kdtree_pure_geom.h create mode 100755 src/dionysus/wasserstein/auction_oracle_kdtree_pure_geom.hpp create mode 100755 src/dionysus/wasserstein/auction_oracle_kdtree_restricted.h create mode 100755 src/dionysus/wasserstein/auction_oracle_kdtree_restricted.hpp create mode 100755 src/dionysus/wasserstein/auction_oracle_kdtree_single_diag.h create mode 100755 src/dionysus/wasserstein/auction_oracle_kdtree_single_diag.hpp create mode 100755 src/dionysus/wasserstein/auction_oracle_lazy_heap.h create mode 100755 src/dionysus/wasserstein/auction_oracle_lazy_heap.hpp create mode 100755 src/dionysus/wasserstein/auction_oracle_stupid_sparse_restricted.h create mode 100755 src/dionysus/wasserstein/auction_oracle_stupid_sparse_restricted.hpp create mode 100755 src/dionysus/wasserstein/auction_runner_fr.h create mode 100755 src/dionysus/wasserstein/auction_runner_fr.hpp create mode 100755 src/dionysus/wasserstein/auction_runner_gs.h create mode 100755 src/dionysus/wasserstein/auction_runner_gs.hpp create mode 100755 src/dionysus/wasserstein/auction_runner_gs_single_diag.h create mode 100755 src/dionysus/wasserstein/auction_runner_gs_single_diag.hpp create mode 100755 src/dionysus/wasserstein/auction_runner_jac.h create mode 100755 src/dionysus/wasserstein/auction_runner_jac.hpp create mode 100755 src/dionysus/wasserstein/basic_defs_ws.h create mode 100755 src/dionysus/wasserstein/basic_defs_ws.hpp create mode 100755 src/dionysus/wasserstein/catch/catch.hpp create mode 100755 src/dionysus/wasserstein/def_debug_ws.h create mode 100755 src/dionysus/wasserstein/diagonal_heap.h create mode 100755 src/dionysus/wasserstein/diagram_reader.h create mode 100755 src/dionysus/wasserstein/dnn/geometry/euclidean-dynamic.h create mode 100755 src/dionysus/wasserstein/dnn/geometry/euclidean-fixed.h create mode 100755 src/dionysus/wasserstein/dnn/local/kd-tree.h create mode 100755 src/dionysus/wasserstein/dnn/local/kd-tree.hpp create mode 100755 src/dionysus/wasserstein/dnn/local/search-functors.h create mode 100755 src/dionysus/wasserstein/dnn/parallel/tbb.h create mode 100755 src/dionysus/wasserstein/dnn/parallel/utils.h create mode 100755 src/dionysus/wasserstein/dnn/utils.h create mode 100755 src/dionysus/wasserstein/spdlog/async_logger.h create mode 100755 src/dionysus/wasserstein/spdlog/common.h create mode 100755 src/dionysus/wasserstein/spdlog/details/async_log_helper.h create mode 100755 src/dionysus/wasserstein/spdlog/details/async_logger_impl.h create mode 100755 src/dionysus/wasserstein/spdlog/details/file_helper.h create mode 100755 src/dionysus/wasserstein/spdlog/details/log_msg.h create mode 100755 src/dionysus/wasserstein/spdlog/details/logger_impl.h create mode 100755 src/dionysus/wasserstein/spdlog/details/mpmc_bounded_q.h create mode 100755 src/dionysus/wasserstein/spdlog/details/null_mutex.h create mode 100755 src/dionysus/wasserstein/spdlog/details/os.h create mode 100755 src/dionysus/wasserstein/spdlog/details/pattern_formatter_impl.h create mode 100755 src/dionysus/wasserstein/spdlog/details/registry.h create mode 100755 src/dionysus/wasserstein/spdlog/details/spdlog_impl.h create mode 100755 src/dionysus/wasserstein/spdlog/fmt/bundled/format.cc create mode 100755 src/dionysus/wasserstein/spdlog/fmt/bundled/format.h create mode 100755 src/dionysus/wasserstein/spdlog/fmt/bundled/ostream.cc create mode 100755 src/dionysus/wasserstein/spdlog/fmt/bundled/ostream.h create mode 100755 src/dionysus/wasserstein/spdlog/fmt/bundled/posix.cc create mode 100755 src/dionysus/wasserstein/spdlog/fmt/bundled/posix.h create mode 100755 src/dionysus/wasserstein/spdlog/fmt/bundled/time.h create mode 100755 src/dionysus/wasserstein/spdlog/fmt/fmt.h create mode 100755 src/dionysus/wasserstein/spdlog/fmt/ostr.h create mode 100755 src/dionysus/wasserstein/spdlog/formatter.h create mode 100755 src/dionysus/wasserstein/spdlog/logger.h create mode 100755 src/dionysus/wasserstein/spdlog/sinks/android_sink.h create mode 100755 src/dionysus/wasserstein/spdlog/sinks/ansicolor_sink.h create mode 100755 src/dionysus/wasserstein/spdlog/sinks/base_sink.h create mode 100755 src/dionysus/wasserstein/spdlog/sinks/dist_sink.h create mode 100755 src/dionysus/wasserstein/spdlog/sinks/file_sinks.h create mode 100755 src/dionysus/wasserstein/spdlog/sinks/msvc_sink.h create mode 100755 src/dionysus/wasserstein/spdlog/sinks/null_sink.h create mode 100755 src/dionysus/wasserstein/spdlog/sinks/ostream_sink.h create mode 100755 src/dionysus/wasserstein/spdlog/sinks/sink.h create mode 100755 src/dionysus/wasserstein/spdlog/sinks/stdout_sinks.h create mode 100755 src/dionysus/wasserstein/spdlog/sinks/syslog_sink.h create mode 100755 src/dionysus/wasserstein/spdlog/sinks/wincolor_sink.h create mode 100755 src/dionysus/wasserstein/spdlog/spdlog.h create mode 100755 src/dionysus/wasserstein/spdlog/tweakme.h create mode 100755 src/dionysus/wasserstein/wasserstein.h create mode 100755 src/dionysus/wasserstein/wasserstein_pure_geom.hpp create mode 100644 src/ripsb.h create mode 100644 src/tdautils/.swp create mode 100644 src/tdautils/filtrationDiagb.h diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..dd491778fb310f4c8e68fc7b51dd6fdf35938287 GIT binary patch literal 6148 zcmeHK%}OId5bjQ*&TK$N5LEEcu!p_mki?+ZA)9y-6i3#>?iy#9(F_bTVKPY;LSXO# zd;q=r^BlXrgb%V0uvOiiWkc4Rh=eMr{;I31s_Acr{$Y&q=D@EpW--PLP{e`>&3^>P zQRk$fJ%}7@^m|^Ec+mUp_2#0<@g5n#Z&zdi`@u4A^zZkJ!hTk*KJ!Acw6wfpm>EQ0o)%1D59+~S17j*Xz==o<0V8Cu<Ph!c{7uO6B^A!Bsl=g^sf|<_c9h<8o%Gqi5#&g~H|R;1@ESaa$p^!~iic z%Ro_eEj<6*FTek1lc+}w5CaRv057$iRuh(_&(^uc;aMv{&p}ZzE?4+B1q^u;E0bh222&iKp9tIyJEn09L()d!6t;$I8ZnS1EB;c-~amVnH}w{ zmY{_0;$wEE-+TS~f9Les<~3ZtmN{x z+-q_T=x~)eYt78EQfYY3hH`1n$zAzU#>_`z{kflDYWb$rMyq0(foFOM;h~by5~w7Guls7A3b{Wq?1b` zCK5;_kVqhrKq7%e0*M3?2_zCoB=CPv0@aae#xAmbkel^m-OmriJ`ZufZ*V`4i{1aA z`+b@F>DdgtbPjcY?{`11cR#)R)7HlvWWf(8O z18_h55N?HA;OlS^Sg;&U1O?OKcONp0SK(jan{Yk+J&eI9!-OJ~b@b7RZ zTnSgeD4Y%JVKH>VC*b4o2XGAh_6WoHIa~tE;f=!$<285)&V?27{r3#xI_QUKaDSU& z?1q&v9X-gz@59Zo4OYQQ zm<7*cAh{WCf*aum_$KUzVK@%{4I|5Cup3M`3?9O;u?uFx0~l#8gAFhZ?s07#O3hYH zF0+_bWox92zeTZXR_z^@%8X_5R)t@ywQ|wQo?=K6e>`HA&BB>>ahp=&yiv7FMK{@{ zWh?lpG^o;cxtdcWCA(PdR9VaPztksWr^{w}%*);iyPC7gs=bf2$11f-yDIT*!>LOL z+UufbtxTyfQmWWaM!cw>san;tnX|HG(}|rX{S2>Wm7O%aXEpd7$k`QTol`UO71dEN zN9}@{S6RDY6)WaNQgQG-}Ytu;sMQ^6^wlrIh0c)_Sa zL{y^yRJEkih@Z4;RV&-26K;@C@28yz%Tfhfk?mrwkVbUmaj0x(9pP7$T~vrPczRt{ zC2*fz@KRi=c37DzLhmSacdD+gu1;0R&g!g-wsP~j9cI2pmi6UF0e@ZAFlB6oNbl%!wc*g%vD6?hz7g?f ztxo=5tW?XjOjVaTkE)Ont4JtuKcXcVsgyFyX`9+c1rEx4rm9u!;*cs>g%UT+iZUsL zA|TPMn0C>VHd42vz%>r#NA@-OwP26Zq$|@8>+TCy3Tl+ z9QF2YBW7q zo^CphWNAWD-MwpOtSk+=;RL5ly-=agNJTK@h7;PU<&nBc>vB6YsS446k>)_apw1bg z4o8|;xcboQ&OFn(Z;g(W%GGV#jM*t8Iuu4;U#b2I_0y@EPDPdBbHB&m=p!dR`o@x( z_}Ckb7-^46O=#q8DuT(5yU}cg>k8#R`BArNUMDf;MziTwQmLYH`9o)H#EnMsD{L$h z16ph2t)Fa@(RMZ!mW1EWBW=(Pj5prSu#1^|Eo-SoW#)rjxh3?S-Y>&hEBb?!dGwcT zwGjQm%#Y;EO*wNU7FGYg+?j1kjL>87hptuad?hvLk9mFK8tvv~#tg}k{-2Z4tUIet zrAyT+)yB#UhyIjksZ{V1e(0TMrNZ>5YT3n9=x^^aPTGE6yl6d}C2RW#!sc0E(3`a6V|L;MUyb1kM^nY*t{3Uexhd}iF(_t3uMQ48q?gP>9FNf7| z2)v3e{vbR6--Gj@2fE>Ccn#hB4yeK(!7=bLm<}I>qv5~N)o*|T{3(d;-UdHKXa5_x z5_Z8^@DX?pz5OZp54a1?hcnkw8ll z@bo{lI7fk7Ycls96;0Yx<>YTqgL7xZP9U|ULo|iuV)hC2L1x-mdHnReR2s!xQ$mFp z#B74%+@rNYnWwanK;IBrAL=&o=j}_I6y-~;uT@2(_7$9}ua+C~Xiq8B0d(RLQfD)G zQ%CXL+Nrj770acaq8am7rj(bz!n4Uvm3QXKGQ9G|wSL7_QJx$gttD`Wr~CR(ZLV}i zG|cYSeNuA%iDPxoPM!7Ju+M% zp}>>TBBK?0rlLZGuT;u08HsVOQYv6?a|}C`5llu$hV{p#GKF)eYEI!d(jc*ArO2^A z5!VLQiD;4WIMh&G-6|L?F#`_du!rPI`K%uA#Zam9Lo3vl%6c2#*6>@Rh!jre5lPH+V(9&O6jph5{Zn-C{HAFEhmkQu{x zYfzreFEXb~d_E%cP~UjvxSceg&PFD!iO&a4SQE1ec+=U&wz`PFK-f8BGPB?Qus>4b z^Wpj>j_+~l+a4Up&BW)Uv8QFz>Cz;|+^EMQr5!~Yn*4c9^$`r%l3 z34Q)<_&XSeLD&KdKy>{#(d~Z#7l8NyoC~tYU=Mo!O>iSz01MzPbbQhMpMZaYi{Vr_ z2n^VZuKy0a4Znt`;STr$TnHz_2jEt8{q1lP^guVvg%lhM;^+T7`u?3D{{4Rei(wi( zfsX%OsKF9=5n22#h)fRQ{xA6LHn9bSWH;3~KTE`kf;0{A>^0r4wH!NKqtwuL+3R=5Q&g^jQQ*26kj3%?Hs!OPeo z?t)7p4fEhoFkmnCo)_V(@DFeuTnpF0rLY7(4W9(LWS$;$pFx$IlJ%Uya;YFII2?%ACx^u=G--D%lFx$3e~t(fgpWgVu!n6*Ybm2*eiTirDw zT(h?5;2pAf>n(DM!svR1s|NFJW+F~XL~*W$f!+1?nUYDJfx6)CeAP?E{VY!~T4_WGXaJxx$DyqUK$VuG3>q|-+d1T=Zovukn?JKe6f*= z<<@(=7NrqU>;0Aepvr4KAifkRmD3aC1loibw60WADK`u3XEA$wLr#XS4@&gF%a=k@l+7p!%5jV*zPYtyDKp}P&I z@-*xi{T?n-p96PWXY9KrBW@Hy%fA(dxKZkEm;QW{VVy2 z6z4MEy?F{sv#o8vm$XG9Z44ok^`>hna?F#?YlBYp7mt=;TNcH%FF(3Ttz^3uTZOjS z_*-*gesxjPkmqI>aRx=8T&}i3=a;qqdqC_0cR&%=!&+Da zvQOYN=z!^9z~k5i{t+$%+2?->^n&;Xd>S6dKJZic3H%s-1oy(#up26{3XX+0u@(Fd zUWNPND)=&730J_E;0#y^Js`G)`>`8b4?E#B=!4~u0@+*e3_K0q9s@Zy-~t#1EiVyN zawQT-B#=lTkw7AWL;{Hf#*si`LzhEm5*vDAL!bIZ2x)TOLec+U8@i~YDyQ|#z&zU6 z=qU>W#%kBv-)w1PYb%8@Y;jgto-O=%4f*Fi`71%;LWnFQ)bDoM+Uc$x1=}*4u50jM z*}v~Ea3F(VMI*eLF&XPZM1t-@lh;k{+hpKnO&5#vc-PuPVTG?;IIYMt{S`xz!M;On z?QB$LiqdYpn+j{e>b@1@7MSS&Da!5w!#9K53=w7FnE#m`ftE6*pP!E z7=!_kb^VuFw-=v)AHjvN7@lKIUe53vgm(B9Ywoh|U(W420%VW>4k*LtU^yswg*Ek` zz}H|5R>LY-333j@baL1(&b$X7hsE zi?zabSq81R>z`e!U+

>0S;HELD|NEY*fMxlsq=>_FK{n{`(AkWAA@xWT#@0Q730 zD%7YodS|cr2)HSS_fE?`Wl1!Au%wsGv`H?N+M-`oa^c-@vQ{Wb+MdywKXrzky}$7! zBcshY2Xw`9;Z`nBklq1T`mh=2P)TnUmm?*cN~PUDA+TL<%jPIR{}>gw#KQol7^Y9jj$^_JoI;}P`~?PO>w8s_^*qc~a& zZ8?`Wmo-9tnxlJCW2|gqJ8#{x(J5s^DTTH)hNaWQ12U?D?cdIH>0wpi2^HD&E(^!e zjzIcve}Wobrgpm4ckTE_W@*c18~emga3pl7@A92fIrqa)#y{E7h#gMhcDmt%7o}>- zNTG8uP-2mF-1R;0%#!#BrAHUZeDlS&$O4_0*H}V z0%^FjFWud*j7Y4Z!$5=0?D~mr4hHD?+>;Vyd%vBr^>G3Xo2{P*;yUs29FW5MhEHA% zwe`BDu$1LjIcs$scRufEg;r7xRo)J%{>Mka5`~=~9Jxz37 zdtd9>GkxEQP!iD@YGWlhBL)W3$aN&NQ$ybV_+=_?YshlOCCJxkr0P^Bpc6^1imUBg5l4 z4+>GsAs9x28=)w5BLR`z!9etX)M?{7(KFHiy?)?3!4f^_BAU^zZPX8Eu7G6L{{}qUz{$IjMI0@##$KXx$_3y#?&^N_JD_|`R2tK_!J!A^J-)*sRm1l+Gvq(U8WkRlURnxlJR8jbX* zzVTGtBJ1;|&yCXWy3+5yr{vVA*Z}F;C_`EB$Glc7K2BIi3iXB*PX^krS^jAW{m&)T zoyZoRfH!*osUEt*qg`vw$Zvv8`rqhe`~Yl3LB*1PYr;K_SK8G3-fy;UW*$DWaNIK_ zo19!}I4yuaAXpd~@4QN;6OKJo+lc)0D}&OK37}XhWP;yu0D8~1jjt|5?)a7kai0}U z1pUbGe6+wkkMZby+L2|f^XUfDjTOjl>243p80EA=Z3m!D2qru$W=Qg6*JT!y18RSTIr|aqg>6S9LsAtOzMc zRnc+R=f3;yJMY$g_uW)y=KAV7eWf`oaC}OLU+-Pref6EsixcOCi2A;3hJ8ozvm=fj zEf-F&o7M&2e{Om+@Tb4rbbQNnO0B+=D{NidZf(g>nt|1Ct#4nuxwEv`3Z^65Y=wc< z>IKSf_DgN6i;97Yfrn+_r1;X=*%{Vydg2s)_A}QWR$KT@R;~6?0qb?Ioq+HNUyWs52V*q>90HXJiV-bDh4VBDh4VBDh4VBDh4VB zDh4VBDh4VBDh4VB9zh07TZkuM*I#*~kN^MA_W%FH&9{L!fjB;8~yzoCN-GQivCU-vcJl1{Q!*z`s5r#GipTfY*Va0N((b zfB~EVJ`Mcs<3jubh=6V23UCQn1kM9(U;&s1o&*H&n~w?M0TTcLcRwn`uYdtSz#AWd zKHx0y-iL+wEpQ7s56lDaJ}$&xfER%u0eiqSaQ8#t2lyp01U7*+U;%gvcoKN~gLn?u zz#eb|cnY|Sg34>aE5IFK7kCzU1~>uSLqX;)@JHZhz;l2C`alop0?R-h;JJG%E7Xj9 ztkohj=%^?#gCW_1MgLD zAJcPMb#2)t3}R(jo= zXS$Ng%!)|j)y7LcKav>r*%;j%v$^m@Peeo1TcS?SAQGAGuJZQMIFf|{qSX*nlf;B% zj--7aEsi+&)d}PPQ3N3ZQF33BHMAVWvOq?0;3W!*UGi6UI;JIB7XwN8EyS*l3J?Pa zKIW@|*);QuG0Pp*(aS_?+a*L!8PLQM0?sgu!wK?nwRw4QXJV{j$wP>;hEGQlVZJZz zz)ae9^0>zujyQJuDkOO;Hl2_RxRY`bZ${w}W8xscr3PUZg(hj5X2;Ab`XM@7HxKUD zfv*-h+$emeNL^49gLYmfIv^eKHpVj0JlfXVAW4#|2qt^6+W}>;k7Hj#d5CqDcoE9& zA2Tyu0x~)lDTx{wI2M6svx&9Ao`j>7yexkal4f}$bnQs5hwj|jIvUqqo6Q+p^~m3c znL0k_Y7v?DXv-Q5=H}KA0JYvla&Ri@=p0HAv>mxGomLnlO6pV26+(Vyj=GS%g6s8D z_hKeoKt_9DC|dP8-RJ}gL35SYz^-XYP9SItl-m*u9|yW1=Jgu0TnNaZ(Cpzw<5c4j z%W9GwBKezT6igv=j6k-l5J(8i`W{U!AUophCOL`%SkytdyL*=E#0~Z`F-`!=g6)!E zqPbSnt@wgYEszHd*=oe}8ZmOe9>PZ=wLmlHlCya&dXA>1xYMjj)*n)h)MRaQHbj2B zt5(>}rxG=F^7xVmk+m%h;@B39HPAgJHtSOu&A-}eK4G(AU}9bmH!{`_C3Q?WPtMVN zj24G%aE9jh9Cymyp*_p2n^-^(3~d``F_odjSjgx_*0ikHjd}SoY~?mPvZs*5A~xd! z*%F`P)z_jBixg5>8IH0Dopa`6MI-H9JkETICw1el`hV&J~_$H z*yD~(4opDjc$t$WyTDBDa!f?YK(;E)pwatTBvR-Z$B0B^KTK?cawWBd#>_*LK%rYC z9is)brU4{Nx;n$g$TUd6*v5Uy!-b*oz$rmK&O$O!u)56~ozu#*oYT(OsZomLj_ST2%*w&4B$PqD|{!`_$o z|5H2!94}!%e+%$|E5N_8e}4t=fiD2BVbA_O;4*Lmcolo|9|O+=b>MTrKd|?H9k>I8 zzy@Fd_pk?l3-~#(1>D0P{5{}5z~6zFfjhtp!1F*GFn}|_yRh?3;QPQ$fbG5kR7b@? z#X!YC#X!YC#lR!NK(=AvecVzH*%_a#`6!_A$wjVg=prX4f29&|xhC2`YNyK+I7!PL z(7?s8tk%+?87?7>O2(XLaYb0#^JC>719C{4xP*)y{`+k+d0)#r(6kJkCa7g2bJOz; zc|&JnNO#BWFI7qIvt18`?dLqMSj)mKoy;fM+h`Y%)(|;iFFh|8x)N(ysMT6V)xAIu zGVUysE-84xi##iDqfp8JP=m}S6y1$k_5Mo7^dwr?KaQg$FR+dbVPDfK4zHU&P{p+D zUj=RV7guwqEJ*1=7b>c}JxP4{NTIT$6hz_dR6dF&uHtCnex*|D0~J}1EU)s*bm$(b t=x<{?qGVTAxH)E4;fb!MQoAarqJmtS|EiqIVVHwrRZiui?f;^j`d^8zT@L^N literal 0 HcmV?d00001 diff --git a/src/alphaComplex.h b/src/alphaComplex.h index 395180d..e5c7206 100644 --- a/src/alphaComplex.h +++ b/src/alphaComplex.h @@ -8,8 +8,7 @@ #include // for Dionysus -#include - +#include // for phat #include @@ -62,18 +61,12 @@ void alphaComplexDiag( FiltrationDiagGudhi( alphaCmplx, coeff_field_characteristic, min_persistence, 2, printProgress, persDgm); - } - else if (libraryDiag[0] == 'D' && libraryDiag[2] == '2') { - Fltr2 filtration = filtrationGudhiToDionysus2< Fltr2 >(alphaCmplx); - FiltrationDiagDionysus2< Persistence2 >( - filtration, maxdimension, location, printProgress, persDgm, persLoc, - persCycle); } else if (libraryDiag[0] == 'D') { // 2018-08-04 // switching back to original code - Fltr filtration = filtrationGudhiToDionysus< Fltr >(alphaCmplx); - FiltrationDiagDionysus< Persistence >( + Fltr2 filtration = filtrationGudhiToDionysus2< Fltr2 >(alphaCmplx); + FiltrationDiagDionysus2< Persistence2 >( filtration, maxdimension, location, printProgress, persDgm, persLoc, persCycle); } diff --git a/src/alphaComplexb.h b/src/alphaComplexb.h new file mode 100644 index 0000000..35aa670 --- /dev/null +++ b/src/alphaComplexb.h @@ -0,0 +1,93 @@ +#include +#include + +// for changing formats and typecasting +#include + +//for GUDHI +#include + +// for Dionysus +#include +#include + +// for phat +#include + +#include + + + +// AlphaComplexDiag +/** \brief Interface for R code, construct the persistence diagram of the alpha + * complex constructed on the input set of points. + * + * @param[out] Rcpp::List A list + * @param[in] X An nx3 matrix of coordinates, + * @param[in] maxalphasquare Threshold for the Alpha complex, + * @param[in] printProgress Is progress printed? + */ +template< typename RealMatrix, typename Print > +void alphaComplexDiag( + const RealMatrix & X, //points to some memory space + const unsigned nSample, + const unsigned nDim, + const int maxdimension, + const std::string & libraryDiag, + const bool location, + const bool printProgress, + const Print & print, + std::vector< std::vector< std::vector< double > > > & persDgm, + std::vector< std::vector< std::vector< unsigned > > > & persLoc, + std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle +) { + + using Kernel = CGAL::Epick_d< CGAL::Dynamic_dimension_tag>; + using Point = Kernel::Point_d; + + int coeff_field_characteristic = 2; + + float min_persistence = 0.0; + + Gudhi::Simplex_tree<> alphaCmplx = + AlphaComplexFiltrationGudhi< Gudhi::Simplex_tree<> >( + X, printProgress, print); + + // 2018-08-04 + // switching back to original code + + // Compute the persistence diagram of the complex + if (libraryDiag[0] == 'G') { + // 2018-08-04 + // switching back to original code + FiltrationDiagGudhi( + alphaCmplx, coeff_field_characteristic, min_persistence, 2, + printProgress, persDgm); + } + else if (libraryDiag[0] == 'D' && libraryDiag[2] == '2') { + Fltr2 filtration = filtrationGudhiToDionysus2< Fltr2 >(alphaCmplx); + FiltrationDiagDionysus2< Persistence2 >( + filtration, maxdimension, location, printProgress, persDgm, persLoc, + persCycle); + } + else if (libraryDiag[0] == 'D') { + // 2018-08-04 + // switching back to original code + Fltr filtration = filtrationGudhiToDionysus< Fltr >(alphaCmplx); + FiltrationDiagDionysus< Persistence >( + filtration, maxdimension, location, printProgress, persDgm, persLoc, + persCycle); + } + else { + // 2018-08-04 + // switching back to original code + std::vector< phat::column > cmplx; + std::vector< double > values; + phat::boundary_matrix< phat::vector_vector > boundary_matrix; + filtrationGudhiToPhat< phat::column, phat::dimension >( + alphaCmplx, cmplx, values, boundary_matrix); + FiltrationDiagPhat( + cmplx, values, boundary_matrix, maxdimension, location, + printProgress, persDgm, persLoc, persCycle); + } +} diff --git a/src/alphaShape.h b/src/alphaShape.h index 85bc08d..9063923 100644 --- a/src/alphaShape.h +++ b/src/alphaShape.h @@ -1,80 +1,74 @@ -#include -#include - -// for changing formats and typecasting -#include - -//for GUDHI -#include - -// for Dionysus -#include - -// for phat -#include - -#include - - - -// AlphaShapeDiag -/** \brief Interface for R code, construct the persistence diagram of the alpha - * shape complex constructed on the input set of points. - * - * @param[out] Rcpp::List A list - * @param[in] X An nx3 matrix of coordinates, - * @param[in] printProgress Is progress printed? - */ -template< typename RealMatrix, typename Print > -void alphaShapeDiag( - const RealMatrix & X, //points to some memory space - const unsigned nSample, - const unsigned nDim, - const int maxdimension, - const std::string & libraryDiag, - const bool location, - const bool printProgress, - const Print & print, - std::vector< std::vector< std::vector< double > > > & persDgm, - std::vector< std::vector< std::vector< unsigned > > > & persLoc, - std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle, - RealMatrix & coordinates -) { - - int coeff_field_characteristic = 2; - - float min_persistence = 0.0; - - Gudhi::Simplex_tree<> smplxTree = - AlphaShapeFiltrationGudhi< Gudhi::Simplex_tree<> >( - X, printProgress, print, coordinates); - - // Compute the persistence diagram of the complex - if (libraryDiag[0] == 'G') { - FiltrationDiagGudhi( - smplxTree, coeff_field_characteristic, min_persistence, 2, - printProgress, persDgm); - } - else if (libraryDiag[0] == 'D' && libraryDiag[1] == '2') { - Fltr2 filtration = filtrationGudhiToDionysus2< Fltr2 >(smplxTree); - FiltrationDiagDionysus2< Persistence2 >( - filtration, maxdimension, location, printProgress, persDgm, persLoc, - persCycle); - } - else if (libraryDiag[0] == 'D') { - Fltr filtration = filtrationGudhiToDionysus< Fltr >(smplxTree); - FiltrationDiagDionysus< Persistence >( - filtration, maxdimension, location, printProgress, persDgm, persLoc, - persCycle); - } - else { - std::vector< phat::column > cmplx; - std::vector< double > values; - phat::boundary_matrix< phat::vector_vector > boundary_matrix; - filtrationGudhiToPhat< phat::column, phat::dimension >( - smplxTree, cmplx, values, boundary_matrix); - FiltrationDiagPhat( - cmplx, values, boundary_matrix, maxdimension, location, - printProgress, persDgm, persLoc, persCycle); - } -} +#include +#include + +// for changing formats and typecasting +#include + +//for GUDHI +#include + +// for Dionysus +#include + +// for phat +#include + +#include + + + +// AlphaShapeDiag +/** \brief Interface for R code, construct the persistence diagram of the alpha + * shape complex constructed on the input set of points. + * + * @param[out] Rcpp::List A list + * @param[in] X An nx3 matrix of coordinates, + * @param[in] printProgress Is progress printed? + */ +template< typename RealMatrix, typename Print > +void alphaShapeDiag( + const RealMatrix & X, //points to some memory space + const unsigned nSample, + const unsigned nDim, + const int maxdimension, + const std::string & libraryDiag, + const bool location, + const bool printProgress, + const Print & print, + std::vector< std::vector< std::vector< double > > > & persDgm, + std::vector< std::vector< std::vector< unsigned > > > & persLoc, + std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle, + RealMatrix & coordinates +) { + + int coeff_field_characteristic = 2; + + float min_persistence = 0.0; + + Gudhi::Simplex_tree<> smplxTree = + AlphaShapeFiltrationGudhi< Gudhi::Simplex_tree<> >( + X, printProgress, print, coordinates); + + // Compute the persistence diagram of the complex + if (libraryDiag[0] == 'G') { + FiltrationDiagGudhi( + smplxTree, coeff_field_characteristic, min_persistence, 2, + printProgress, persDgm); + } + else if (libraryDiag[0] == 'D') { + Fltr2 filtration = filtrationGudhiToDionysus2< Fltr2 >(smplxTree); + FiltrationDiagDionysus2< Persistence2 >( + filtration, maxdimension, location, printProgress, persDgm, persLoc, + persCycle); + } + else { + std::vector< phat::column > cmplx; + std::vector< double > values; + phat::boundary_matrix< phat::vector_vector > boundary_matrix; + filtrationGudhiToPhat< phat::column, phat::dimension >( + smplxTree, cmplx, values, boundary_matrix); + FiltrationDiagPhat( + cmplx, values, boundary_matrix, maxdimension, location, + printProgress, persDgm, persLoc, persCycle); + } +} diff --git a/src/alphaShapeb.h b/src/alphaShapeb.h new file mode 100644 index 0000000..04cb5fc --- /dev/null +++ b/src/alphaShapeb.h @@ -0,0 +1,81 @@ +#include +#include + +// for changing formats and typecasting +#include + +//for GUDHI +#include + +// for Dionysus +#include +#include + +// for phat +#include + +#include + + + +// AlphaShapeDiag +/** \brief Interface for R code, construct the persistence diagram of the alpha + * shape complex constructed on the input set of points. + * + * @param[out] Rcpp::List A list + * @param[in] X An nx3 matrix of coordinates, + * @param[in] printProgress Is progress printed? + */ +template< typename RealMatrix, typename Print > +void alphaShapeDiag( + const RealMatrix & X, //points to some memory space + const unsigned nSample, + const unsigned nDim, + const int maxdimension, + const std::string & libraryDiag, + const bool location, + const bool printProgress, + const Print & print, + std::vector< std::vector< std::vector< double > > > & persDgm, + std::vector< std::vector< std::vector< unsigned > > > & persLoc, + std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle, + RealMatrix & coordinates +) { + + int coeff_field_characteristic = 2; + + float min_persistence = 0.0; + + Gudhi::Simplex_tree<> smplxTree = + AlphaShapeFiltrationGudhi< Gudhi::Simplex_tree<> >( + X, printProgress, print, coordinates); + + // Compute the persistence diagram of the complex + if (libraryDiag[0] == 'G') { + FiltrationDiagGudhi( + smplxTree, coeff_field_characteristic, min_persistence, 2, + printProgress, persDgm); + } + else if (libraryDiag[0] == 'D' && libraryDiag[1] == '2') { + Fltr2 filtration = filtrationGudhiToDionysus2< Fltr2 >(smplxTree); + FiltrationDiagDionysus2< Persistence2 >( + filtration, maxdimension, location, printProgress, persDgm, persLoc, + persCycle); + } + else if (libraryDiag[0] == 'D') { + Fltr filtration = filtrationGudhiToDionysus< Fltr >(smplxTree); + FiltrationDiagDionysus< Persistence >( + filtration, maxdimension, location, printProgress, persDgm, persLoc, + persCycle); + } + else { + std::vector< phat::column > cmplx; + std::vector< double > values; + phat::boundary_matrix< phat::vector_vector > boundary_matrix; + filtrationGudhiToPhat< phat::column, phat::dimension >( + smplxTree, cmplx, values, boundary_matrix); + FiltrationDiagPhat( + cmplx, values, boundary_matrix, maxdimension, location, + printProgress, persDgm, persLoc, persCycle); + } +} diff --git a/src/diag.cpp b/src/diag.cpp index 37c598f..73c0fec 100644 --- a/src/diag.cpp +++ b/src/diag.cpp @@ -6,9 +6,9 @@ #include // for Rips -#include +//#include #include -#include +//#include #include // for grid @@ -23,6 +23,9 @@ // for Dionysus #include #include +//#include +#include + // for phat #include @@ -144,8 +147,11 @@ double Bottleneck(const Rcpp::NumericMatrix & Diag1 , const Rcpp::NumericMatrix & Diag2 ) { - return bottleneck_distance(RcppToDionysus< PersistenceDiagram<> >(Diag1), - RcppToDionysus< PersistenceDiagram<> >(Diag2)); + //return bottleneck_distance(RcppToDionysus< PersistenceDiagram<> >(Diag1), + // RcppToDionysus< PersistenceDiagram<> >(Diag2)); + auto b = RcppToPairVector< std::vector > >(Diag1); + auto a = RcppToPairVector< std::vector > >(Diag2); + return hera::bottleneckDistExact(a,b); } @@ -158,6 +164,13 @@ Wasserstein(const Rcpp::NumericMatrix & Diag1 ) { return wasserstein_distance(RcppToDionysus< PersistenceDiagram<> >(Diag1), RcppToDionysus< PersistenceDiagram<> >(Diag2), p); + //hera::AuctionParams params; + //params.wasserstein_power = p; + //return hera::wasserstein_dist(RcppToDionysus2< PDgm >(Diag1), + // RcppToDionysus2< PDgm >(Diag2), params); + //auto b = RcppToPairVector< std::vector > >(Diag1); + //auto a = RcppToPairVector< std::vector > >(Diag2); + //return hera::wasserstein_dist(a, b, params); } diff --git a/src/dionysus/bottleneck/basic_defs_bt.h b/src/dionysus/bottleneck/basic_defs_bt.h new file mode 100755 index 0000000..69dc709 --- /dev/null +++ b/src/dionysus/bottleneck/basic_defs_bt.h @@ -0,0 +1,476 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + +*/ + +#ifndef HERA_BASIC_DEFS_BT_H +#define HERA_BASIC_DEFS_BT_H + +#ifdef _WIN32 +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "def_debug_bt.h" + +#ifndef FOR_R_TDA +#include +#endif + +namespace hera { +namespace bt { + +typedef int IdType; +constexpr IdType MinValidId = 10; + +template +struct Point { + Real x, y; + + bool operator==(const Point &other) const + { + return ((this->x == other.x) and (this->y == other.y)); + } + + bool operator!=(const Point &other) const + { + return !(*this == other); + } + + Point(Real ax, Real ay) : x(ax), y(ay) {} + + Point() : x(0.0), y(0.0) {} + +#ifndef FOR_R_TDA + + template + friend std::ostream& operator<<(std::ostream& output, const Point& p) + { + output << "(" << p.x << ", " << p.y << ")"; + return output; + } + +#endif +}; + +template +struct DiagramPoint { + // Points above the diagonal have type NORMAL + // Projections onto the diagonal have type DIAG + // for DIAG points only x-coordinate is relevant + // to-do: add getters/setters, checks in constructors, etc + enum Type { + NORMAL, DIAG + }; + // data members +private: + Real x, y; +public: + Type type; + IdType id; + + // operators, constructors + bool operator==(const DiagramPoint &other) const + { + // compare by id only + assert(this->id >= MinValidId); + assert(other.id >= MinValidId); + bool areEqual{ this->id == other.id }; + assert(!areEqual or ((this->x == other.x) and (this->y == other.y) and (this->type == other.type))); + return areEqual; + } + + bool operator!=(const DiagramPoint &other) const + { + return !(*this == other); + } + + DiagramPoint(Real _x, Real _y, Type _type, IdType _id) : + x(_x), + y(_y), + type(_type), + id(_id) + { + if ( _y == _x and _type != DIAG) + throw std::runtime_error("Point on the main diagonal must have DIAG type"); + + } + + + bool isDiagonal() const { return type == DIAG; } + + bool isNormal() const { return type == NORMAL; } + + bool isInfinity() const + { + return x == std::numeric_limits::infinity() or + x == -std::numeric_limits::infinity() or + y == std::numeric_limits::infinity() or + y == -std::numeric_limits::infinity(); + } + + Real inline getRealX() const // return the x-coord + { + return x; + } + + Real inline getRealY() const // return the y-coord + { + return y; + } + +#ifndef FOR_R_TDA + template + friend std::ostream& operator<<(std::ostream& output, const DiagramPoint& p) + { + if ( p.isDiagonal() ) { + output << "(" << p.x << ", " << p.y << ", " << 0.5 * (p.x + p.y) << ", " << p.id << " DIAG )"; + } else { + output << "(" << p.x << ", " << p.y << ", " << p.id << " NORMAL)"; + } + return output; + } +#endif + +}; + +// compute l-inf distance between two diagram points +template +Real distLInf(const DiagramPoint& a, const DiagramPoint& b) +{ + if ( a.isDiagonal() and b.isDiagonal() ) { + // distance between points on the diagonal is 0 + return 0.0; + } + // otherwise distance is a usual l-inf distance + return std::max(fabs(a.getRealX() - b.getRealX()), fabs(a.getRealY() - b.getRealY())); +} + + +template +inline void hash_combine(std::size_t & seed, const T & v) +{ + std::hash hasher; + seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +template +struct DiagramPointHash { + size_t operator()(const DiagramPoint &p) const + { + assert(p.id >= MinValidId); + return std::hash()(p.id); + } +}; + +template +Real distLInf(const DiagramPoint &a, const DiagramPoint &b); + +//template +//typedef std::unordered_set PointSet; +template +class DiagramPointSet; + +template +void addProjections(DiagramPointSet &A, DiagramPointSet &B); + +template +class DiagramPointSet { +public: + + using Real = Real_; + using DgmPoint = DiagramPoint; + using DgmPointHash = DiagramPointHash; + using const_iterator = typename std::unordered_set::const_iterator; + using iterator = typename std::unordered_set::iterator; + +private: + + bool isLinked { false }; + IdType maxId { MinValidId + 1 }; + std::unordered_set points; + +public: + + void insert(const DgmPoint& p) + { + points.insert(p); + if (p.id > maxId) { + maxId = p.id + 1; + } + } + + void erase(const DgmPoint& p, bool doCheck = true) + { + // if doCheck, erasing non-existing elements causes assert + auto it = points.find(p); + if (it != points.end()) { + points.erase(it); + } else { + assert(!doCheck); + } + } + + + void erase(const const_iterator it) + { + points.erase(it); + } + + void removeDiagonalPoints() + { + if (isLinked) { + auto ptIter = points.begin(); + while(ptIter != points.end()) { + if (ptIter->isDiagonal()) { + ptIter = points.erase(ptIter); + } else { + ptIter++; + } + } + isLinked = false; + } + } + + size_t size() const + { + return points.size(); + } + + void reserve(const size_t newSize) + { + points.reserve(newSize); + } + + void clear() + { + points.clear(); + } + + bool empty() const + { + return points.empty(); + } + + bool hasElement(const DgmPoint &p) const + { + return points.find(p) != points.end(); + } + + iterator find(const DgmPoint &p) + { + return points.find(p); + } + + iterator begin() + { + return points.begin(); + } + + iterator end() + { + return points.end(); + } + + const_iterator cbegin() const + { + return points.cbegin(); + } + + const_iterator cend() const + { + return points.cend(); + } + + + const_iterator find(const DgmPoint &p) const + { + return points.find(p); + } + +#ifndef FOR_R_TDA + template + friend std::ostream& operator<<(std::ostream& output, const DiagramPointSet& ps) + { + output << "{ "; + for(auto pit = ps.cbegin(); pit != ps.cend(); ++pit) { + output << *pit << ", "; + } + output << "\b\b }"; + return output; + } +#endif + + friend void addProjections(DiagramPointSet& A, DiagramPointSet& B); + + template + void fillIn(PairIterator begin_iter, PairIterator end_iter) + { + isLinked = false; + clear(); + IdType uniqueId = MinValidId + 1; + for (auto iter = begin_iter; iter != end_iter; ++iter) { + insert(DgmPoint(iter->first, iter->second, DgmPoint::NORMAL, uniqueId++)); + } + } + + template + void fillIn(const PointContainer& dgm_cont) + { + using Traits = DiagramTraits; + isLinked = false; + clear(); + IdType uniqueId = MinValidId + 1; + for (const auto& pt : dgm_cont) { + Real x = Traits::get_x(pt); + Real y = Traits::get_y(pt); + insert(DgmPoint(x, y, DgmPoint::NORMAL, uniqueId++)); + } + } + + + // ctor from range + template + DiagramPointSet(PairIterator begin_iter, PairIterator end_iter) + { + fillIn(begin_iter, end_iter); + } + + // ctor from container, uses DiagramTraits + template + DiagramPointSet(const PointContainer& dgm) + { + fillIn(dgm); + } + + + // default ctor, empty diagram + DiagramPointSet(IdType minId = MinValidId + 1) : maxId(minId + 1) {}; + + IdType nextId() { return maxId + 1; } + +}; // DiagramPointSet + + +template +Real getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B) { + Real result{0.0}; + DiagramPoint begA = *(A.begin()); + DiagramPoint optB = *(B.begin()); + for (const auto &pointB : B) { + if (distLInf(begA, pointB) > result) { + result = distLInf(begA, pointB); + optB = pointB; + } + } + for (const auto &pointA : A) { + if (distLInf(pointA, optB) > result) { + result = distLInf(pointA, optB); + } + } + return result; +} + +// preprocess diagrams A and B by adding projections onto diagonal of points of +// A to B and vice versa. Also removes points at infinity! +// NB: ids of points will be changed! +template +void addProjections(DiagramPointSet& A, DiagramPointSet& B) +{ + + using Real = Real_; + using DgmPoint = DiagramPoint; + using DgmPointSet = DiagramPointSet; + + IdType uniqueId {MinValidId + 1}; + DgmPointSet newA, newB; + + // copy normal points from A to newA + // add projections to newB + for(auto& pA : A) { + if (pA.isNormal() and not pA.isInfinity()) { + // add pA's projection to B + DgmPoint dpA {pA.getRealX(), pA.getRealY(), DgmPoint::NORMAL, uniqueId++}; + DgmPoint dpB {(pA.getRealX() +pA.getRealY())/2, (pA.getRealX() +pA.getRealY())/2, DgmPoint::DIAG, uniqueId++}; + newA.insert(dpA); + newB.insert(dpB); + } + } + + for(auto& pB : B) { + if (pB.isNormal() and not pB.isInfinity()) { + // add pB's projection to A + DgmPoint dpB {pB.getRealX(), pB.getRealY(), DgmPoint::NORMAL, uniqueId++}; + DgmPoint dpA {(pB.getRealX() +pB.getRealY())/2, (pB.getRealX() +pB.getRealY())/2, DgmPoint::DIAG, uniqueId++}; + newB.insert(dpB); + newA.insert(dpA); + } + } + + A = newA; + B = newB; + A.isLinked = true; + B.isLinked = true; +} + + +//#ifndef FOR_R_TDA + +//template +//std::ostream& operator<<(std::ostream& output, const DiagramPoint& p) +//{ +// if ( p.isDiagonal() ) { +// output << "(" << p.x << ", " << p.y << ", " << 0.5 * (p.x + p.y) << ", " << p.id << " DIAG )"; +// } else { +// output << "(" << p.x << ", " << p.y << ", " << p.id << " NORMAL)"; +// } +// return output; +//} + +//template +//std::ostream& operator<<(std::ostream& output, const DiagramPointSet& ps) +//{ +// output << "{ "; +// for(auto pit = ps.cbegin(); pit != ps.cend(); ++pit) { +// output << *pit << ", "; +// } +// output << "\b\b }"; +// return output; +//} +//#endif // FOR_R_TDA + + +} // end namespace bt +} // end namespace hera +#endif // HERA_BASIC_DEFS_BT_H diff --git a/src/dionysus/bottleneck/bottleneck.h b/src/dionysus/bottleneck/bottleneck.h new file mode 100755 index 0000000..0d4e1ed --- /dev/null +++ b/src/dionysus/bottleneck/bottleneck.h @@ -0,0 +1,118 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + +*/ + +#ifndef HERA_BOTTLENECK_H +#define HERA_BOTTLENECK_H + + +#include +#include +#include +#include +#include + +#include "diagram_traits.h" +#include "diagram_reader.h" +#include "bottleneck_detail.h" +#include "basic_defs_bt.h" +#include "bound_match.h" + +namespace hera { + +// functions taking containers as input +// template parameter PairContainer must be a container of pairs of real +// numbers (pair.first = x-coordinate, pair.second = y-coordinate) +// PairContainer class must support iteration of the form +// for(it = pairContainer.begin(); it != pairContainer.end(); ++it) + +// all functions in this header are wrappers around +// functions from hera::bt namespace + +// get exact bottleneck distance, +template +typename DiagramTraits::RealType +bottleneckDistExact(PairContainer& dgm_A, PairContainer& dgm_B) +{ + using Real = typename DiagramTraits::RealType; + hera::bt::DiagramPointSet a(dgm_A); + hera::bt::DiagramPointSet b(dgm_B); + return hera::bt::bottleneckDistExact(a, b, 14); +} + +// get exact bottleneck distance, +template +typename DiagramTraits::RealType +bottleneckDistExact(PairContainer& dgm_A, PairContainer& dgm_B, const int decPrecision) +{ + using Real = typename DiagramTraits::RealType; + hera::bt::DiagramPointSet a(dgm_A); + hera::bt::DiagramPointSet b(dgm_B); + return hera::bt::bottleneckDistExact(a, b, decPrecision); +} + +// return the interval (distMin, distMax) such that: +// a) actual bottleneck distance between A and B is contained in the interval +// b) if the interval is not (0,0), then (distMax - distMin) / distMin < delta +template +std::pair::RealType, typename DiagramTraits::RealType> +bottleneckDistApproxInterval(PairContainer& dgm_A, PairContainer& dgm_B, const typename DiagramTraits::RealType delta) +{ + using Real = typename DiagramTraits::RealType; + hera::bt::DiagramPointSet a(dgm_A); + hera::bt::DiagramPointSet b(dgm_B); + return hera::bt::bottleneckDistApproxInterval(a, b, delta); +} + +// use sampling heuristic: discard most of the points with small persistency +// to get a good initial approximation of the bottleneck distance +template +typename DiagramTraits::RealType +bottleneckDistApproxHeur(PairContainer& dgm_A, PairContainer& dgm_B, const typename DiagramTraits::RealType delta) +{ + using Real = typename DiagramTraits::RealType; + hera::bt::DiagramPointSet a(dgm_A); + hera::bt::DiagramPointSet b(dgm_B); + std::pair resPair = hera::bt::bottleneckDistApproxIntervalHeur(a, b, delta); + return resPair.second; +} + +// get approximate distance, +// see bottleneckDistApproxInterval +template +typename DiagramTraits::RealType +bottleneckDistApprox(PairContainer& A, PairContainer& B, const typename DiagramTraits::RealType delta) +{ + using Real = typename DiagramTraits::RealType; + hera::bt::DiagramPointSet a(A.begin(), A.end()); + hera::bt::DiagramPointSet b(B.begin(), B.end()); + return hera::bt::bottleneckDistApprox(a, b, delta); +} + +} // end namespace hera + +#endif diff --git a/src/dionysus/bottleneck/bottleneck_detail.h b/src/dionysus/bottleneck/bottleneck_detail.h new file mode 100755 index 0000000..27c3c5d --- /dev/null +++ b/src/dionysus/bottleneck/bottleneck_detail.h @@ -0,0 +1,85 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + +*/ + +#ifndef HERA_BOTTLENECK_DETAIL_H +#define HERA_BOTTLENECK_DETAIL_H + + +#include +#include +#include +#include +#include + +#include "diagram_traits.h" +#include "basic_defs_bt.h" +#include "bound_match.h" + +namespace hera { + + +namespace bt { + + + +// functions taking DiagramPointSet as input. +// ATTENTION: parameters A and B (diagrams) will be changed after the call +// (projections added). + +// return the interval (distMin, distMax) such that: +// a) actual bottleneck distance between A and B is contained in the interval +// b) if the interval is not (0,0), then (distMax - distMin) / distMin < epsilon +template +std::pair bottleneckDistApproxInterval(DiagramPointSet& A, DiagramPointSet& B, const Real epsilon); + + +// heuristic (sample diagram to estimate the distance) +template +std::pair bottleneckDistApproxIntervalHeur(DiagramPointSet& A, DiagramPointSet& B, const Real epsilon); + +// get approximate distance, +// see bottleneckDistApproxInterval +template +Real bottleneckDistApprox(DiagramPointSet& A, DiagramPointSet& B, const Real epsilon); + +// get exact bottleneck distance, +template +Real bottleneckDistExact(DiagramPointSet& A, DiagramPointSet& B, const int decPrecision); + +// get exact bottleneck distance, +template +Real bottleneckDistExact(DiagramPointSet& A, DiagramPointSet& B); + +} // end namespace bt + + +} // end namespace hera + +#include "bottleneck_detail.hpp" + +#endif // HERA_BOTTLENECK_DETAIL_H diff --git a/src/dionysus/bottleneck/bottleneck_detail.hpp b/src/dionysus/bottleneck/bottleneck_detail.hpp new file mode 100755 index 0000000..24cb725 --- /dev/null +++ b/src/dionysus/bottleneck/bottleneck_detail.hpp @@ -0,0 +1,507 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + +*/ + +#ifndef HERA_BOTTLENECK_HPP +#define HERA_BOTTLENECK_HPP + +#ifdef FOR_R_TDA +#undef DEBUG_BOUND_MATCH +#undef DEBUG_MATCHING +#undef VERBOSE_BOTTLENECK +#endif + + +#include +#include +#include +#include + +#include "bottleneck_detail.h" + +namespace hera { +namespace bt { + +template +void binarySearch(const Real epsilon, + std::pair& result, + BoundMatchOracle& oracle, + const Real infinityCost, + bool isResultInitializedCorrectly, + const Real distProbeInit) +{ + // aliases for result components + Real& distMin = result.first; + Real& distMax = result.second; + + distMin = std::max(distMin, infinityCost); + distMax = std::max(distMax, infinityCost); + + Real distProbe; + + if (not isResultInitializedCorrectly) { + distProbe = distProbeInit; + if (oracle.isMatchLess(distProbe)) { + // distProbe is an upper bound, + // find lower bound with binary search + do { + distMax = distProbe; + distProbe /= 2.0; + } while (oracle.isMatchLess(distProbe)); + distMin = distProbe; + } else { + // distProbe is a lower bound, + // find upper bound with exponential search + do { + distMin = distProbe; + distProbe *= 2.0; + } while (!oracle.isMatchLess(distProbe)); + distMax = distProbe; + } + } + // bounds are correct , perform binary search + distProbe = ( distMin + distMax ) / 2.0; + while (( distMax - distMin ) / distMin >= epsilon ) { + + if (distMax < infinityCost) { + distMin = infinityCost; + distMax = infinityCost; + break; + } + + if (oracle.isMatchLess(distProbe)) { + distMax = distProbe; + } else { + distMin = distProbe; + } + + distProbe = ( distMin + distMax ) / 2.0; + } + + distMin = std::max(distMin, infinityCost); + distMax = std::max(distMax, infinityCost); +} + +template +inline Real getOneDimensionalCost(std::vector &set_A, std::vector &set_B) +{ + if (set_A.size() != set_B.size()) { + return std::numeric_limits::infinity(); + } + + if (set_A.empty()) { + return Real(0.0); + } + + std::sort(set_A.begin(), set_A.end()); + std::sort(set_B.begin(), set_B.end()); + + Real result = 0.0; + for(size_t i = 0; i < set_A.size(); ++i) { + result = std::max(result, (std::fabs(set_A[i] - set_B[i]))); + } + + return result; +} + +template +inline Real getInfinityCost(const DiagramPointSet &A, const DiagramPointSet &B) +{ + std::vector x_plus_A, x_minus_A, y_plus_A, y_minus_A; + std::vector x_plus_B, x_minus_B, y_plus_B, y_minus_B; + + for(auto iter_A = A.cbegin(); iter_A != A.cend(); ++iter_A) { + Real x = iter_A->getRealX(); + Real y = iter_A->getRealY(); + if ( x == std::numeric_limits::infinity()) { + y_plus_A.push_back(y); + } else if (x == -std::numeric_limits::infinity()) { + y_minus_A.push_back(y); + } else if (y == std::numeric_limits::infinity()) { + x_plus_A.push_back(x); + } else if (y == -std::numeric_limits::infinity()) { + x_minus_A.push_back(x); + } + } + + for(auto iter_B = B.cbegin(); iter_B != B.cend(); ++iter_B) { + Real x = iter_B->getRealX(); + Real y = iter_B->getRealY(); + if (x == std::numeric_limits::infinity()) { + y_plus_B.push_back(y); + } else if (x == -std::numeric_limits::infinity()) { + y_minus_B.push_back(y); + } else if (y == std::numeric_limits::infinity()) { + x_plus_B.push_back(x); + } else if (y == -std::numeric_limits::infinity()) { + x_minus_B.push_back(x); + } + } + + Real infinity_cost = getOneDimensionalCost(x_plus_A, x_plus_B); + infinity_cost = std::max(infinity_cost, getOneDimensionalCost(x_minus_A, x_minus_B)); + infinity_cost = std::max(infinity_cost, getOneDimensionalCost(y_plus_A, y_plus_B)); + infinity_cost = std::max(infinity_cost, getOneDimensionalCost(y_minus_A, y_minus_B)); + + return infinity_cost; +} + +// return the interval (distMin, distMax) such that: +// a) actual bottleneck distance between A and B is contained in the interval +// b) if the interval is not (0,0), then (distMax - distMin) / distMin < epsilon +template +inline std::pair bottleneckDistApproxInterval(DiagramPointSet& A, DiagramPointSet& B, const Real epsilon) +{ + // empty diagrams are not considered as error + if (A.empty() and B.empty()) + return std::make_pair(0.0, 0.0); + + Real infinity_cost = getInfinityCost(A, B); + if (infinity_cost == std::numeric_limits::infinity()) + return std::make_pair(infinity_cost, infinity_cost); + + // link diagrams A and B by adding projections + addProjections(A, B); + + // TODO: think about that! + // we need one threshold for checking if the distance is 0, + // another one for the oracle! + constexpr Real epsThreshold { 1.0e-10 }; + std::pair result { 0.0, 0.0 }; + bool useRangeSearch { true }; + // construct an oracle + BoundMatchOracle oracle(A, B, epsThreshold, useRangeSearch); + // check for distance = 0 + if (oracle.isMatchLess(2*epsThreshold)) { + return result; + } + // get a 3-approximation of maximal distance between A and B + // as a starting value for probe distance + Real distProbe { getFurthestDistance3Approx>(A, B) }; + binarySearch(epsilon, result, oracle, infinity_cost, false, distProbe); + return result; +} + +template +void sampleDiagramForHeur(const DiagramPointSet& dgmIn, DiagramPointSet& dgmOut) +{ + struct pair_hash { + std::size_t operator()(const std::pair p) const + { + return std::hash()(p.first) ^ std::hash()(p.second); + } + }; + std::unordered_map, int, pair_hash> m; + for(auto ptIter = dgmIn.cbegin(); ptIter != dgmIn.cend(); ++ptIter) { + if (ptIter->isNormal() and not ptIter->isInfinity()) { + m[std::make_pair(ptIter->getRealX(), ptIter->getRealY())]++; + } + } + if (m.size() < 2) { + dgmOut = dgmIn; + return; + } + std::vector v; + for(const auto& ptQtyPair : m) { + v.push_back(ptQtyPair.second); + } + std::sort(v.begin(), v.end()); + int maxLeap = v[1] - v[0]; + int cutVal = v[0]; + for(int i = 1; i < static_cast(v.size())- 1; ++i) { + int currLeap = v[i+1] - v[i]; + if (currLeap > maxLeap) { + maxLeap = currLeap; + cutVal = v[i]; + } + } + std::vector> vv; + // keep points whose multiplicites are at most cutVal + // quick-and-dirty: fill in vv with copies of each point + // to construct DiagramPointSet from it later + for(const auto& ptQty : m) { + if (ptQty.second < cutVal) { + for(int i = 0; i < ptQty.second; ++i) { + vv.push_back(std::make_pair(ptQty.first.first, ptQty.first.second)); + } + } + } + dgmOut.clear(); + dgmOut = DiagramPointSet(vv.begin(), vv.end()); +} + + +// return the interval (distMin, distMax) such that: +// a) actual bottleneck distance between A and B is contained in the interval +// b) if the interval is not (0,0), then (distMax - distMin) / distMin < epsilon +template +std::pair bottleneckDistApproxIntervalWithInitial(DiagramPointSet& A, DiagramPointSet& B, + const Real epsilon, + const std::pair initialGuess, + const Real infinity_cost) +{ + // empty diagrams are not considered as error + if (A.empty() and B.empty()) + return std::make_pair(0.0, 0.0); + + // link diagrams A and B by adding projections + addProjections(A, B); + + constexpr Real epsThreshold { 1.0e-10 }; + std::pair result { 0.0, 0.0 }; + bool useRangeSearch { true }; + // construct an oracle + BoundMatchOracle oracle(A, B, epsThreshold, useRangeSearch); + + Real& distMin {result.first}; + Real& distMax {result.second}; + + // initialize search interval from initialGuess + distMin = initialGuess.first; + distMax = initialGuess.second; + + assert(distMin <= distMax); + + // make sure that distMin is a lower bound + while(oracle.isMatchLess(distMin)) { + // distMin is in fact an upper bound, so assign it to distMax + distMax = distMin; + // and decrease distMin by 5 % + distMin = 0.95 * distMin; + } + + // make sure that distMax is an upper bound + while(not oracle.isMatchLess(distMax)) { + // distMax is in fact a lower bound, so assign it to distMin + distMin = distMax; + // and increase distMax by 5 % + distMax = 1.05 * distMax; + } + + // bounds are found, perform binary search + Real distProbe = ( distMin + distMax ) / 2.0; + binarySearch(epsilon, result, oracle, infinity_cost, true, distProbe); + return result; +} + +// return the interval (distMin, distMax) such that: +// a) actual bottleneck distance between A and B is contained in the interval +// b) if the interval is not (0,0), then (distMax - distMin) / distMin < epsilon +// use heuristic: initial estimate on sampled diagrams +template +std::pair bottleneckDistApproxIntervalHeur(DiagramPointSet& A, DiagramPointSet& B, const Real epsilon) +{ + // empty diagrams are not considered as error + if (A.empty() and B.empty()) + return std::make_pair(0.0, 0.0); + + Real infinity_cost = getInfinityCost(A, B); + if (infinity_cost == std::numeric_limits::infinity()) + return std::make_pair(infinity_cost, infinity_cost); + + DiagramPointSet sampledA, sampledB; + sampleDiagramForHeur(A, sampledA); + sampleDiagramForHeur(B, sampledB); + + std::pair initGuess = bottleneckDistApproxInterval(sampledA, sampledB, epsilon); + + initGuess.first = std::max(initGuess.first, infinity_cost); + initGuess.second = std::max(initGuess.second, infinity_cost); + + return bottleneckDistApproxIntervalWithInitial(A, B, epsilon, initGuess, infinity_cost); +} + + + +// get approximate distance, +// see bottleneckDistApproxInterval +template +Real bottleneckDistApprox(DiagramPointSet& A, DiagramPointSet& B, const Real epsilon) +{ + auto interval = bottleneckDistApproxInterval(A, B, epsilon); + return interval.second; +} + + +template +Real bottleneckDistExactFromSortedPwDist(DiagramPointSet&A, DiagramPointSet& B, std::vector& pairwiseDist, const int decPrecision) +{ + // trivial case: we have only one candidate + if (pairwiseDist.size() == 1) + return pairwiseDist[0]; + + bool useRangeSearch = true; + Real distEpsilon = std::numeric_limits::max(); + Real diffThreshold = 0.1; + for(int k = 0; k < decPrecision; ++k) { + diffThreshold /= 10.0; + } + for(size_t k = 0; k < pairwiseDist.size() - 2; ++k) { + auto diff = pairwiseDist[k+1]- pairwiseDist[k]; + if ( diff > diffThreshold and diff < distEpsilon ) { + distEpsilon = diff; + } + } + distEpsilon /= 3.0; + + BoundMatchOracle oracle(A, B, distEpsilon, useRangeSearch); + // binary search + size_t iterNum {0}; + size_t idxMin {0}, idxMax {pairwiseDist.size() - 1}; + size_t idxMid; + while(idxMax > idxMin) { + idxMid = static_cast(floor(idxMin + idxMax) / 2.0); + iterNum++; + // not A[imid] < dist <=> A[imid] >= dist <=> A[imid[ >= dist + eps + if (oracle.isMatchLess(pairwiseDist[idxMid] + distEpsilon / 2.0)) { + idxMax = idxMid; + } else { + idxMin = idxMid + 1; + } + } + idxMid = static_cast(floor(idxMin + idxMax) / 2.0); + return pairwiseDist[idxMid]; +} + + +template +Real bottleneckDistExact(DiagramPointSet& A, DiagramPointSet& B) +{ + return bottleneckDistExact(A, B, 14); +} + +template +Real bottleneckDistExact(DiagramPointSet& A, DiagramPointSet& B, const int decPrecision) +{ + using DgmPoint = DiagramPoint; + + constexpr Real epsilon = 0.001; + auto interval = bottleneckDistApproxInterval(A, B, epsilon); + if (interval.first == interval.second) + return interval.first; + const Real delta = 0.50001 * (interval.second - interval.first); + const Real approxDist = 0.5 * ( interval.first + interval.second); + const Real minDist = interval.first; + const Real maxDist = interval.second; + if ( delta == 0 ) { + return interval.first; + } + // copy points from A to a vector + // todo: get rid of this? + std::vector pointsA; + pointsA.reserve(A.size()); + for(const auto& ptA : A) { + pointsA.push_back(ptA); + } + + // in this vector we store the distances between the points + // that are candidates to realize + std::vector pairwiseDist; + { + // vector to store centers of vertical stripes + // two for each point in A and the id of the corresponding point + std::vector> xCentersVec; + xCentersVec.reserve(2 * pointsA.size()); + for(auto ptA : pointsA) { + xCentersVec.push_back(std::make_pair(ptA.getRealX() - approxDist, ptA)); + xCentersVec.push_back(std::make_pair(ptA.getRealX() + approxDist, ptA)); + } + // lambda to compare pairs w.r.t coordinate + auto compLambda = [](std::pair a, std::pair b) + { return a.first < b.first; }; + + std::sort(xCentersVec.begin(), xCentersVec.end(), compLambda); + // todo: sort points in B, reduce search range in lower and upper bounds + for(auto ptB : B) { + // iterator to the first stripe such that ptB lies to the left + // from its right boundary (x_B <= x_j + \delta iff x_j >= x_B - \delta + auto itStart = std::lower_bound(xCentersVec.begin(), + xCentersVec.end(), + std::make_pair(ptB.getRealX() - delta, ptB), + compLambda); + + for(auto iterA = itStart; iterA < xCentersVec.end(); ++iterA) { + if ( ptB.getRealX() < iterA->first - delta) { + // from that moment x_B >= x_j - delta + // is violated: x_B no longer lies to right from the left + // boundary of current stripe + break; + } + // we're here => ptB lies in vertical stripe, + // check if distance fits into the interval we've found + Real pwDist = distLInf(iterA->second, ptB); + if (pwDist >= minDist and pwDist <= maxDist) { + pairwiseDist.push_back(pwDist); + } + } + } + } + + { + // for y + // vector to store centers of vertical stripes + // two for each point in A and the id of the corresponding point + std::vector> yCentersVec; + yCentersVec.reserve(2 * pointsA.size()); + for(auto ptA : pointsA) { + yCentersVec.push_back(std::make_pair(ptA.getRealY() - approxDist, ptA)); + yCentersVec.push_back(std::make_pair(ptA.getRealY() + approxDist, ptA)); + } + // lambda to compare pairs w.r.t coordinate + auto compLambda = [](std::pair a, std::pair b) + { return a.first < b.first; }; + + std::sort(yCentersVec.begin(), yCentersVec.end(), compLambda); + + // todo: sort points in B, reduce search range in lower and upper bounds + for(auto ptB : B) { + auto itStart = std::lower_bound(yCentersVec.begin(), + yCentersVec.end(), + std::make_pair(ptB.getRealY() - delta, ptB), + compLambda); + + + for(auto iterA = itStart; iterA < yCentersVec.end(); ++iterA) { + if ( ptB.getRealY() < iterA->first - delta) { + break; + } + Real pwDist = distLInf(iterA->second, ptB); + if (pwDist >= minDist and pwDist <= maxDist) { + pairwiseDist.push_back(pwDist); + } + } + } + } + + std::sort(pairwiseDist.begin(), pairwiseDist.end()); + + return bottleneckDistExactFromSortedPwDist(A, B, pairwiseDist, decPrecision); +} + +} // end namespace bt +} // end namespace hera +#endif // HERA_BOTTLENECK_HPP diff --git a/src/dionysus/bottleneck/bound_match.h b/src/dionysus/bottleneck/bound_match.h new file mode 100755 index 0000000..770c7df --- /dev/null +++ b/src/dionysus/bottleneck/bound_match.h @@ -0,0 +1,107 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + +*/ + +#ifndef HERA_BOUND_MATCH_H +#define HERA_BOUND_MATCH_H + +#include +#include + +#include "basic_defs_bt.h" +#include "neighb_oracle.h" + + +namespace hera { +namespace bt { + +template +class Matching { +public: + using DgmPoint = DiagramPoint; + using DgmPointSet = DiagramPointSet; + using DgmPointHash = DiagramPointHash; + using Path = std::vector; + + Matching(const DgmPointSet& AA, const DgmPointSet& BB) : A(AA), B(BB) {}; + DgmPointSet getExposedVertices(bool forA = true) const ; + bool isExposed(const DgmPoint& p) const; + void getAllAdjacentVertices(const DgmPointSet& setIn, DgmPointSet& setOut, bool forA = true) const; + void increase(const Path& augmentingPath); + void checkAugPath(const Path& augPath) const; + bool getMatchedVertex(const DgmPoint& p, DgmPoint& result) const; + bool isPerfect() const; + void trimMatching(const Real newThreshold); +#ifndef FOR_R_TDA + template + friend std::ostream& operator<<(std::ostream& output, const Matching& m); +#endif +private: + DgmPointSet A; + DgmPointSet B; + std::unordered_map AToB, BToA; + void matchVertices(const DgmPoint& pA, const DgmPoint& pB); + void sanityCheck() const; +}; + + + +template> +class BoundMatchOracle { +public: + using Real = Real_; + using NeighbOracle = NeighbOracle_; + using DgmPoint = DiagramPoint; + using DgmPointSet = DiagramPointSet; + using Path = std::vector; + + BoundMatchOracle(DgmPointSet psA, DgmPointSet psB, Real dEps, bool useRS = true); + bool isMatchLess(Real r); + bool buildMatchingForThreshold(const Real r); +private: + DgmPointSet A, B; + Matching M; + void printLayerGraph(); + void buildLayerGraph(Real r); + void buildLayerOracles(Real r); + bool buildAugmentingPath(const DgmPoint startVertex, Path& result); + void removeFromLayer(const DgmPoint& p, const int layerIdx); + std::unique_ptr neighbOracle; + bool augPathExist; + std::vector layerGraph; + std::vector> layerOracles; + Real distEpsilon; + bool useRangeSearch; + Real prevQueryValue; +}; + +} // end namespace bt +} // end namespace hera + +#include "bound_match.hpp" + +#endif // HERA_BOUND_MATCH_H diff --git a/src/dionysus/bottleneck/bound_match.hpp b/src/dionysus/bottleneck/bound_match.hpp new file mode 100755 index 0000000..221bd0f --- /dev/null +++ b/src/dionysus/bottleneck/bound_match.hpp @@ -0,0 +1,473 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + +*/ + +#ifndef HERA_BOUND_MATCH_HPP +#define HERA_BOUND_MATCH_HPP + + +#ifdef FOR_R_TDA +#undef DEBUG_BOUND_MATCH +#undef DEBUG_MATCHING +#undef VERBOSE_BOTTLENECK +#endif + +#include +#include "def_debug_bt.h" +#include "bound_match.h" + +#ifdef VERBOSE_BOTTLENECK +#include +#endif + +#ifndef FOR_R_TDA +#include +#endif + +namespace hera{ +namespace bt { + +#ifndef FOR_R_TDA +template +std::ostream& operator<<(std::ostream& output, const Matching& m) +{ + output << "Matching: " << m.AToB.size() << " pairs ("; + if (!m.isPerfect()) { + output << "not"; + } + output << " perfect)" << std::endl; + for(auto atob : m.AToB) { + output << atob.first << " <-> " << atob.second << " distance: " << distLInf(atob.first, atob.second) << std::endl; + } + return output; +} +#endif + +template +void Matching::sanityCheck() const +{ +#ifdef DEBUG_MATCHING + assert( AToB.size() == BToA.size() ); + for(auto aToBPair : AToB) { + auto bToAPair = BToA.find(aToBPair.second); + assert(bToAPair != BToA.end()); + assert( aToBPair.first == bToAPair->second); + } +#endif +} + +template +bool Matching::isPerfect() const +{ + return AToB.size() == A.size(); +} + +template +void Matching::matchVertices(const DgmPoint& pA, const DgmPoint& pB) +{ + assert(A.hasElement(pA)); + assert(B.hasElement(pB)); + AToB.erase(pA); + AToB.insert( {{ pA, pB }} ); + BToA.erase(pB); + BToA.insert( {{ pB, pA }} ); +} + +template +bool Matching::getMatchedVertex(const DgmPoint& p, DgmPoint& result) const +{ + sanityCheck(); + auto inA = AToB.find(p); + if (inA != AToB.end()) { + result = inA->second; + return true; + } else { + auto inB = BToA.find(p); + if (inB != BToA.end()) { + result = inB->second; + return true; + } + } + return false; +} + + +template +void Matching::checkAugPath(const Path& augPath) const +{ + assert(augPath.size() % 2 == 0); + for(size_t idx = 0; idx < augPath.size(); ++idx) { + bool mustBeExposed { idx == 0 or idx == augPath.size() - 1 }; + if (isExposed(augPath[idx]) != mustBeExposed) { +#ifndef FOR_R_TDA + std::cerr << "mustBeExposed = " << mustBeExposed << ", idx = " << idx << ", point " << augPath[idx] << std::endl; +#endif + } + assert( isExposed(augPath[idx]) == mustBeExposed ); + DgmPoint matchedVertex {0.0, 0.0, DgmPoint::DIAG, 1}; + if ( idx % 2 == 0 ) { + assert( A.hasElement(augPath[idx])); + if (not mustBeExposed) { + getMatchedVertex(augPath[idx], matchedVertex); + assert(matchedVertex == augPath[idx - 1]); + } + } else { + assert( B.hasElement(augPath[idx])); + if (not mustBeExposed) { + getMatchedVertex(augPath[idx], matchedVertex); + assert(matchedVertex == augPath[idx + 1]); + } + } + } +} + +// use augmenting path to increase +// the size of the matching +template +void Matching::increase(const Path& augPath) +{ + sanityCheck(); + // check that augPath is an augmenting path + checkAugPath(augPath); + for(size_t idx = 0; idx < augPath.size() - 1; idx += 2) { + matchVertices( augPath[idx], augPath[idx + 1]); + } + sanityCheck(); +} + +template +DiagramPointSet Matching::getExposedVertices(bool forA) const +{ + sanityCheck(); + DgmPointSet result; + const DgmPointSet* setToSearch { forA ? &A : &B }; + const std::unordered_map* mapToSearch { forA ? &AToB : &BToA }; + for(auto it = setToSearch->cbegin(); it != setToSearch->cend(); ++it) { + if (mapToSearch->find((*it)) == mapToSearch->cend()) { + result.insert((*it)); + } + } + return result; +} + +template +void Matching::getAllAdjacentVertices(const DgmPointSet& setIn, + DgmPointSet& setOut, + bool forA) const +{ + sanityCheck(); + //bool isDebug {false}; + setOut.clear(); + const std::unordered_map* m; + m = ( forA ) ? &BToA : &AToB; + for(auto pit = setIn.cbegin(); pit != setIn.cend(); ++pit) { + auto findRes = m->find(*pit); + if (findRes != m->cend()) { + setOut.insert((*findRes).second); + } + } + sanityCheck(); +} + +template +bool Matching::isExposed(const DgmPoint& p) const +{ + return ( AToB.find(p) == AToB.end() ) && ( BToA.find(p) == BToA.end() ); +} + +// remove all edges whose length is > newThreshold +template +void Matching::trimMatching(const R newThreshold) +{ + //bool isDebug { false }; + sanityCheck(); + for(auto aToBIter = AToB.begin(); aToBIter != AToB.end(); ) { + if ( distLInf(aToBIter->first, aToBIter->second) > newThreshold ) { + // remove edge from AToB and BToA + BToA.erase(aToBIter->second); + aToBIter = AToB.erase(aToBIter); + } else { + aToBIter++; + } + } + sanityCheck(); +} + +// ------- BoundMatchOracle -------------- + +template +BoundMatchOracle::BoundMatchOracle(DgmPointSet psA, DgmPointSet psB, + Real dEps, bool useRS) : + A(psA), B(psB), M(A, B), distEpsilon(dEps), useRangeSearch(useRS), prevQueryValue(0.0) +{ + neighbOracle = std::unique_ptr(new NeighbOracle(psB, 0, distEpsilon)); +} + +template +bool BoundMatchOracle::isMatchLess(Real r) +{ +#ifdef VERBOSE_BOTTLENECK + std::chrono::high_resolution_clock hrClock; + std::chrono::time_point startMoment; + startMoment = hrClock.now(); +#endif + bool result = buildMatchingForThreshold(r); +#ifdef VERBOSE_BOTTLENECK + auto endMoment = hrClock.now(); + std::chrono::duration iterTime = endMoment - startMoment; + std::cout << "isMatchLess for r = " << r << " finished in " << std::chrono::duration(iterTime).count() << " ms." << std::endl; +#endif + return result; + +} + + +template +void BoundMatchOracle::removeFromLayer(const DgmPoint& p, const int layerIdx) { + //bool isDebug {false}; + layerGraph[layerIdx].erase(p); + if (layerOracles[layerIdx]) { + layerOracles[layerIdx]->deletePoint(p); + } +} + +// return true, if there exists an augmenting path from startVertex +// in this case the path is returned in result. +// startVertex must be an exposed vertex from L_1 (layer[0]) +template +bool BoundMatchOracle::buildAugmentingPath(const DgmPoint startVertex, Path& result) +{ + //bool isDebug {false}; + DgmPoint prevVertexA = startVertex; + result.clear(); + result.push_back(startVertex); + size_t evenLayerIdx {1}; + while ( evenLayerIdx < layerGraph.size() ) { + //for(size_t evenLayerIdx = 1; evenLayerIdx < layerGraph.size(); evenLayerIdx += 2) { + DgmPoint nextVertexB{0.0, 0.0, DgmPoint::DIAG, 1}; // next vertex from even layer + bool neighbFound = layerOracles[evenLayerIdx]->getNeighbour(prevVertexA, nextVertexB); + if (neighbFound) { + result.push_back(nextVertexB); + if ( layerGraph.size() == evenLayerIdx + 1) { + break; + } else { + // nextVertexB must be matched with some vertex from the next odd + // layer + DgmPoint nextVertexA {0.0, 0.0, DgmPoint::DIAG, 1}; + if (!M.getMatchedVertex(nextVertexB, nextVertexA)) { +#ifndef FOR_R_TDA + std::cerr << "Vertices in even layers must be matched! Unmatched: "; + std::cerr << nextVertexB << std::endl; + std::cerr << evenLayerIdx << "; " << layerGraph.size() << std::endl; +#endif + throw std::runtime_error("Unmatched vertex in even layer"); + } else { + assert( ! (nextVertexA.getRealX() == 0 and nextVertexA.getRealY() == 0) ); + result.push_back(nextVertexA); + prevVertexA = nextVertexA; + evenLayerIdx += 2; + continue; + } + } + } else { + // prevVertexA has no neighbours in the next layer, + // backtrack + if (evenLayerIdx == 1) { + // startVertex is not connected to any vertices + // in the next layer, augm. path doesn't exist + removeFromLayer(startVertex, 0); + return false; + } else { + assert(evenLayerIdx >= 3); + assert(result.size() % 2 == 1); + result.pop_back(); + DgmPoint prevVertexB = result.back(); + result.pop_back(); + removeFromLayer(prevVertexA, evenLayerIdx-1); + removeFromLayer(prevVertexB, evenLayerIdx-2); + // we should proceed from the previous odd layer + assert(result.size() >= 1); + prevVertexA = result.back(); + evenLayerIdx -= 2; + continue; + } + } + } // finished iterating over all layers + // remove all vertices in the augmenting paths + // the corresponding layers + for(size_t layerIdx = 0; layerIdx < result.size(); ++layerIdx) { + removeFromLayer(result[layerIdx], layerIdx); + } + return true; +} + + + + +template +bool BoundMatchOracle::buildMatchingForThreshold(const Real r) +{ + //bool isDebug {false}; + if (prevQueryValue > r) { + M.trimMatching(r); + } + prevQueryValue = r; + while(true) { + buildLayerGraph(r); + if (augPathExist) { + std::vector augmentingPaths; + DgmPointSet copyLG0; + for(DgmPoint p : layerGraph[0]) { + copyLG0.insert(p); + } + for(DgmPoint exposedVertex : copyLG0) { + Path augPath; + if (buildAugmentingPath(exposedVertex, augPath)) { + augmentingPaths.push_back(augPath); + } + } + if (augmentingPaths.empty()) { +#ifndef FOR_R_TDA + std::cerr << "augmenting paths must exist, but were not found!" << std::endl; +#endif + throw std::runtime_error("bad epsilon?"); + } + // swap all augmenting paths with matching to increase it + for(auto& augPath : augmentingPaths ) { + M.increase(augPath); + } + } else { + return M.isPerfect(); + } + } +} + +template +void BoundMatchOracle::printLayerGraph(void) +{ +#ifdef DEBUG_BOUND_MATCH + for(auto& layer : layerGraph) { + std::cout << "{ "; + for(auto& p : layer) { + std::cout << p << "; "; + } + std::cout << "\b\b }" << std::endl; + } +#endif +} + +template +void BoundMatchOracle::buildLayerGraph(Real r) +{ +#ifdef VERBOSE_BOTTLENECK + std::cout << "Entered buildLayerGraph, r = " << r << std::endl; +#endif + layerGraph.clear(); + DgmPointSet L1 = M.getExposedVertices(); + layerGraph.push_back(L1); + neighbOracle->rebuild(B, r); + size_t k = 0; + DgmPointSet layerNextEven; + DgmPointSet layerNextOdd; + bool exposedVerticesFound {false}; + while(true) { + layerNextEven.clear(); + for( auto p : layerGraph[k]) { + bool neighbFound; + DgmPoint neighbour {0.0, 0.0, DgmPoint::DIAG, 1}; + if (useRangeSearch) { + std::vector neighbVec; + neighbOracle->getAllNeighbours(p, neighbVec); + neighbFound = !neighbVec.empty(); + for(auto& neighbPt : neighbVec) { + layerNextEven.insert(neighbPt); + if (!exposedVerticesFound and M.isExposed(neighbPt)) + exposedVerticesFound = true; + } + } else { + while(true) { + neighbFound = neighbOracle->getNeighbour(p, neighbour); + if (neighbFound) { + layerNextEven.insert(neighbour); + neighbOracle->deletePoint(neighbour); + if ((!exposedVerticesFound) && M.isExposed(neighbour)) { + exposedVerticesFound = true; + } + } else { + break; + } + } + } // without range search + } // all vertices from previous odd layer processed + if (layerNextEven.empty()) { + augPathExist = false; + break; + } + if (exposedVerticesFound) { + for(auto it = layerNextEven.cbegin(); it != layerNextEven.cend(); ) { + if ( ! M.isExposed(*it) ) { + layerNextEven.erase(it++); + } else { + ++it; + } + + } + layerGraph.push_back(layerNextEven); + augPathExist = true; + break; + } + layerGraph.push_back(layerNextEven); + M.getAllAdjacentVertices(layerNextEven, layerNextOdd); + layerGraph.push_back(layerNextOdd); + k += 2; + } + buildLayerOracles(r); + printLayerGraph(); + } + +// create geometric oracles for each even layer +// odd layers have NULL in layerOracles +template +void BoundMatchOracle::buildLayerOracles(Real r) +{ + //bool isDebug {false}; + // free previously constructed oracles + layerOracles.clear(); + for(size_t layerIdx = 0; layerIdx < layerGraph.size(); ++layerIdx) { + if (layerIdx % 2 == 1) { + // even layer, build actual oracle + layerOracles.emplace_back(new NeighbOracle(layerGraph[layerIdx], r, distEpsilon)); + } else { + // odd layer + layerOracles.emplace_back(nullptr); + } + } +} + +} // end namespace bt +} // end namespace hera +#endif // HERA_BOUND_MATCH_HPP diff --git a/src/dionysus/bottleneck/def_debug_bt.h b/src/dionysus/bottleneck/def_debug_bt.h new file mode 100755 index 0000000..21557e7 --- /dev/null +++ b/src/dionysus/bottleneck/def_debug_bt.h @@ -0,0 +1,42 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + +*/ + +#ifndef DEF_DEBUG_BT_H +#define DEF_DEBUG_BT_H + +//#define DEBUG_BOUND_MATCH +//#define DEBUG_NEIGHBOUR_ORACLE +//#define DEBUG_MATCHING +//#define DEBUG_AUCTION +// This symbol should be defined only in the version +// for R package TDA, to comply with some CRAN rules +// like no usage of cout, cerr, cin, exit, etc. +//#define FOR_R_TDA +//#define VERBOSE_BOTTLENECK + +#endif diff --git a/src/dionysus/bottleneck/diagram_reader.h b/src/dionysus/bottleneck/diagram_reader.h new file mode 100755 index 0000000..08d9e2b --- /dev/null +++ b/src/dionysus/bottleneck/diagram_reader.h @@ -0,0 +1,196 @@ +/* +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + +*/ + +#ifndef HERA_DIAGRAM_READER_H +#define HERA_DIAGRAM_READER_H + +#ifndef FOR_R_TDA +#include +#endif + +#include +#include +#include +#include + +namespace hera { + +// cannot choose stod, stof or stold based on RealType, +// lazy solution: partial specialization + template + inline RealType parse_real_from_str(const std::string& s); + + template <> + inline double parse_real_from_str(const std::string& s) + { + return std::stod(s); + } + + + template <> + inline long double parse_real_from_str(const std::string& s) + { + return std::stold(s); + } + + + template <> + inline float parse_real_from_str(const std::string& s) + { + return std::stof(s); + } + + + template + inline RealType parse_real_from_str(const std::string& s) + { + static_assert(sizeof(RealType) != sizeof(RealType), "Must be specialized for each type you want to use, see above"); + } + +// fill in result with points from file fname +// return false if file can't be opened +// or error occurred while reading +// decPrecision is the maximal decimal precision in the input, +// it is zero if all coordinates in the input are integers + + template>> + inline bool readDiagramPointSet(const char* fname, ContType_& result, int& decPrecision) + { + using RealType = RealType_; + + size_t lineNumber { 0 }; + result.clear(); + std::ifstream f(fname); + if (!f.good()) { +#ifndef FOR_R_TDA + std::cerr << "Cannot open file " << fname << std::endl; +#endif + return false; + } + std::locale loc; + std::string line; + while(std::getline(f, line)) { + lineNumber++; + // process comments: remove everything after hash + auto hashPos = line.find_first_of("#", 0); + if( std::string::npos != hashPos) { + line = std::string(line.begin(), line.begin() + hashPos); + } + if (line.empty()) { + continue; + } + // trim whitespaces + auto whiteSpaceFront = std::find_if_not(line.begin(),line.end(),isspace); + auto whiteSpaceBack = std::find_if_not(line.rbegin(),line.rend(),isspace).base(); + if (whiteSpaceBack <= whiteSpaceFront) { + // line consists of spaces only - move to the next line + continue; + } + line = std::string(whiteSpaceFront,whiteSpaceBack); + + // transform line to lower case + // to parse Infinity + for(auto& c : line) { + c = std::tolower(c, loc); + } + + bool fracPart = false; + int currDecPrecision = 0; + for(auto c : line) { + if (c == '.') { + fracPart = true; + } else if (fracPart) { + if (isdigit(c)) { + currDecPrecision++; + } else { + fracPart = false; + if (currDecPrecision > decPrecision) + decPrecision = currDecPrecision; + currDecPrecision = 0; + } + } + } + + RealType x, y; + std::string str_x, str_y; + std::istringstream iss(line); + try { + iss >> str_x >> str_y; + + x = parse_real_from_str(str_x); + y = parse_real_from_str(str_y); + + if (x != y) { + result.push_back(std::make_pair(x, y)); + } else { +#ifndef FOR_R_TDA + std::cerr << "Warning: point with 0 persistence ignored in " << fname << ":" << lineNumber << "\n"; +#endif + } + } + catch (const std::invalid_argument& e) { +#ifndef FOR_R_TDA + std::cerr << "Error in file " << fname << ", line number " << lineNumber << ": cannot parse \"" << line << "\"" << std::endl; +#endif + return false; + } + catch (const std::out_of_range&) { +#ifndef FOR_R_TDA + std::cerr << "Error while reading file " << fname << ", line number " << lineNumber << ": value too large in \"" << line << "\"" << std::endl; +#endif + return false; + } + } + f.close(); + return true; + } + + // wrappers + template>> + inline bool readDiagramPointSet(const std::string& fname, ContType_& result, int& decPrecision) + { + return readDiagramPointSet(fname.c_str(), result, decPrecision); + } + + // these two functions are now just wrappers for the previous ones, + // in case someone needs them; decPrecision is ignored + template>> + inline bool readDiagramPointSet(const char* fname, ContType_& result) + { + int decPrecision; + return readDiagramPointSet(fname, result, decPrecision); + } + + template>> + inline bool readDiagramPointSet(const std::string& fname, ContType_& result) + { + int decPrecision; + return readDiagramPointSet(fname.c_str(), result, decPrecision); + } + +} // end namespace hera +#endif // HERA_DIAGRAM_READER_H diff --git a/src/dionysus/bottleneck/diagram_traits.h b/src/dionysus/bottleneck/diagram_traits.h new file mode 100755 index 0000000..c8d4862 --- /dev/null +++ b/src/dionysus/bottleneck/diagram_traits.h @@ -0,0 +1,45 @@ +#ifndef HERA_DIAGRAM_TRAITS_H +#define HERA_DIAGRAM_TRAITS_H + +namespace hera { + +template().begin())>::type > +struct DiagramTraits +{ + using Container = PairContainer_; + using PointType = PointType_; + using RealType = typename std::remove_reference< decltype(std::declval()[0]) >::type; + + static RealType get_x(const PointType& p) { return p[0]; } + static RealType get_y(const PointType& p) { return p[1]; } +}; + + +template +struct DiagramTraits> +{ + using PointType = std::pair; + using RealType = double; + using Container = std::vector; + + static RealType get_x(const PointType& p) { return p.first; } + static RealType get_y(const PointType& p) { return p.second; } +}; + + +template +struct DiagramTraits> +{ + using PointType = std::pair; + using RealType = float; + using Container = std::vector; + + static RealType get_x(const PointType& p) { return p.first; } + static RealType get_y(const PointType& p) { return p.second; } +}; + + +} // end namespace hera + + +#endif // HERA_DIAGRAM_TRAITS_H diff --git a/src/dionysus/bottleneck/dnn/geometry/euclidean-fixed.h b/src/dionysus/bottleneck/dnn/geometry/euclidean-fixed.h new file mode 100755 index 0000000..f45b980 --- /dev/null +++ b/src/dionysus/bottleneck/dnn/geometry/euclidean-fixed.h @@ -0,0 +1,162 @@ +#ifndef HERA_BT_DNN_GEOMETRY_EUCLIDEAN_FIXED_H +#define HERA_BT_DNN_GEOMETRY_EUCLIDEAN_FIXED_H + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../parallel/tbb.h" // for dnn::vector<...> + +namespace hera { +namespace bt { +namespace dnn +{ + // TODO: wrap in another namespace (e.g., euclidean) + + template + struct Point: + boost::addable< Point, + boost::subtractable< Point, + boost::dividable2< Point, Real, + boost::multipliable2< Point, Real > > > >, + public boost::array + { + public: + typedef Real Coordinate; + typedef Real DistanceType; + + + public: + Point(size_t id = 0): id_(id) {} + template + Point(const Point& p, size_t id = 0): + id_(id) { *this = p; } + + static size_t dimension() { return D; } + + // Assign a point of different dimension + template + Point& operator=(const Point& p) { for (size_t i = 0; i < (D < DD ? D : DD); ++i) (*this)[i] = p[i]; if (DD < D) for (size_t i = DD; i < D; ++i) (*this)[i] = 0; return *this; } + + Point& operator+=(const Point& p) { for (size_t i = 0; i < D; ++i) (*this)[i] += p[i]; return *this; } + Point& operator-=(const Point& p) { for (size_t i = 0; i < D; ++i) (*this)[i] -= p[i]; return *this; } + Point& operator/=(Real r) { for (size_t i = 0; i < D; ++i) (*this)[i] /= r; return *this; } + Point& operator*=(Real r) { for (size_t i = 0; i < D; ++i) (*this)[i] *= r; return *this; } + + Real norm2() const { Real n = 0; for (size_t i = 0; i < D; ++i) n += (*this)[i] * (*this)[i]; return n; } + Real max_norm() const + { + Real m = std::fabs((*this)[0]); + for (size_t i = 1; i < D; ++i) + if (std::fabs((*this)[i]) > m) + m = std::fabs((*this)[i]); + return m; } + + // quick and dirty for now; make generic later + //DistanceType distance(const Point& other) const { return sqrt(sq_distance(other)); } + //DistanceType sq_distance(const Point& other) const { return (other - *this).norm2(); } + + DistanceType distance(const Point& other) const { return (other - *this).max_norm(); } + DistanceType sq_distance(const Point& other) const { DistanceType d = distance(other); return d*d; } + + size_t id() const { return id_; } + size_t& id() { return id_; } + + private: + friend class boost::serialization::access; + + template + void serialize(Archive& ar, const unsigned int version) { ar & boost::serialization::base_object< boost::array >(*this) & id_; } + + private: + size_t id_; + }; + + template + std::ostream& + operator<<(std::ostream& out, const Point& p) + { out << p[0]; for (size_t i = 1; i < D; ++i) out << " " << p[i]; return out; } + + + template + struct PointTraits; // intentionally undefined; should be specialized for each type + + template + struct PointTraits< Point > // specialization for dnn::Point + { + typedef Point PointType; + typedef const PointType* PointHandle; + typedef std::vector PointContainer; + + typedef typename PointType::Coordinate Coordinate; + typedef typename PointType::DistanceType DistanceType; + + static DistanceType + distance(const PointType& p1, const PointType& p2) { return p1.distance(p2); } + static DistanceType + distance(PointHandle p1, PointHandle p2) { return distance(*p1,*p2); } + static DistanceType + sq_distance(const PointType& p1, + const PointType& p2) { return p1.sq_distance(p2); } + static DistanceType + sq_distance(PointHandle p1, PointHandle p2) { return sq_distance(*p1,*p2); } + static size_t dimension() { return D; } + static Real coordinate(const PointType& p, size_t i) { return p[i]; } + static Real& coordinate(PointType& p, size_t i) { return p[i]; } + static Real coordinate(PointHandle p, size_t i) { return coordinate(*p,i); } + + static size_t id(const PointType& p) { return p.id(); } + static size_t& id(PointType& p) { return p.id(); } + static size_t id(PointHandle p) { return id(*p); } + + static PointHandle + handle(const PointType& p) { return &p; } + static const PointType& + point(PointHandle ph) { return *ph; } + + void swap(PointType& p1, PointType& p2) const { return std::swap(p1, p2); } + + static PointContainer + container(size_t n = 0, const PointType& p = PointType()) { return PointContainer(n, p); } + static typename PointContainer::iterator + iterator(PointContainer& c, PointHandle ph) { return c.begin() + (ph - &c[0]); } + static typename PointContainer::const_iterator + iterator(const PointContainer& c, PointHandle ph) { return c.begin() + (ph - &c[0]); } + + private: + friend class boost::serialization::access; + + template + void serialize(Archive& ar, const unsigned int version) {} + }; + + template + void read_points(const std::string& filename, PointContainer& points) + { + typedef typename boost::range_value::type Point; + typedef typename PointTraits::Coordinate Coordinate; + + std::ifstream in(filename.c_str()); + std::string line; + while(std::getline(in, line)) + { + if (line[0] == '#') continue; // comment line in the file + std::stringstream linestream(line); + Coordinate x; + points.push_back(Point()); + size_t i = 0; + while (linestream >> x) + points.back()[i++] = x; + } + } +} // dnn +} // bt +} // hera +#endif diff --git a/src/dionysus/bottleneck/dnn/local/kd-tree.h b/src/dionysus/bottleneck/dnn/local/kd-tree.h new file mode 100755 index 0000000..c1aed2b --- /dev/null +++ b/src/dionysus/bottleneck/dnn/local/kd-tree.h @@ -0,0 +1,106 @@ +#ifndef HERA_BT_DNN_LOCAL_KD_TREE_H +#define HERA_BT_DNN_LOCAL_KD_TREE_H + +#include "../utils.h" +#include "search-functors.h" + +#include +#include + +#include +#include +#include + +#include +#include + +namespace hera { +namespace bt { +namespace dnn +{ + // Weighted KDTree + // Traits_ provides Coordinate, DistanceType, PointType, dimension(), distance(p1,p2), coordinate(p,i) + template< class Traits_ > + class KDTree + { + public: + typedef Traits_ Traits; + typedef hera::bt::dnn::HandleDistance HandleDistance; + + typedef typename Traits::PointType Point; + typedef typename Traits::PointHandle PointHandle; + typedef typename Traits::Coordinate Coordinate; + typedef typename Traits::DistanceType DistanceType; + typedef std::vector HandleContainer; + typedef std::vector HDContainer; // TODO: use tbb::scalable_allocator + typedef HDContainer Result; + typedef std::vector DistanceContainer; + typedef std::unordered_map HandleMap; + //private: + typedef typename HandleContainer::iterator HCIterator; + typedef std::tuple KDTreeNode; + typedef std::tuple KDTreeNodeNoCut; + + //BOOST_STATIC_ASSERT_MSG(has_coordinates::value, "KDTree requires coordinates"); + + public: + KDTree(const Traits& traits): + traits_(traits) {} + + KDTree(const Traits& traits, HandleContainer&& handles); + + template + KDTree(const Traits& traits, const Range& range); + + template + void init(const Range& range); + + HandleDistance find(PointHandle q) const; + Result findR(PointHandle q, DistanceType r) const; // all neighbors within r + Result findFirstR(PointHandle q, DistanceType r) const; // first neighbor within r + Result findK(PointHandle q, size_t k) const; // k nearest neighbors + + HandleDistance find(const Point& q) const { return find(traits().handle(q)); } + Result findR(const Point& q, DistanceType r) const { return findR(traits().handle(q), r); } + Result findFirstR(const Point& q, DistanceType r) const { return findFirstR(traits().handle(q), r); } + Result findK(const Point& q, size_t k) const { return findK(traits().handle(q), k); } + + + + template + void search(PointHandle q, ResultsFunctor& rf) const; + + const Traits& traits() const { return traits_; } + + void get_path_to_root(const size_t idx, std::stack& s); + // to support deletion + void init_n_elems(); + void delete_point(const size_t idx); + void delete_point(PointHandle p); + void update_n_elems(const ssize_t idx, const int delta); + void increase_n_elems(const ssize_t idx); + void decrease_n_elems(const ssize_t idx); + size_t get_num_points() const { return num_points_; } + //private: + void init(); + + + struct CoordinateComparison; + struct OrderTree; + + //private: + Traits traits_; + HandleContainer tree_; + std::vector delete_flags_; + std::vector subtree_n_elems; + HandleMap indices_; + std::vector parents_; + + size_t num_points_; + }; +} // dnn +} // bt +} // hera +#include "kd-tree.hpp" + +#endif diff --git a/src/dionysus/bottleneck/dnn/local/kd-tree.hpp b/src/dionysus/bottleneck/dnn/local/kd-tree.hpp new file mode 100755 index 0000000..249fa55 --- /dev/null +++ b/src/dionysus/bottleneck/dnn/local/kd-tree.hpp @@ -0,0 +1,296 @@ +#include +#include +#include + +#include + +#include "../parallel/tbb.h" + +template +hera::bt::dnn::KDTree::KDTree(const Traits& traits, HandleContainer&& handles): + traits_(traits), + tree_(std::move(handles)), + delete_flags_(handles.size(), static_cast(0) ), + subtree_n_elems(handles.size(), static_cast(0)), + num_points_(handles.size()) +{ + init(); +} + +template +template +hera::bt::dnn::KDTree::KDTree(const Traits& traits, const Range& range): + traits_(traits) +{ + init(range); +} + +template +template +void hera::bt::dnn::KDTree::init(const Range& range) +{ + size_t sz = std::distance(std::begin(range), std::end(range)); + subtree_n_elems = std::vector(sz, 0); + delete_flags_ = std::vector(sz, 0); + num_points_ = sz; + tree_.reserve(sz); + for (PointHandle h : range) + tree_.push_back(h); + parents_.resize(sz, -1); + init(); +} + +template +void hera::bt::dnn::KDTree::init() +{ + if (tree_.empty()) + return; + +#if defined(TBB) + task_group g; + g.run(OrderTree(this, tree_.begin(), tree_.end(), -1, 0, traits())); + g.wait(); +#else + OrderTree(this, tree_.begin(), tree_.end(), -1, 0, traits()).serial(); +#endif + + for (size_t i = 0; i < tree_.size(); ++i) + indices_[tree_[i]] = i; + init_n_elems(); +} + +template +struct +hera::bt::dnn::KDTree::OrderTree +{ + OrderTree(KDTree* tree_, HCIterator b_, HCIterator e_, ssize_t p_, size_t i_, const Traits& traits_): + tree(tree_), b(b_), e(e_), p(p_), i(i_), traits(traits_) {} + + void operator()() const + { + if (e - b < 1000) + { + serial(); + return; + } + + HCIterator m = b + (e - b)/2; + ssize_t im = m - tree->tree_.begin(); + tree->parents_[im] = p; + + CoordinateComparison cmp(i, traits); + std::nth_element(b,m,e, cmp); + size_t next_i = (i + 1) % traits.dimension(); + + task_group g; + if (b < m - 1) g.run(OrderTree(tree, b, m, im, next_i, traits)); + if (e > m + 2) g.run(OrderTree(tree, m+1, e, im, next_i, traits)); + g.wait(); + } + + void serial() const + { + std::queue q; + q.push(KDTreeNode(b,e,p,i)); + while (!q.empty()) + { + HCIterator b, e; ssize_t p; size_t i; + std::tie(b,e,p,i) = q.front(); + q.pop(); + HCIterator m = b + (e - b)/2; + ssize_t im = m - tree->tree_.begin(); + tree->parents_[im] = p; + + CoordinateComparison cmp(i, traits); + std::nth_element(b,m,e, cmp); + size_t next_i = (i + 1) % traits.dimension(); + + // Replace with a size condition instead? + if (b < m - 1) + q.push(KDTreeNode(b, m, im, next_i)); + else if (b < m) + tree->parents_[im - 1] = im; + if (e > m + 2) + q.push(KDTreeNode(m+1, e, im, next_i)); + else if (e > m + 1) + tree->parents_[im + 1] = im; + } + } + + KDTree* tree; + HCIterator b, e; + ssize_t p; + size_t i; + const Traits& traits; +}; + +template +void hera::bt::dnn::KDTree::update_n_elems(ssize_t idx, const int delta) +// add delta to the number of points in node idx and update subtree_n_elems +// for all parents of the node idx +{ + //std::cout << "subtree_n_elems.size = " << subtree_n_elems.size() << std::endl; + // update the node itself + while (idx != -1) + { + //std::cout << idx << std::endl; + subtree_n_elems[idx] += delta; + idx = parents_[idx]; + } +} + +template +void hera::bt::dnn::KDTree::increase_n_elems(const ssize_t idx) +{ + update_n_elems(idx, static_cast(1)); +} + +template +void hera::bt::dnn::KDTree::decrease_n_elems(const ssize_t idx) +{ + update_n_elems(idx, static_cast(-1)); +} + +template +void hera::bt::dnn::KDTree::init_n_elems() +{ + for(size_t idx = 0; idx < tree_.size(); ++idx) { + increase_n_elems(idx); + } +} + + +template +template +void hera::bt::dnn::KDTree::search(PointHandle q, ResultsFunctor& rf) const +{ + typedef typename HandleContainer::const_iterator HCIterator; + typedef std::tuple KDTreeNode; + + if (tree_.empty()) + return; + + DistanceType D = std::numeric_limits::infinity(); + + // TODO: use tbb::scalable_allocator for the queue + std::queue nodes; + + nodes.push(KDTreeNode(tree_.begin(), tree_.end(), 0)); + + //std::cout << "started kdtree::search" << std::endl; + + while (!nodes.empty()) + { + HCIterator b, e; size_t i; + std::tie(b,e,i) = nodes.front(); + nodes.pop(); + + CoordinateComparison cmp(i, traits()); + i = (i + 1) % traits().dimension(); + + HCIterator m = b + (e - b)/2; + size_t m_idx = m - tree_.begin(); + // ignore deleted points + if ( delete_flags_[m_idx] == 0 ) { + DistanceType dist = traits().distance(q, *m); + // + weights_[m - tree_.begin()]; + //std::cout << "Supplied to functor: m : "; + //std::cout << "(" << (*(*m))[0] << ", " << (*(*m))[1] << ")"; + //std::cout << " and q : "; + //std::cout << "(" << (*q)[0] << ", " << (*q)[1] << ")" << std::endl; + //std::cout << "dist^q + weight = " << dist << std::endl; + //std::cout << "weight = " << weights_[m - tree_.begin()] << std::endl; + //std::cout << "dist = " << traits().distance(q, *m) << std::endl; + //std::cout << "dist^q = " << pow(traits().distance(q, *m), wassersteinPower) << std::endl; + + D = rf(*m, dist); + } + // we are really searching w.r.t L_\infty ball; could prune better with an L_2 ball + Coordinate diff = cmp.diff(q, *m); // diff returns signed distance + DistanceType diffToWasserPower = (diff > 0 ? 1.0 : -1.0) * fabs(diff); + + size_t lm = m + 1 + (e - (m+1))/2 - tree_.begin(); + if ( subtree_n_elems[lm] > 0 ) { + if (e > m + 1 && diffToWasserPower >= -D) { + nodes.push(KDTreeNode(m+1, e, i)); + } + } + + size_t rm = b + (m - b) / 2 - tree_.begin(); + if ( subtree_n_elems[rm] > 0 ) { + if (b < m && diffToWasserPower <= D) { + nodes.push(KDTreeNode(b, m, i)); + } + } + } + //std::cout << "exited kdtree::search" << std::endl; +} + +template +typename hera::bt::dnn::KDTree::HandleDistance hera::bt::dnn::KDTree::find(PointHandle q) const +{ + hera::bt::dnn::NNRecord nn; + search(q, nn); + return nn.result; +} + +template +typename hera::bt::dnn::KDTree::Result hera::bt::dnn::KDTree::findR(PointHandle q, DistanceType r) const +{ + hera::bt::dnn::rNNRecord rnn(r); + search(q, rnn); + //std::sort(rnn.result.begin(), rnn.result.end()); + return rnn.result; +} + +template +typename hera::bt::dnn::KDTree::Result hera::bt::dnn::KDTree::findFirstR(PointHandle q, DistanceType r) const +{ + hera::bt::dnn::firstrNNRecord rnn(r); + search(q, rnn); + return rnn.result; +} + +template +typename hera::bt::dnn::KDTree::Result hera::bt::dnn::KDTree::findK(PointHandle q, size_t k) const +{ + hera::bt::dnn::kNNRecord knn(k); + search(q, knn); + // do we need this??? + std::sort(knn.result.begin(), knn.result.end()); + return knn.result; +} + +template +struct hera::bt::dnn::KDTree::CoordinateComparison +{ + CoordinateComparison(size_t i, const Traits& traits): + i_(i), traits_(traits) {} + + bool operator()(PointHandle p1, PointHandle p2) const { return coordinate(p1) < coordinate(p2); } + Coordinate diff(PointHandle p1, PointHandle p2) const { return coordinate(p1) - coordinate(p2); } + + Coordinate coordinate(PointHandle p) const { return traits_.coordinate(p, i_); } + size_t axis() const { return i_; } + + private: + size_t i_; + const Traits& traits_; +}; + +template +void hera::bt::dnn::KDTree::delete_point(const size_t idx) +{ + // prevent double deletion + assert(delete_flags_[idx] == 0); + delete_flags_[idx] = 1; + decrease_n_elems(idx); + --num_points_; +} + +template +void hera::bt::dnn::KDTree::delete_point(PointHandle p) +{ + delete_point(indices_[p]); +} + diff --git a/src/dionysus/bottleneck/dnn/local/search-functors.h b/src/dionysus/bottleneck/dnn/local/search-functors.h new file mode 100755 index 0000000..63ad11d --- /dev/null +++ b/src/dionysus/bottleneck/dnn/local/search-functors.h @@ -0,0 +1,119 @@ +#ifndef HERA_BT_DNN_LOCAL_SEARCH_FUNCTORS_H +#define HERA_BT_DNN_LOCAL_SEARCH_FUNCTORS_H + +#include + +namespace hera +{ +namespace bt +{ +namespace dnn +{ + +template +struct HandleDistance +{ + typedef typename NN::PointHandle PointHandle; + typedef typename NN::DistanceType DistanceType; + typedef typename NN::HDContainer HDContainer; + + HandleDistance() {} + HandleDistance(PointHandle pp, DistanceType dd): + p(pp), d(dd) {} + bool operator<(const HandleDistance& other) const { return d < other.d; } + + PointHandle p; + DistanceType d; +}; + +template +struct NNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + + NNRecord() { result.d = std::numeric_limits::infinity(); } + DistanceType operator()(PointHandle p, DistanceType d) { if (d < result.d) { result.p = p; result.d = d; } return result.d; } + HandleDistance result; +}; + +template +struct rNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + rNNRecord(DistanceType r_): r(r_) {} + DistanceType operator()(PointHandle p, DistanceType d) + { + if (d <= r) + result.push_back(HandleDistance(p,d)); + return r; + } + + DistanceType r; + HDContainer result; +}; + +template +struct firstrNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + firstrNNRecord(DistanceType r_): r(r_) {} + + DistanceType operator()(PointHandle p, DistanceType d) + { + if (d <= r) { + result.push_back(HandleDistance(p,d)); + return -100000000.0; + } else { + return r; + } + } + + DistanceType r; + HDContainer result; +}; + + +template +struct kNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + kNNRecord(unsigned k_): k(k_) {} + DistanceType operator()(PointHandle p, DistanceType d) + { + if (result.size() < k) + { + result.push_back(HandleDistance(p,d)); + boost::push_heap(result); + if (result.size() < k) + return std::numeric_limits::infinity(); + } else if (d < result[0].d) + { + boost::pop_heap(result); + result.back() = HandleDistance(p,d); + boost::push_heap(result); + } + if ( result.size() > 1 ) { + assert( result[0].d >= result[1].d ); + } + return result[0].d; + } + + unsigned k; + HDContainer result; +}; + +} // dnn +} // bt +} // hera + +#endif // HERA_BT_DNN_LOCAL_SEARCH_FUNCTORS_H diff --git a/src/dionysus/bottleneck/dnn/parallel/tbb.h b/src/dionysus/bottleneck/dnn/parallel/tbb.h new file mode 100755 index 0000000..14f0093 --- /dev/null +++ b/src/dionysus/bottleneck/dnn/parallel/tbb.h @@ -0,0 +1,235 @@ +#ifndef HERA_BT_PARALLEL_H +#define HERA_BT_PARALLEL_H + +#ifndef FOR_R_TDA +#include +#endif + +#include + +#include +#include +#include + +#ifdef TBB + +#include +#include +#include + +#include +#include +#include + +namespace hera { +namespace bt { +namespace dnn +{ + using tbb::mutex; + using tbb::task_scheduler_init; + using tbb::task_group; + using tbb::task; + + template + struct vector + { + typedef tbb::concurrent_vector type; + }; + + template + struct atomic + { + typedef tbb::atomic type; + static T compare_and_swap(type& v, T n, T o) { return v.compare_and_swap(n,o); } + }; + + template + void do_foreach(Iterator begin, Iterator end, const F& f) { tbb::parallel_do(begin, end, f); } + + template + void for_each_range_(const Range& r, const F& f) + { + for (typename Range::iterator cur = r.begin(); cur != r.end(); ++cur) + f(*cur); + } + + template + void for_each_range(size_t from, size_t to, const F& f) + { + //static tbb::affinity_partitioner ap; + //tbb::parallel_for(c.range(), boost::bind(&for_each_range_, _1, f), ap); + tbb::parallel_for(from, to, f); + } + + template + void for_each_range(const Container& c, const F& f) + { + //static tbb::affinity_partitioner ap; + //tbb::parallel_for(c.range(), boost::bind(&for_each_range_, _1, f), ap); + tbb::parallel_for(c.range(), boost::bind(&for_each_range_, _1, f)); + } + + template + void for_each_range(Container& c, const F& f) + { + //static tbb::affinity_partitioner ap; + //tbb::parallel_for(c.range(), boost::bind(&for_each_range_, _1, f), ap); + tbb::parallel_for(c.range(), boost::bind(&for_each_range_, _1, f)); + } + + template + struct map_traits + { + typedef tbb::concurrent_hash_map type; + typedef typename type::range_type range; + }; + + struct progress_timer + { + progress_timer(): start(tbb::tick_count::now()) {} + ~progress_timer() + { +#ifndef FOR_R_TDA + std::cout << (tbb::tick_count::now() - start).seconds() << " s" << std::endl; +#endif + } + + tbb::tick_count start; + }; +} +} +} + +// Serialization for tbb::concurrent_vector<...> +namespace boost +{ + namespace serialization + { + template + void save(Archive& ar, const tbb::concurrent_vector& v, const unsigned int file_version) + { stl::save_collection(ar, v); } + + template + void load(Archive& ar, tbb::concurrent_vector& v, const unsigned int file_version) + { + stl::load_collection, + stl::archive_input_seq< Archive, tbb::concurrent_vector >, + stl::reserve_imp< tbb::concurrent_vector > + >(ar, v); + } + + template + void serialize(Archive& ar, tbb::concurrent_vector& v, const unsigned int file_version) + { split_free(ar, v, file_version); } + + template + void save(Archive& ar, const tbb::atomic& v, const unsigned int file_version) + { T v_ = v; ar << v_; } + + template + void load(Archive& ar, tbb::atomic& v, const unsigned int file_version) + { T v_; ar >> v_; v = v_; } + + template + void serialize(Archive& ar, tbb::atomic& v, const unsigned int file_version) + { split_free(ar, v, file_version); } + } +} + +#else + +#include +#include +#include + +namespace hera { +namespace bt { +namespace dnn +{ + template + struct vector + { + typedef ::std::vector type; + }; + + template + struct atomic + { + typedef T type; + static T compare_and_swap(type& v, T n, T o) { if (v != o) return v; v = n; return o; } + }; + + template + void do_foreach(Iterator begin, Iterator end, const F& f) { std::for_each(begin, end, f); } + + template + void for_each_range(size_t from, size_t to, const F& f) + { + for (size_t i = from; i < to; ++i) + f(i); + } + + template + void for_each_range(Container& c, const F& f) + { + BOOST_FOREACH(const typename Container::value_type& i, c) + f(i); + } + + template + void for_each_range(const Container& c, const F& f) + { + BOOST_FOREACH(const typename Container::value_type& i, c) + f(i); + } + + struct mutex + { + struct scoped_lock + { + scoped_lock() {} + scoped_lock(mutex& ) {} + void acquire(mutex& ) const {} + void release() const {} + }; + }; + + struct task_scheduler_init + { + task_scheduler_init(unsigned) {} + void initialize(unsigned) {} + static const unsigned automatic = 0; + static const unsigned deferred = 0; + }; + + struct task_group + { + template + void run(const Functor& f) const { f(); } + void wait() const {} + }; + + template + struct map_traits + { + typedef std::map type; + typedef type range; + }; + + using boost::progress_timer; +} +} +} + +#endif // TBB + +namespace dnn +{ + template + void do_foreach(const Range& range, const F& f) { do_foreach(boost::begin(range), boost::end(range), f); } +} + +#endif diff --git a/src/dionysus/bottleneck/dnn/parallel/utils.h b/src/dionysus/bottleneck/dnn/parallel/utils.h new file mode 100755 index 0000000..9809e77 --- /dev/null +++ b/src/dionysus/bottleneck/dnn/parallel/utils.h @@ -0,0 +1,100 @@ +#ifndef HERA_BT_PARALLEL_UTILS_H +#define HERA_BT_PARALLEL_UTILS_H + +#include "../utils.h" + +namespace hera +{ +namespace bt +{ +namespace dnn +{ + // Assumes rng is synchronized across ranks + template + void shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const SwapFunctor& swap, DataVector empty = DataVector()); + + template + void shuffle(mpi::communicator& world, DataVector& data, RNGType& rng) + { + typedef decltype(data[0]) T; + shuffle(world, data, rng, [](T& x, T& y) { std::swap(x,y); }); + } +} +} +} + +template +void +hera::bt::dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const SwapFunctor& swap, DataVector empty) +{ + // This is not a perfect shuffle: it dishes out data in chunks of 1/size. + // (It can be interpreted as generating a bistochastic matrix by taking the + // sum of size random permutation matrices.) Hopefully, it works for our purposes. + + typedef typename RNGType::result_type RNGResult; + + int size = world.size(); + int rank = world.rank(); + + // Generate local seeds + boost::uniform_int uniform; + RNGResult seed; + for (size_t i = 0; i < size; ++i) + { + RNGResult v = uniform(rng); + if (i == rank) + seed = v; + } + RNGType local_rng(seed); + + // Shuffle local data + hera::bt::dnn::random_shuffle(data.begin(), data.end(), local_rng, swap); + + // Decide how much of our data goes to i-th processor + std::vector out_counts(size); + std::vector ranks(boost::counting_iterator(0), + boost::counting_iterator(size)); + for (size_t i = 0; i < size; ++i) + { + hera::bt::dnn::random_shuffle(ranks.begin(), ranks.end(), rng); + ++out_counts[ranks[rank]]; + } + + // Fill the outgoing array + size_t total = 0; + std::vector< DataVector > outgoing(size, empty); + for (size_t i = 0; i < size; ++i) + { + size_t count = data.size()*out_counts[i]/size; + if (total + count > data.size()) + count = data.size() - total; + + outgoing[i].reserve(count); + for (size_t j = total; j < total + count; ++j) + outgoing[i].push_back(data[j]); + + total += count; + } + + boost::uniform_int uniform_outgoing(0,size-1); // in range [0,size-1] + while(total < data.size()) // send leftover to random processes + { + outgoing[uniform_outgoing(local_rng)].push_back(data[total]); + ++total; + } + data.clear(); + + // Exchange the data + std::vector< DataVector > incoming(size, empty); + mpi::all_to_all(world, outgoing, incoming); + outgoing.clear(); + + // Assemble our data + for(const DataVector& vec : incoming) + for (size_t i = 0; i < vec.size(); ++i) + data.push_back(vec[i]); + hera::bt::dnn::random_shuffle(data.begin(), data.end(), local_rng, swap); + // XXX: the final shuffle is irrelevant for our purposes. But it's also cheap. +} + +#endif diff --git a/src/dionysus/bottleneck/dnn/utils.h b/src/dionysus/bottleneck/dnn/utils.h new file mode 100755 index 0000000..f4ce632 --- /dev/null +++ b/src/dionysus/bottleneck/dnn/utils.h @@ -0,0 +1,47 @@ +#ifndef HERA_BT_DNN_UTILS_H +#define HERA_BT_DNN_UTILS_H + +#include +#include +#include + +namespace hera +{ +namespace bt +{ +namespace dnn +{ + +template +struct has_coordinates +{ + template ().coordinate(std::declval()...) )> + static std::true_type test(int); + + template + static std::false_type test(...); + + static constexpr bool value = decltype(test(0))::value; +}; + +template +void random_shuffle(RandomIt first, RandomIt last, UniformRandomNumberGenerator& g, const SwapFunctor& swap) +{ + size_t n = last - first; + boost::uniform_int uniform(0,n); + for (size_t i = n-1; i > 0; --i) + swap(first[i], first[uniform(g,i+1)]); // picks a random number in [0,i] range +} + +template +void random_shuffle(RandomIt first, RandomIt last, UniformRandomNumberGenerator& g) +{ + typedef decltype(*first) T; + random_shuffle(first, last, g, [](T& x, T& y) { std::swap(x,y); }); +} + +} // dnn +} // bt +} // hera + +#endif diff --git a/src/dionysus/bottleneck/neighb_oracle.h b/src/dionysus/bottleneck/neighb_oracle.h new file mode 100755 index 0000000..c3751b3 --- /dev/null +++ b/src/dionysus/bottleneck/neighb_oracle.h @@ -0,0 +1,295 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + +*/ + +#ifndef HERA_NEIGHB_ORACLE_H +#define HERA_NEIGHB_ORACLE_H + +#include +#include +#include + +#include "basic_defs_bt.h" +#include "dnn/geometry/euclidean-fixed.h" +#include "dnn/local/kd-tree.h" + + + +namespace hera { +namespace bt { + +template +class NeighbOracleSimple +{ +public: + using DgmPoint = DiagramPoint; + using DgmPointSet = DiagramPointSet; + +private: + Real r; + Real distEpsilon; + DgmPointSet pointSet; + +public: + + NeighbOracleSimple() : r(0.0) {} + + NeighbOracleSimple(const DgmPointSet& _pointSet, const Real _r, const Real _distEpsilon) : + r(_r), + distEpsilon(_distEpsilon), + pointSet(_pointSet) + {} + + void deletePoint(const DgmPoint& p) + { + pointSet.erase(p); + } + + void rebuild(const DgmPointSet& S, const double rr) + { + pointSet = S; + r = rr; + } + + bool getNeighbour(const DgmPoint& q, DgmPoint& result) const + { + for(auto pit = pointSet.cbegin(); pit != pointSet.cend(); ++pit) { + if ( distLInf(*pit, q) <= r) { + result = *pit; + return true; + } + } + return false; + } + + void getAllNeighbours(const DgmPoint& q, std::vector& result) + { + result.clear(); + for(const auto& point : pointSet) { + if ( distLInf(point, q) <= r) { + result.push_back(point); + } + } + for(auto& pt : result) { + deletePoint(pt); + } + } + +}; + +template +class NeighbOracleDnn +{ +public: + + using Real = Real_; + using DnnPoint = dnn::Point<2, double>; + using DnnTraits = dnn::PointTraits; + using DgmPoint = DiagramPoint; + using DgmPointSet = DiagramPointSet; + using DgmPointHash = DiagramPointHash; + + Real r; + Real distEpsilon; + std::vector allPoints; + DgmPointSet diagonalPoints; + std::unordered_map pointIdxLookup; + // dnn-stuff + std::unique_ptr> kdtree; + std::vector dnnPoints; + std::vector dnnPointHandles; + std::vector kdtreeItems; + + NeighbOracleDnn(const DgmPointSet& S, const Real rr, const Real dEps) : + kdtree(nullptr) + { + assert(dEps >= 0); + distEpsilon = dEps; + rebuild(S, rr); + } + + + void deletePoint(const DgmPoint& p) + { + auto findRes = pointIdxLookup.find(p); + assert(findRes != pointIdxLookup.end()); + //std::cout << "Deleting point " << p << std::endl; + size_t pointIdx { (*findRes).second }; + //std::cout << "pointIdx = " << pointIdx << std::endl; + diagonalPoints.erase(p, false); + kdtree->delete_point(dnnPointHandles[kdtreeItems[pointIdx]]); + } + + void rebuild(const DgmPointSet& S, const Real rr) + { + //std::cout << "Entered rebuild, r = " << rr << std::endl; + r = rr; + size_t dnnNumPoints = S.size(); + //printDebug(isDebug, "S = ", S); + if (dnnNumPoints > 0) { + pointIdxLookup.clear(); + pointIdxLookup.reserve(S.size()); + allPoints.clear(); + allPoints.reserve(S.size()); + diagonalPoints.clear(); + diagonalPoints.reserve(S.size() / 2); + for(auto pit = S.cbegin(); pit != S.cend(); ++pit) { + allPoints.push_back(*pit); + if (pit->isDiagonal()) { + diagonalPoints.insert(*pit); + } + } + + size_t pointIdx = 0; + for(auto& dataPoint : allPoints) { + pointIdxLookup.insert( { dataPoint, pointIdx++ } ); + } + + size_t dnnItemIdx { 0 }; + size_t trueIdx { 0 }; + dnnPoints.clear(); + kdtreeItems.clear(); + dnnPointHandles.clear(); + dnnPoints.clear(); + kdtreeItems.reserve(S.size() ); + // store normal items in kd-tree + for(const auto& g : allPoints) { + if (true) { + kdtreeItems[trueIdx] = dnnItemIdx; + // index of items is id of dnn-point + DnnPoint p(trueIdx); + p[0] = g.getRealX(); + p[1] = g.getRealY(); + dnnPoints.push_back(p); + assert(dnnItemIdx == dnnPoints.size() - 1); + dnnItemIdx++; + } + trueIdx++; + } + assert(dnnPoints.size() == allPoints.size() ); + for(size_t i = 0; i < dnnPoints.size(); ++i) { + dnnPointHandles.push_back(&dnnPoints[i]); + } + DnnTraits traits; + //std::cout << "kdtree: " << dnnPointHandles.size() << " points" << std::endl; + kdtree.reset(new dnn::KDTree(traits, dnnPointHandles)); + } + } + + + bool getNeighbour(const DgmPoint& q, DgmPoint& result) const + { + //std::cout << "getNeighbour for q = " << q << ", r = " << r << std::endl; + //std::cout << *this << std::endl; + // distance between two diagonal points + // is 0 + if (q.isDiagonal()) { + if (!diagonalPoints.empty()) { + result = *diagonalPoints.cbegin(); + //std::cout << "Neighbour found in diagonal points, res = " << result; + return true; + } + } + // check if kdtree is not empty + if (0 == kdtree->get_num_points() ) { + //std::cout << "empty tree, no neighb." << std::endl; + return false; + } + // if no neighbour found among diagonal points, + // search in kd_tree + DnnPoint queryPoint; + queryPoint[0] = q.getRealX(); + queryPoint[1] = q.getRealY(); + auto kdtreeResult = kdtree->findFirstR(queryPoint, r); + if (kdtreeResult.empty()) { + //std::cout << "no neighbour within " << r << "found." << std::endl; + return false; + } + if (kdtreeResult[0].d <= r + distEpsilon) { + result = allPoints[kdtreeResult[0].p->id()]; + //std::cout << "Neighbour found with kd-tree, index = " << kdtreeResult[0].p->id() << std::endl; + //std::cout << "result = " << result << std::endl; + return true; + } + //std::cout << "No neighbour found for r = " << r << std::endl; + return false; + } + + + + void getAllNeighbours(const DgmPoint& q, std::vector& result) + { + //std::cout << "Entered getAllNeighbours for q = " << q << std::endl; + result.clear(); + // add diagonal points, if necessary + if ( q.isDiagonal() ) { + for( auto& diagPt : diagonalPoints ) { + result.push_back(diagPt); + } + } + // delete diagonal points we found + // to prevent finding them again + for(auto& pt : result) { + //std::cout << "deleting DIAG point pt = " << pt << std::endl; + deletePoint(pt); + } + size_t diagOffset = result.size(); + std::vector pointIndicesOut; + // perorm range search on kd-tree + DnnPoint queryPoint; + queryPoint[0] = q.getRealX(); + queryPoint[1] = q.getRealY(); + auto kdtreeResult = kdtree->findR(queryPoint, r); + pointIndicesOut.reserve(kdtreeResult.size()); + for(auto& handleDist : kdtreeResult) { + if (handleDist.d <= r + distEpsilon) { + pointIndicesOut.push_back(handleDist.p->id()); + } else { + break; + } + } + // get actual points in result + for(auto& ptIdx : pointIndicesOut) { + result.push_back(allPoints[ptIdx]); + } + // delete all points we found + for(auto ptIt = result.begin() + diagOffset; ptIt != result.end(); ++ptIt) { + //printDebug(isDebug, "deleting point pt = ", *ptIt); + deletePoint(*ptIt); + } + } + + //DgmPointSet originalPointSet; + template + friend std::ostream& operator<<(std::ostream& out, const NeighbOracleDnn& oracle); + +}; + +} // end namespace bt +} // end namespace hera + +#endif // HERA_NEIGHB_ORACLE_H diff --git a/src/dionysus/wasserstein/auction_oracle.h b/src/dionysus/wasserstein/auction_oracle.h new file mode 100755 index 0000000..d285a1f --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle.h @@ -0,0 +1,40 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef HERA_AUCTION_ORACLE_H +#define HERA_AUCTION_ORACLE_H + +// all oracle classes are in separate h-hpp files +// this file comprises all of them + +#include "auction_oracle_base.h" +#include "auction_oracle_kdtree_restricted.h" +#include "auction_oracle_kdtree_single_diag.h" +#include "auction_oracle_stupid_sparse_restricted.h" + +#endif // HERA_AUCTION_ORACLE_H diff --git a/src/dionysus/wasserstein/auction_oracle_base.h b/src/dionysus/wasserstein/auction_oracle_base.h new file mode 100755 index 0000000..08eaf00 --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle_base.h @@ -0,0 +1,85 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef AUCTION_ORACLE_BASE_H +#define AUCTION_ORACLE_BASE_H + +#include +#include +#include +#include + +#include "basic_defs_ws.h" + +namespace hera { +namespace ws { + + +template +struct DebugOptimalBid { + DebugOptimalBid() : best_item_idx(k_invalid_index), best_item_value(-666.666), second_best_item_idx(k_invalid_index), second_best_item_value(-666.666) {}; + IdxType best_item_idx; + Real best_item_value; + IdxType second_best_item_idx; + Real second_best_item_value; +}; + +template >> +struct AuctionOracleBase { + AuctionOracleBase(const PointContainer_& _bidders, const PointContainer_& _items, const AuctionParams& params); + ~AuctionOracleBase() {} + Real get_epsilon() const { return epsilon; }; + void set_epsilon(Real new_epsilon) { assert(new_epsilon >= 0.0); epsilon = new_epsilon; }; + const std::vector& get_prices() const { return prices; } + virtual Real get_price(const size_t item_idx) const { return prices[item_idx]; } // TODO make virtual? +//protected: + const PointContainer_& bidders; + const PointContainer_& items; + const size_t num_bidders_; + const size_t num_items_; + std::vector prices; + const Real wasserstein_power; + Real epsilon; + const Real internal_p; + unsigned int dim; // used only in pure geometric version, not for persistence diagrams + Real get_value_for_bidder(size_t bidder_idx, size_t item_idx) const; + Real get_value_for_diagonal_bidder(size_t item_idx) const; + Real get_cost_for_diagonal_bidder(size_t item_idx) const; +}; + + +template +std::ostream& operator<< (std::ostream& output, const DebugOptimalBid& db); + +} // ws +} // hera + + +#include "auction_oracle_base.hpp" + +#endif diff --git a/src/dionysus/wasserstein/auction_oracle_base.hpp b/src/dionysus/wasserstein/auction_oracle_base.hpp new file mode 100755 index 0000000..b74c7fb --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle_base.hpp @@ -0,0 +1,97 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef AUCTION_ORACLE_BASE_HPP +#define AUCTION_ORACLE_BASE_HPP + +#include +#include +#include +#include + +#include "def_debug_ws.h" +#include "auction_oracle.h" + + +#ifdef FOR_R_TDA +#undef DEBUG_AUCTION +#endif + +namespace hera { +namespace ws { + +template +AuctionOracleBase::AuctionOracleBase(const PointContainer& _bidders, + const PointContainer& _items, + const AuctionParams& params) : + bidders(_bidders), + items(_items), + num_bidders_(_bidders.size()), + num_items_(_items.size()), + prices(items.size(), Real(0.0)), + wasserstein_power(params.wasserstein_power), + internal_p(params.internal_p), + dim(params.dim) +{ + assert(bidders.size() == items.size() ); +} + + +template +Real AuctionOracleBase::get_value_for_bidder(size_t bidder_idx, size_t item_idx) const +{ + return std::pow(dist_lp(bidders[bidder_idx], items[item_idx], internal_p, dim), wasserstein_power) + get_price(item_idx); +} + +template +Real AuctionOracleBase::get_value_for_diagonal_bidder(size_t item_idx) const +{ + return get_cost_for_diagonal_bidder(item_idx) + get_price(item_idx); +} + +template +Real AuctionOracleBase::get_cost_for_diagonal_bidder(size_t item_idx) const +{ + return std::pow(items[item_idx].persistence_lp(internal_p), wasserstein_power); +} + + + +template +std::ostream& operator<< (std::ostream& output, const DebugOptimalBid& db) +{ + output << "best_item_value = " << db.best_item_value; + output << "; best_item_idx = " << db.best_item_idx; + output << "; second_best_item_value = " << db.second_best_item_value; + output << "; second_best_item_idx = " << db.second_best_item_idx; + return output; +} + +} // ws +} // hera + +#endif diff --git a/src/dionysus/wasserstein/auction_oracle_kdtree_pure_geom.h b/src/dionysus/wasserstein/auction_oracle_kdtree_pure_geom.h new file mode 100755 index 0000000..096583e --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle_kdtree_pure_geom.h @@ -0,0 +1,97 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef AUCTION_ORACLE_KDTREE_PURE_GEOM_H +#define AUCTION_ORACLE_KDTREE_PURE_GEOM_H + + +#include +#include +#include + +#include + +namespace ba = boost::adaptors; + +#include "spdlog/spdlog.h" +#include "basic_defs_ws.h" +#include "auction_oracle_base.h" +#include "dnn/geometry/euclidean-dynamic.h" +#include "dnn/local/kd-tree.h" + +namespace hera +{ +namespace ws +{ + +template > +struct AuctionOracleKDTreePureGeom : AuctionOracleBase { + + using Real = Real_; + using DynamicPointTraitsR = typename hera::ws::dnn::DynamicPointTraits; + using DiagramPointR = typename DynamicPointTraitsR::PointType; + using PointHandleR = typename DynamicPointTraitsR::PointHandle; + using PointContainer = PointContainer_; + using DebugOptimalBidR = typename ws::DebugOptimalBid; + + using DynamicPointTraits = hera::ws::dnn::DynamicPointTraits; + using KDTreeR = hera::ws::dnn::KDTree; + + AuctionOracleKDTreePureGeom(const PointContainer& bidders, const PointContainer& items, const AuctionParams& params); + ~AuctionOracleKDTreePureGeom(); + + // data members + // temporarily make everything public + DynamicPointTraits traits; + Real max_val_; + Real weight_adj_const_; + std::unique_ptr kdtree_; + std::vector kdtree_items_; + // methods + void set_price(const IdxType items_idx, const Real new_price); + IdxValPair get_optimal_bid(const IdxType bidder_idx); + void adjust_prices(); + void adjust_prices(const Real delta); + + // debug routines + DebugOptimalBidR get_optimal_bid_debug(IdxType bidder_idx) const; + void sanity_check(); + + std::shared_ptr console_logger; + + std::pair get_minmax_price() const; + +}; + +} // ws +} // hera + + +#include "auction_oracle_kdtree_pure_geom.hpp" + +#endif diff --git a/src/dionysus/wasserstein/auction_oracle_kdtree_pure_geom.hpp b/src/dionysus/wasserstein/auction_oracle_kdtree_pure_geom.hpp new file mode 100755 index 0000000..a6bdf10 --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle_kdtree_pure_geom.hpp @@ -0,0 +1,244 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ +#ifndef AUCTION_ORACLE_KDTREE_PURE_GEOM_HPP +#define AUCTION_ORACLE_KDTREE_PURE_GEOM_HPP + +#include +#include +#include +#include + +#include "def_debug_ws.h" +#include "auction_oracle_kdtree_restricted.h" + + +#ifdef FOR_R_TDA +#undef DEBUG_AUCTION +#endif + +namespace hera { +namespace ws { + + +// ***************************** +// AuctionOracleKDTreePureGeom +// ***************************** + + + +template +std::ostream& operator<<(std::ostream& output, const AuctionOracleKDTreePureGeom& oracle) +{ + output << "Oracle " << &oracle << std::endl; + output << fmt::format(" max_val_ = {0}\n", + oracle.max_val_); + + output << fmt::format(" prices = {0}\n", + format_container_to_log(oracle.prices)); + + output << "end of oracle " << &oracle << std::endl; + return output; +} + + +template +AuctionOracleKDTreePureGeom::AuctionOracleKDTreePureGeom(const PointContainer_& _bidders, + const PointContainer_& _items, + const AuctionParams& params) : + AuctionOracleBase(_bidders, _items, params), + traits(params.dim) +{ + + traits.internal_p = params.internal_p; + + std::vector item_handles(this->num_items_); + for(size_t i = 0; i < this->num_items_; ++i) { + item_handles[i] = traits.handle(this->items[i]); + } + + //kdtree_ = std::unique_ptr(new KDTreeR(traits, + // this->items | ba::transformed([this](const DiagramPointR& p) { return traits.handle(p); }), + // params.wasserstein_power)); + + kdtree_ = std::unique_ptr(new KDTreeR(traits, item_handles, params.wasserstein_power)); + + + max_val_ = 3*getFurthestDistance3Approx_pg(this->bidders, this->items, params.internal_p, params.dim); + max_val_ = std::pow(max_val_, params.wasserstein_power); + weight_adj_const_ = max_val_; + + console_logger = spdlog::get("console"); + if (not console_logger) { + console_logger = spdlog::stdout_logger_st("console"); + } + console_logger->set_pattern("[%H:%M:%S.%e] %v"); + console_logger->debug("KDTree Restricted oracle ctor done"); +} + + +template +typename AuctionOracleKDTreePureGeom::DebugOptimalBidR +AuctionOracleKDTreePureGeom::get_optimal_bid_debug(IdxType bidder_idx) const +{ + auto bidder = this->bidders[bidder_idx]; + + size_t best_item_idx = k_invalid_index; + size_t second_best_item_idx = k_invalid_index; + Real best_item_value = std::numeric_limits::max(); + Real second_best_item_value = std::numeric_limits::max(); + + for(IdxType item_idx = 0; item_idx < this->items.size(); ++item_idx) { + auto item = this->items[item_idx]; + if (item.type != bidder.type and item_idx != bidder_idx) + continue; + auto item_value = std::pow(dist_lp(bidder, item, this->internal_p), this->wasserstein_power, this->dim) + this->prices[item_idx]; + if (item_value < best_item_value) { + best_item_value = item_value; + best_item_idx = item_idx; + } + } + + assert(best_item_idx != k_invalid_index); + + for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + auto item = this->items[item_idx]; + if (item.type != bidder.type and item_idx != bidder_idx) + continue; + if (item_idx == best_item_idx) + continue; + auto item_value = std::pow(dist_lp(bidder, item, this->internal_p), this->wasserstein_power, this->dim) + this->prices[item_idx]; + if (item_value < second_best_item_value) { + second_best_item_value = item_value; + second_best_item_idx = item_idx; + } + } + + assert(second_best_item_idx != k_invalid_index); + assert(second_best_item_value >= best_item_value); + + DebugOptimalBidR result; + + result.best_item_idx = best_item_idx; + result.best_item_value = best_item_value; + result.second_best_item_idx = second_best_item_idx; + result.second_best_item_value = second_best_item_value; + + return result; +} + + +template +IdxValPair AuctionOracleKDTreePureGeom::get_optimal_bid(IdxType bidder_idx) +{ + auto two_best_items = kdtree_->findK(this->bidders[bidder_idx], 2); + size_t best_item_idx = traits.id(two_best_items[0].p); + Real best_item_value = two_best_items[0].d; + Real second_best_item_value = two_best_items[1].d; + + IdxValPair result; + + assert( second_best_item_value >= best_item_value ); + + result.first = best_item_idx; + result.second = ( second_best_item_value - best_item_value ) + this->prices[best_item_idx] + this->epsilon; + + return result; +} + +/* +a_{ij} = d_{ij} +value_{ij} = a_{ij} + price_j +*/ + +template +void AuctionOracleKDTreePureGeom::set_price(IdxType item_idx, + Real new_price) +{ + + console_logger->debug("Enter set_price, item_idx = {0}, new_price = {1}, old price = {2}", item_idx, new_price, this->prices[item_idx]); + + assert(this->prices.size() == this->items.size()); + // adjust_prices decreases prices, + // also this variable must be true in reverse phases of FR-auction + + this->prices[item_idx] = new_price; + kdtree_->change_weight( traits.handle(this->items[item_idx]), new_price); + + console_logger->debug("Exit set_price, item_idx = {0}, new_price = {1}", item_idx, new_price); +} + + +template +void AuctionOracleKDTreePureGeom::adjust_prices(Real delta) +{ + //console_logger->debug("Enter adjust_prices, delta = {0}", delta); + //std::cerr << *this << std::endl; + + if (delta == 0.0) + return; + + for(auto& p : this->prices) { + p -= delta; + } + + kdtree_->adjust_weights(delta); + + //std::cerr << *this << std::endl; + //console_logger->debug("Exit adjust_prices, delta = {0}", delta); +} + +template +void AuctionOracleKDTreePureGeom::adjust_prices() +{ + auto pr_begin = this->prices.begin(); + auto pr_end = this->prices.end(); + Real min_price = *(std::min_element(pr_begin, pr_end)); + adjust_prices(min_price); +} + +template +std::pair AuctionOracleKDTreePureGeom::get_minmax_price() const +{ + auto r = std::minmax_element(this->prices.begin(), this->prices.end()); + return std::make_pair(*r.first, *r.second); +} + +template +AuctionOracleKDTreePureGeom::~AuctionOracleKDTreePureGeom() +{ +} + +template +void AuctionOracleKDTreePureGeom::sanity_check() +{ +} + + +} // ws +} // hera + +#endif diff --git a/src/dionysus/wasserstein/auction_oracle_kdtree_restricted.h b/src/dionysus/wasserstein/auction_oracle_kdtree_restricted.h new file mode 100755 index 0000000..1999147 --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle_kdtree_restricted.h @@ -0,0 +1,122 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef AUCTION_ORACLE_KDTREE_RESTRICTED_H +#define AUCTION_ORACLE_KDTREE_RESTRICTED_H + + +//#define USE_BOOST_HEAP + +#include +#include +#include + + +#include "spdlog/spdlog.h" +#include "basic_defs_ws.h" +#include "diagonal_heap.h" +#include "auction_oracle_base.h" +#include "dnn/geometry/euclidean-fixed.h" +#include "dnn/local/kd-tree.h" + +namespace hera { +namespace ws { + +template >> +struct AuctionOracleKDTreeRestricted : AuctionOracleBase { + + using PointContainer = PointContainer_; + using Real = Real_; + + using LossesHeapR = typename ws::LossesHeapOld; + using LossesHeapRHandle = typename ws::LossesHeapOld::handle_type; + using DiagramPointR = typename ws::DiagramPoint; + using DebugOptimalBidR = typename ws::DebugOptimalBid; + + using DnnPoint = dnn::Point<2, Real>; + using DnnTraits = dnn::PointTraits; + + AuctionOracleKDTreeRestricted(const PointContainer& bidders, const PointContainer& items, const AuctionParams& params); + ~AuctionOracleKDTreeRestricted(); + // data members + // temporarily make everything public + Real max_val_; + Real weight_adj_const_; + dnn::KDTree* kdtree_; + std::vector dnn_points_; + std::vector dnn_point_handles_; + LossesHeapR diag_items_heap_; + std::vector diag_heap_handles_; + std::vector heap_handles_indices_; + std::vector kdtree_items_; + std::vector top_diag_indices_; + std::vector top_diag_lookup_; + size_t top_diag_counter_ { 0 }; + bool best_diagonal_items_computed_ { false }; + Real best_diagonal_item_value_; + size_t second_best_diagonal_item_idx_ { k_invalid_index }; + Real second_best_diagonal_item_value_ { std::numeric_limits::max() }; + + + // methods + void set_price(const IdxType items_idx, const Real new_price, const bool update_diag = true); + IdxValPair get_optimal_bid(const IdxType bidder_idx); + void adjust_prices(); + void adjust_prices(const Real delta); + + // debug routines + DebugOptimalBidR get_optimal_bid_debug(IdxType bidder_idx) const; + void sanity_check(); + + + // heap top vector + size_t get_heap_top_size() const; + void recompute_top_diag_items(bool hard = false); + void recompute_second_best_diag(); + void reset_top_diag_counter(); + void increment_top_diag_counter(); + void add_top_diag_index(const size_t item_idx); + void remove_top_diag_index(const size_t item_idx); + bool is_in_top_diag_indices(const size_t item_idx) const; + + std::shared_ptr console_logger; + + std::pair get_minmax_price() const; + +}; + +template +std::ostream& operator<< (std::ostream& output, const DebugOptimalBid& db); + +} // ws +} // hera + + +#include "auction_oracle_kdtree_restricted.hpp" + +#endif diff --git a/src/dionysus/wasserstein/auction_oracle_kdtree_restricted.hpp b/src/dionysus/wasserstein/auction_oracle_kdtree_restricted.hpp new file mode 100755 index 0000000..0e6f780 --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle_kdtree_restricted.hpp @@ -0,0 +1,598 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ +#ifndef AUCTION_ORACLE_KDTREE_RESTRICTED_HPP +#define AUCTION_ORACLE_KDTREE_RESTRICTED_HPP + +#include +#include +#include +#include + +#include "def_debug_ws.h" +#include "auction_oracle_kdtree_restricted.h" + + +#ifdef FOR_R_TDA +#undef DEBUG_AUCTION +#endif + +namespace hera { +namespace ws { + + +// ***************************** +// AuctionOracleKDTreeRestricted +// ***************************** + + + +template +std::ostream& operator<<(std::ostream& output, const AuctionOracleKDTreeRestricted& oracle) +{ + output << "Oracle " << &oracle << std::endl; + output << fmt::format(" max_val_ = {0}, best_diagonal_items_computed_ = {1}, best_diagonal_item_value_ = {2}, second_best_diagonal_item_idx_ = {3}, second_best_diagonal_item_value_ = {4}\n", + oracle.max_val_, + oracle.best_diagonal_items_computed_, + oracle.best_diagonal_item_value_, + oracle.second_best_diagonal_item_idx_, + oracle.second_best_diagonal_item_value_); + + output << fmt::format(" prices = {0}\n", + format_container_to_log(oracle.prices)); + + output << fmt::format(" diag_items_heap_ = {0}\n", + losses_heap_to_string(oracle.diag_items_heap_)); + + + output << fmt::format(" top_diag_indices_ = {0}\n", + format_container_to_log(oracle.top_diag_indices_)); + + output << fmt::format(" top_diag_counter_ = {0}\n", + oracle.top_diag_counter_); + + output << fmt::format(" top_diag_lookup_ = {0}\n", + format_container_to_log(oracle.top_diag_lookup_)); + + + output << "end of oracle " << &oracle << std::endl; + return output; +} + + +template +AuctionOracleKDTreeRestricted::AuctionOracleKDTreeRestricted(const PointContainer_& _bidders, + const PointContainer_& _items, + const AuctionParams& params) : + AuctionOracleBase(_bidders, _items, params), + heap_handles_indices_(_items.size(), k_invalid_index), + kdtree_items_(_items.size(), k_invalid_index), + top_diag_lookup_(_items.size(), k_invalid_index) +{ + size_t dnn_item_idx { 0 }; + size_t true_idx { 0 }; + dnn_points_.clear(); + dnn_points_.reserve(this->items.size()); + // store normal items in kd-tree + for(const auto& g : this->items) { + if (g.is_normal() ) { + kdtree_items_[true_idx] = dnn_item_idx; + // index of items is id of dnn-point + DnnPoint p(true_idx); + p[0] = g.getRealX(); + p[1] = g.getRealY(); + dnn_points_.push_back(p); + assert(dnn_item_idx == dnn_points_.size() - 1); + dnn_item_idx++; + } + true_idx++; + } + + assert(dnn_points_.size() < _items.size() ); + for(size_t i = 0; i < dnn_points_.size(); ++i) { + dnn_point_handles_.push_back(&dnn_points_[i]); + } + DnnTraits traits; + traits.internal_p = params.internal_p; + kdtree_ = new dnn::KDTree(traits, dnn_point_handles_, params.wasserstein_power); + + size_t handle_idx {0}; + for(size_t item_idx = 0; item_idx < _items.size(); ++item_idx) { + if (this->items[item_idx].is_diagonal()) { + heap_handles_indices_[item_idx] = handle_idx++; + diag_heap_handles_.push_back(diag_items_heap_.push(std::make_pair(item_idx, 0.0))); + } + } + max_val_ = 3*getFurthestDistance3Approx<>(_bidders, _items, params.internal_p); + max_val_ = std::pow(max_val_, params.wasserstein_power); + weight_adj_const_ = max_val_; + + console_logger = spdlog::get("console"); + if (not console_logger) { + console_logger = spdlog::stdout_logger_st("console"); + } + console_logger->set_pattern("[%H:%M:%S.%e] %v"); + console_logger->debug("KDTree Restricted oracle ctor done"); +} + + +template +bool AuctionOracleKDTreeRestricted::is_in_top_diag_indices(const size_t item_idx) const +{ + return top_diag_lookup_[item_idx] != k_invalid_index; +} + + +template +void AuctionOracleKDTreeRestricted::add_top_diag_index(const size_t item_idx) +{ + assert(find(top_diag_indices_.begin(), top_diag_indices_.end(), item_idx) == top_diag_indices_.end()); + assert(this->items[item_idx].is_diagonal()); + + top_diag_indices_.push_back(item_idx); + top_diag_lookup_[item_idx] = top_diag_indices_.size() - 1; +} + +template +void AuctionOracleKDTreeRestricted::remove_top_diag_index(const size_t item_idx) +{ + if (top_diag_indices_.size() > 1) { + // remove item_idx from top_diag_indices after swapping + // it with the last element, update index lookup appropriately + auto old_index = top_diag_lookup_[item_idx]; + auto end_element = top_diag_indices_.back(); + std::swap(top_diag_indices_[old_index], top_diag_indices_.back()); + top_diag_lookup_[end_element] = old_index; + } + + top_diag_indices_.pop_back(); + top_diag_lookup_[item_idx] = k_invalid_index; + if (top_diag_indices_.size() < 2) { + recompute_second_best_diag(); + } + best_diagonal_items_computed_ = not top_diag_indices_.empty(); + reset_top_diag_counter(); +} + + +template +void AuctionOracleKDTreeRestricted::increment_top_diag_counter() +{ + assert(top_diag_counter_ >= 0 and top_diag_counter_ < top_diag_indices_.size()); + + ++top_diag_counter_; + if (top_diag_counter_ >= top_diag_indices_.size()) { + top_diag_counter_ -= top_diag_indices_.size(); + } + + assert(top_diag_counter_ >= 0 and top_diag_counter_ < top_diag_indices_.size()); +} + + +template +void AuctionOracleKDTreeRestricted::reset_top_diag_counter() +{ + top_diag_counter_ = 0; +} + +template +void AuctionOracleKDTreeRestricted::recompute_top_diag_items(bool hard) +{ + console_logger->debug("Enter recompute_top_diag_items, hard = {0}", hard); + assert(hard or top_diag_indices_.empty()); + + if (hard) { + std::fill(top_diag_lookup_.begin(), top_diag_lookup_.end(), k_invalid_index); + top_diag_indices_.clear(); + } + + auto top_diag_iter = diag_items_heap_.ordered_begin(); + best_diagonal_item_value_ = top_diag_iter->second; + add_top_diag_index(top_diag_iter->first); + + ++top_diag_iter; + + // traverse the heap while we see the same value + while(top_diag_iter != diag_items_heap_.ordered_end()) { + if ( top_diag_iter->second != best_diagonal_item_value_) { + break; + } else { + add_top_diag_index(top_diag_iter->first); + } + ++top_diag_iter; + } + + recompute_second_best_diag(); + + best_diagonal_items_computed_ = true; + reset_top_diag_counter(); + console_logger->debug("Exit recompute_top_diag_items, hard = {0}", hard); +} + +template +typename AuctionOracleKDTreeRestricted::DebugOptimalBidR +AuctionOracleKDTreeRestricted::get_optimal_bid_debug(IdxType bidder_idx) const +{ + auto bidder = this->bidders[bidder_idx]; + + size_t best_item_idx = k_invalid_index; + size_t second_best_item_idx = k_invalid_index; + Real best_item_value = std::numeric_limits::max(); + Real second_best_item_value = std::numeric_limits::max(); + + for(IdxType item_idx = 0; item_idx < this->items.size(); ++item_idx) { + auto item = this->items[item_idx]; + if (item.type != bidder.type and item_idx != bidder_idx) + continue; + auto item_value = std::pow(dist_lp(bidder, item, this->internal_p), this->wasserstein_power) + this->prices[item_idx]; + if (item_value < best_item_value) { + best_item_value = item_value; + best_item_idx = item_idx; + } + } + + assert(best_item_idx != k_invalid_index); + + for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + auto item = this->items[item_idx]; + if (item.type != bidder.type and item_idx != bidder_idx) + continue; + if (item_idx == best_item_idx) + continue; + auto item_value = std::pow(dist_lp(bidder, item, this->internal_p), this->wasserstein_power) + this->prices[item_idx]; + if (item_value < second_best_item_value) { + second_best_item_value = item_value; + second_best_item_idx = item_idx; + } + } + + assert(second_best_item_idx != k_invalid_index); + assert(second_best_item_value >= best_item_value); + + DebugOptimalBidR result; + + result.best_item_idx = best_item_idx; + result.best_item_value = best_item_value; + result.second_best_item_idx = second_best_item_idx; + result.second_best_item_value = second_best_item_value; + + return result; +} + + +template +IdxValPair AuctionOracleKDTreeRestricted::get_optimal_bid(IdxType bidder_idx) +{ + auto bidder = this->bidders[bidder_idx]; + + // corresponding point is always considered as a candidate + // if bidder is a diagonal point, proj_item is a normal point, + // and vice versa. + + size_t best_item_idx { k_invalid_index }; + size_t second_best_item_idx __attribute__((unused)) { k_invalid_index }; + size_t best_diagonal_item_idx { k_invalid_index }; + Real best_item_value; + Real second_best_item_value; + + + size_t proj_item_idx = bidder_idx; + assert( 0 <= proj_item_idx and proj_item_idx < this->items.size() ); + assert(this->items[proj_item_idx].type != bidder.type); + Real proj_item_value = this->get_value_for_bidder(bidder_idx, proj_item_idx); + + if (bidder.is_diagonal()) { + // for diagonal bidder the only normal point has already been added + // the other 2 candidates are diagonal items only, get from the heap + // with prices + + if (not best_diagonal_items_computed_) { + recompute_top_diag_items(); + } + + best_diagonal_item_idx = top_diag_indices_[top_diag_counter_]; + increment_top_diag_counter(); + + if ( proj_item_value < best_diagonal_item_value_) { + best_item_idx = proj_item_idx; + best_item_value = proj_item_value; + second_best_item_value = best_diagonal_item_value_; + second_best_item_idx = best_diagonal_item_idx; + } else if (proj_item_value < second_best_diagonal_item_value_) { + best_item_idx = best_diagonal_item_idx; + best_item_value = best_diagonal_item_value_; + second_best_item_value = proj_item_value; + second_best_item_idx = proj_item_idx; + } else { + best_item_idx = best_diagonal_item_idx; + best_item_value = best_diagonal_item_value_; + second_best_item_value = second_best_diagonal_item_value_; + second_best_item_idx = second_best_diagonal_item_idx_; + } + } else { + // for normal bidder get 2 best items among non-diagonal points from + // kdtree_ + DnnPoint bidder_dnn; + bidder_dnn[0] = bidder.getRealX(); + bidder_dnn[1] = bidder.getRealY(); + auto two_best_items = kdtree_->findK(bidder_dnn, 2); + size_t best_normal_item_idx { two_best_items[0].p->id() }; + Real best_normal_item_value { two_best_items[0].d }; + // if there is only one off-diagonal point in the second diagram, + // kd-tree will not return the second candidate. + // Set its value to inf, so it will always lose to the value of the projection + Real second_best_normal_item_value { two_best_items.size() == 1 ? std::numeric_limits::max() : two_best_items[1].d }; + + if ( proj_item_value < best_normal_item_value) { + best_item_idx = proj_item_idx; + best_item_value = proj_item_value; + second_best_item_value = best_normal_item_value; + } else if (proj_item_value < second_best_normal_item_value) { + best_item_idx = best_normal_item_idx; + best_item_value = best_normal_item_value; + second_best_item_value = proj_item_value; + } else { + best_item_idx = best_normal_item_idx; + best_item_value = best_normal_item_value; + second_best_item_value = second_best_normal_item_value; + } + } + + IdxValPair result; + + assert( second_best_item_value >= best_item_value ); + + result.first = best_item_idx; + result.second = ( second_best_item_value - best_item_value ) + this->prices[best_item_idx] + this->epsilon; + +#ifdef DEBUG_KDTREE_RESTR_ORACLE + auto db = get_optimal_bid_debug(bidder_idx); + assert(fabs(db.best_item_value - best_item_value) < 0.000001); + if (fabs(db.second_best_item_value - second_best_item_value) >= 0.000001) { + console_logger->debug("Bidder_idx = {0}, best_item_idx = {1}, true_best_item_idx = {2}", bidder_idx, best_item_idx, db.best_item_idx); + console_logger->debug("second_best_item_idx = {0}, true second_best_item_idx = {1}", second_best_item_idx, db.second_best_item_idx); + console_logger->debug("second_best_value = {0}, true second_best_item_value = {1}", second_best_item_value, db.second_best_item_value); + console_logger->debug("prices = {0}", format_container_to_log(this->prices)); + console_logger->debug("top_diag_indices_ = {0}", format_container_to_log(top_diag_indices_)); + console_logger->debug("second_best_diagonal_item_value_ = {0}", second_best_diagonal_item_value_); + } + assert(fabs(db.second_best_item_value - second_best_item_value) < 0.000001); + //std::cout << "bid OK" << std::endl; +#endif + + return result; +} +/* +a_{ij} = d_{ij} +value_{ij} = a_{ij} + price_j +*/ +template +void AuctionOracleKDTreeRestricted::recompute_second_best_diag() +{ + + console_logger->debug("Enter recompute_second_best_diag"); + + if (top_diag_indices_.size() > 1) { + second_best_diagonal_item_value_ = best_diagonal_item_value_; + second_best_diagonal_item_idx_ = top_diag_indices_[0]; + } else { + if (diag_items_heap_.size() == 1) { + second_best_diagonal_item_value_ = std::numeric_limits::max(); + second_best_diagonal_item_idx_ = k_invalid_index; + } else { + auto diag_iter = diag_items_heap_.ordered_begin(); + ++diag_iter; + second_best_diagonal_item_value_ = diag_iter->second; + second_best_diagonal_item_idx_ = diag_iter->first; + } + } + + console_logger->debug("Exit recompute_second_best_diag, second_best_diagonal_item_value_ = {0}, second_best_diagonal_item_idx_ = {1}", second_best_diagonal_item_value_, second_best_diagonal_item_idx_); +} + + +template +void AuctionOracleKDTreeRestricted::set_price(IdxType item_idx, + Real new_price, + const bool update_diag) +{ + + console_logger->debug("Enter set_price, item_idx = {0}, new_price = {1}, old price = {2}, update_diag = {3}", item_idx, new_price, this->prices[item_idx], update_diag); + + assert(this->prices.size() == this->items.size()); + assert( 0 < diag_heap_handles_.size() and diag_heap_handles_.size() <= this->items.size()); + // adjust_prices decreases prices, + // also this variable must be true in reverse phases of FR-auction + bool item_goes_down = new_price > this->prices[item_idx]; + + this->prices[item_idx] = new_price; + if ( this->items[item_idx].is_normal() ) { + assert(0 <= item_idx and item_idx < static_cast(kdtree_items_.size())); + assert(0 <= kdtree_items_[item_idx] and kdtree_items_[item_idx] < dnn_point_handles_.size()); + kdtree_->change_weight( dnn_point_handles_[kdtree_items_[item_idx]], new_price); + } else { + assert(diag_heap_handles_.size() > heap_handles_indices_.at(item_idx)); + if (item_goes_down) { + diag_items_heap_.decrease(diag_heap_handles_[heap_handles_indices_[item_idx]], std::make_pair(item_idx, new_price)); + } else { + diag_items_heap_.increase(diag_heap_handles_[heap_handles_indices_[item_idx]], std::make_pair(item_idx, new_price)); + } + if (update_diag) { + // Update top_diag_indices_ only if necessary: + // normal bidders take their projections, which might not be on top + // also, set_price is called by adjust_prices, and we may have already + // removed the item from top_diag + if (is_in_top_diag_indices(item_idx)) { + remove_top_diag_index(item_idx); + } + + if (item_idx == (IdxType)second_best_diagonal_item_idx_) { + recompute_second_best_diag(); + } + } + } + + console_logger->debug("Exit set_price, item_idx = {0}, new_price = {1}", item_idx, new_price); +} + + +template +void AuctionOracleKDTreeRestricted::adjust_prices(Real delta) +{ + //console_logger->debug("Enter adjust_prices, delta = {0}", delta); + //std::cerr << *this << std::endl; + + if (delta == 0.0) + return; + + for(auto& p : this->prices) { + p -= delta; + } + + kdtree_->adjust_weights(delta); + + bool price_goes_up = delta < 0; + + for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + if (this->items[item_idx].is_diagonal()) { + auto new_price = this->prices[item_idx]; + if (price_goes_up) { + diag_items_heap_.decrease(diag_heap_handles_[heap_handles_indices_[item_idx]], std::make_pair(item_idx, new_price)); + } else { + diag_items_heap_.increase(diag_heap_handles_[heap_handles_indices_[item_idx]], std::make_pair(item_idx, new_price)); + } + } + } + best_diagonal_item_value_ -= delta; + second_best_diagonal_item_value_ -= delta; + + //std::cerr << *this << std::endl; + //console_logger->debug("Exit adjust_prices, delta = {0}", delta); +} + +template +void AuctionOracleKDTreeRestricted::adjust_prices() +{ + auto pr_begin = this->prices.begin(); + auto pr_end = this->prices.end(); + Real min_price = *(std::min_element(pr_begin, pr_end)); + adjust_prices(min_price); +} + +template +size_t AuctionOracleKDTreeRestricted::get_heap_top_size() const +{ + return top_diag_indices_.size(); +} + +template +std::pair AuctionOracleKDTreeRestricted::get_minmax_price() const +{ + auto r = std::minmax_element(this->prices.begin(), this->prices.end()); + return std::make_pair(*r.first, *r.second); +} + + + +template +AuctionOracleKDTreeRestricted::~AuctionOracleKDTreeRestricted() +{ + delete kdtree_; +} + +template +void AuctionOracleKDTreeRestricted::sanity_check() +{ +#ifdef DEBUG_KDTREE_RESTR_ORACLE + if (best_diagonal_items_computed_) { + std::vector diag_items_price_vec; + diag_items_price_vec.reserve(this->items.size()); + + for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + if (this->items.at(item_idx).is_diagonal()) { + diag_items_price_vec.push_back(this->prices.at(item_idx)); + } else { + diag_items_price_vec.push_back(std::numeric_limits::max()); + } + } + + auto best_iter = std::min_element(diag_items_price_vec.begin(), diag_items_price_vec.end()); + assert(best_iter != diag_items_price_vec.end()); + Real true_best_diag_value = *best_iter; + size_t true_best_diag_idx = best_iter - diag_items_price_vec.begin(); + assert(true_best_diag_value != std::numeric_limits::max()); + + Real true_second_best_diag_value = std::numeric_limits::max(); + size_t true_second_best_diag_idx = k_invalid_index; + for(size_t item_idx = 0; item_idx < diag_items_price_vec.size(); ++item_idx) { + if (this->items.at(item_idx).is_normal()) { + assert(top_diag_lookup_.at(item_idx) == k_invalid_index); + continue; + } + + auto i_iter = std::find(top_diag_indices_.begin(), top_diag_indices_.end(), item_idx); + if (diag_items_price_vec.at(item_idx) == true_best_diag_value) { + assert(i_iter != top_diag_indices_.end()); + assert(top_diag_lookup_.at(item_idx) == i_iter - top_diag_indices_.begin()); + } else { + assert(top_diag_lookup_.at(item_idx) == k_invalid_index); + assert(i_iter == top_diag_indices_.end()); + } + + if (item_idx == true_best_diag_idx) { + continue; + } + if (diag_items_price_vec.at(item_idx) < true_second_best_diag_value) { + true_second_best_diag_value = diag_items_price_vec.at(item_idx); + true_second_best_diag_idx = item_idx; + } + } + + if (true_best_diag_value != best_diagonal_item_value_) { + console_logger->debug("best_diagonal_item_value_ = {0}, true value = {1}", best_diagonal_item_value_, true_best_diag_value); + std::cerr << *this; + //console_logger->debug("{0}", *this); + } + + assert(true_best_diag_value == best_diagonal_item_value_); + + assert(true_second_best_diag_idx != k_invalid_index); + + if (true_second_best_diag_value != second_best_diagonal_item_value_) { + console_logger->debug("second_best_diagonal_item_value_ = {0}, true value = {1}", second_best_diagonal_item_value_, true_second_best_diag_value); + //console_logger->debug("{0}", *this); + } + + assert(true_second_best_diag_value == second_best_diagonal_item_value_); + } +#endif +} + + +} // ws +} // hera + +#endif diff --git a/src/dionysus/wasserstein/auction_oracle_kdtree_single_diag.h b/src/dionysus/wasserstein/auction_oracle_kdtree_single_diag.h new file mode 100755 index 0000000..9192993 --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle_kdtree_single_diag.h @@ -0,0 +1,219 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef AUCTION_ORACLE_KDTREE_SINGLE_DIAG_H +#define AUCTION_ORACLE_KDTREE_SINGLE_DIAG_H + + +#include +#include +#include +#include +#include + +#include "basic_defs_ws.h" +#include "dnn/geometry/euclidean-fixed.h" +#include "dnn/local/kd-tree.h" + +namespace hera { +namespace ws { + + +template +struct ItemSlice; + +template +bool operator<(const ItemSlice& s_1, const ItemSlice& s_2); + +template +bool operator>(const ItemSlice& s_1, const ItemSlice& s_2); + +template +struct ItemSlice { +public: + using RealType = Real; + + size_t item_idx; + Real loss; + ItemSlice(size_t _item_idx, const Real _loss); + + void set_loss(const Real new_loss) { loss = new_loss; } + void adjust_loss(const Real delta) { loss -= delta; } + + friend bool operator< <>(const ItemSlice&, const ItemSlice&); + friend bool operator> <>(const ItemSlice&, const ItemSlice&); + +private: +}; + + +template +class LossesHeap { +public: + using ItemSliceR = ItemSlice; + using KeeperTypeR = std::set >; + using IterTypeR = typename KeeperTypeR::iterator; + + LossesHeap() {} + LossesHeap(const std::vector&); + void adjust_prices(const Real delta); // subtract delta from all values + ItemSliceR get_best_slice() const; + ItemSliceR get_second_best_slice() const; + + template + decltype(auto) emplace(Args&&... args) + { + return keeper.emplace(std::forward(args)...); + } + + + IterTypeR begin() { return keeper.begin(); } + IterTypeR end() { return keeper.end(); } + void erase(IterTypeR iter) { assert(iter != keeper.end()); keeper.erase(iter); } + decltype(auto) insert(const ItemSliceR& item) { return keeper.insert(item); } + size_t size() const { return keeper.size(); } + bool empty() const { return keeper.empty(); } +//private: + std::set > keeper; +}; + +template +struct DiagonalBid { + DiagonalBid() {} + + std::vector assigned_normal_items; + std::vector assigned_normal_items_bid_values; + + std::vector best_item_indices; + std::vector bid_values; + + // common bid value for diag-diag + Real diag_to_diag_value { 0.0 }; + + // analogous to second best item value; denoted by w in Bertsekas's paper on auction for transportation problem + Real almost_best_value { 0.0 }; + + // how many points to get from unassigned diagonal chunk + int num_from_unassigned_diag { 0 }; +}; + +template >> +struct AuctionOracleKDTreeSingleDiag : AuctionOracleBase { + + using PointContainer = PointContainer_; + using Real = Real_; + + using DnnPoint = dnn::Point<2, Real>; + using DnnTraits = dnn::PointTraits; + + using IdxValPairR = typename ws::IdxValPair; + using ItemSliceR = typename ws::ItemSlice; + using LossesHeapR = typename ws::LossesHeap; + using LossesHeapIterR = typename ws::LossesHeap::IterTypeR; + using DiagramPointR = typename ws::DiagramPoint; + using DiagonalBidR = typename ws::DiagonalBid; + + AuctionOracleKDTreeSingleDiag(const PointContainer& bidders, + const PointContainer& items, + const AuctionParams& params); + ~AuctionOracleKDTreeSingleDiag(); + // data members + // temporarily make everything public + Real max_val_; + size_t num_diag_items_; + size_t num_normal_items_; + size_t num_diag_bidders_; + size_t num_normal_bidders_; + dnn::KDTree* kdtree_; + std::vector dnn_points_; + std::vector dnn_point_handles_; + std::vector kdtree__items_; + + // this heap is used by off-diagonal bidders to get the cheapest diagonal + // item; index in the IdxVal is a valid item index in the vector of items + // items in diag_assigned_to_diag_slice_ and in diag_unassigned_slice_ + // are not stored in this heap + LossesHeapR diag_items_heap_; + // vector of iterators; if item_idx is in diag_assigned_to_diag_slice_ or + // in diag_unassigned_slice_, then diag_items_heap__iters_[item_idx] == + // diag_items_heap_.end() + std::vector diag_items_heap__iters_; + + + // this heap is used by _the_ diagonal bidder to get the cheapest items + // * value in IdxValPair is price + persistence (i.e., price for + // diagonal items) + // * index in IdxValPair is a valid item index in the vector of items + // items in diag_assigned_to_diag_slice_ and in diag_unassigned_slice_ + // are not stored in this heap + LossesHeapR all_items_heap_; + std::vector all_items_heap__iters_; + + std::unordered_set diag_assigned_to_diag_slice_; + std::unordered_set diag_unassigned_slice_; + + + std::unordered_set normal_items_assigned_to_diag_; + + Real diag_to_diag_price_; + Real diag_unassigned_price_; + + // methods + Real get_price(const size_t item_idx) const override; + void set_price(const size_t item_idx, + const Real new_price, + const bool item_is_diagonal, + const bool bidder_is_diagonal, + const OwnerType old_owner_type); + + IdxValPair get_optimal_bid(const IdxType bidder_idx); + + DiagonalBidR get_optimal_bids_for_diagonal(int unassigned_mass); + void process_unassigned_diagonal(const int unassigned_mass, + int& accumulated_mass, + bool& saw_diagonal_slice, + int& num_classes, + Real& w, + DiagonalBidR& result, + bool& found_w); + + void adjust_prices(); + void flush_assignment(); + void sanity_check(); + + bool is_item_diagonal(const size_t item_idx) const; + bool is_item_normal(const size_t item_idx) const { return not is_item_diagonal(item_idx); } + +}; + +} // ws +} // hera + +#include "auction_oracle_kdtree_single_diag.hpp" + +#endif diff --git a/src/dionysus/wasserstein/auction_oracle_kdtree_single_diag.hpp b/src/dionysus/wasserstein/auction_oracle_kdtree_single_diag.hpp new file mode 100755 index 0000000..42677ab --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle_kdtree_single_diag.hpp @@ -0,0 +1,717 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ +#ifndef AUCTION_ORACLE_KDTREE_RESTRICTED_SINGLE_DIAG_HPP +#define AUCTION_ORACLE_KDTREE_RESTRICTED_SINGLE_DIAG_HPP + +#include +#include +#include +#include + +#include "def_debug_ws.h" +#include "auction_oracle.h" + + +#ifdef FOR_R_TDA +#undef DEBUG_AUCTION +#endif + +namespace hera { +namespace ws { + +// ***************************** +// AuctionOracleKDTreeSingleDiag +// ***************************** + + + +template +ItemSlice::ItemSlice(size_t _item_idx, + const Real _loss) : + item_idx(_item_idx), + loss(_loss) +{ +} + + +template +bool operator<(const ItemSlice& s_1, const ItemSlice& s_2) +{ + return s_1.loss < s_2.loss + or (s_1.loss == s_2.loss and s_1.item_idx < s_2.item_idx); +} + +template +bool operator>(const ItemSlice& s_1, const ItemSlice& s_2) +{ + return s_1.loss > s_2.loss + or (s_1.loss == s_2.loss and s_1.item_idx > s_2.item_idx); +} + +template +std::ostream& operator<<(std::ostream& s, const ItemSlice& x) +{ + s << "(" << x.item_idx << ", " << x.loss << ")"; + return s; +} + +// ***************************** +// LossesHeap +// ***************************** + + +template +void LossesHeap::adjust_prices(const Real delta) +{ + throw std::runtime_error("not implemented"); +} + +template +typename LossesHeap::ItemSliceR LossesHeap::get_best_slice() const +{ + return *(keeper.begin()); +} + +template +typename LossesHeap::ItemSliceR LossesHeap::get_second_best_slice() const +{ + if (keeper.size() > 1) { + return *std::next(keeper.begin()); + } else { + return ItemSliceR(k_invalid_index, std::numeric_limits::max()); + } +} + +template +std::ostream& operator<<(std::ostream& s, const LossesHeap& x) +{ + s << "Heap[ "; + for(auto iter = x.keeper.begin(); iter != x.keeper.end(); ++iter) { + s << *iter << "\n"; + } + s << "]\n"; + return s; +} + +// ***************************** +// DiagonalBid +// ***************************** + +template +std::ostream& operator<<(std::ostream& s, const DiagonalBid& b) +{ + s << "DiagonalBid { num_from_unassigned_diag = " << b.num_from_unassigned_diag; + s << ", diag_to_diag_value = " << b.diag_to_diag_value; + s << ", almost_best_value = " << b.almost_best_value; + s << ",\nbest_item_indices = ["; + for(const auto i : b.best_item_indices) { + s << i << ", "; + } + s << "]\n"; + + s << ",\nbid_values= ["; + for(const auto v : b.bid_values) { + s << v << ", "; + } + s << "]\n"; + + s << ",\nassigned_normal_items= ["; + for(const auto i : b.assigned_normal_items) { + s << i << ", "; + } + s << "]\n"; + + s << ",\nassigned_normal_items_bid_values = ["; + for(const auto v : b.assigned_normal_items_bid_values) { + s << v << ", "; + } + s << "]\n"; + + return s; +} + +// ***************************** +// AuctionOracleKDTreeSingleDiag +// ***************************** + +template +std::ostream& operator<<(std::ostream& s, const AuctionOracleKDTreeSingleDiag& x) +{ + s << "oracle: bidders" << std::endl; + for(const auto& p : x.bidders) { + s << p << "\n"; + } + s << "items:"; + + for(const auto& p : x.items) { + s << p << "\n"; + } + + s << "diag_unassigned_slice_.size = " << x.diag_unassigned_slice_.size() << ", "; + s << "diag_unassigned_price_ = " << x.diag_unassigned_price_ << ", "; + s << "diag unassigned slice ["; + + for(const auto& i : x.diag_unassigned_slice_) { + s << i << ", "; + } + s << "]\n "; + + s << "diag_assigned_to_diag_slice_.size = " << x.diag_assigned_to_diag_slice_.size() << ", "; + s << "diag_assigned_to_diag_price = " << x.diag_to_diag_price_ << "\n"; + s << "diag_assigned_to_diag_slice_ ["; + + for(const auto& i : x.diag_assigned_to_diag_slice_) { + s << i << ", "; + } + s << "]\n "; + + s << "diag_items_heap_.size = " << x.diag_items_heap_.size() << "\n "; + s << x.diag_items_heap_; + + s << "all_items_heap_.size = " << x.all_items_heap_.size() << "\n "; + s << x.all_items_heap_; + + s << "epsilon = " << x.epsilon << std::endl; + + return s; +} + + +template +AuctionOracleKDTreeSingleDiag::AuctionOracleKDTreeSingleDiag(const PointContainer_& _bidders, + const PointContainer_& _items, + const AuctionParams& params) : + AuctionOracleBase(_bidders, _items, params), + max_val_(std::pow( 3.0 * getFurthestDistance3Approx<>(_bidders, _items, params.internal_p), params.wasserstein_power)), + num_diag_items_(0), + kdtree__items_(_items.size(), k_invalid_index) +{ + size_t dnn_item_idx { 0 }; + dnn_points_.clear(); + + all_items_heap__iters_.clear(); + all_items_heap__iters_.reserve( 4 * _items.size() / 7); + + + for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + const auto& item = this->items[item_idx]; + if (item.is_normal() ) { + // store normal items in kd-tree + kdtree__items_[item_idx] = dnn_item_idx; + // index of items is id of dnn-point + DnnPoint p(item_idx); + p[0] = item.x; + p[1] = item.y; + dnn_points_.push_back(p); + assert(dnn_item_idx == dnn_points_.size() - 1); + dnn_item_idx++; + // add slice to vector + auto ins_res = all_items_heap_.emplace(item_idx, this->get_value_for_diagonal_bidder(item_idx)); + all_items_heap__iters_.push_back(ins_res.first); + assert(ins_res.second); + } else { + // all diagonal items are initially in the unassigned slice + diag_unassigned_slice_.insert(item_idx); + all_items_heap__iters_.push_back(all_items_heap_.end()); + diag_items_heap__iters_.push_back(diag_items_heap_.end()); + ++num_diag_items_; + } + } + + num_normal_items_ = this->items.size() - num_diag_items_; + num_normal_bidders_ = num_diag_items_; + num_diag_bidders_ = this->bidders.size() - num_normal_bidders_; + + assert(dnn_points_.size() < _items.size() ); + for(size_t i = 0; i < dnn_points_.size(); ++i) { + dnn_point_handles_.push_back(&dnn_points_[i]); + } + + DnnTraits traits; + traits.internal_p = params.internal_p; + + kdtree_ = new dnn::KDTree(traits, dnn_point_handles_, this->wasserstein_power); + + sanity_check(); + + //std::cout << "IN CTOR: " << *this << std::endl; + +} + + +template +void AuctionOracleKDTreeSingleDiag::process_unassigned_diagonal(const int unassigned_mass, int& accumulated_mass, bool& saw_diagonal_slice, int& num_classes, Real& w, DiagonalBidR& result, bool& found_w) +{ + result.num_from_unassigned_diag = std::min(static_cast(diag_unassigned_slice_.size()), static_cast(unassigned_mass - accumulated_mass)); + if (not saw_diagonal_slice) { + saw_diagonal_slice = true; + ++num_classes; + } + + accumulated_mass += result.num_from_unassigned_diag; + //std::cout << "got mass from diagunassigned_slice, result.num_from_unassigned_diag = " << result.num_from_unassigned_diag << ", accumulated_mass = " << accumulated_mass << std::endl; + + if (static_cast(diag_unassigned_slice_.size()) > result.num_from_unassigned_diag and num_classes >= 2) { + found_w = true; + w = diag_unassigned_price_; + //std::cout << "w found from diag_unassigned_slice_, too, w = " << w << std::endl; + result.almost_best_value = w; + } + +} + + +template +typename AuctionOracleKDTreeSingleDiag::DiagonalBidR AuctionOracleKDTreeSingleDiag::get_optimal_bids_for_diagonal(int unassigned_mass) +{ + sanity_check(); + + assert(unassigned_mass == static_cast(num_diag_bidders_) + - static_cast(normal_items_assigned_to_diag_.size()) + - static_cast(diag_assigned_to_diag_slice_.size()) ); + assert(unassigned_mass > 0); + + DiagonalBidR result; + + + // number of similarity classes already assigned to diagonal bidder + // each normal point is a single class + // all diagonal points are in one class + int num_classes = normal_items_assigned_to_diag_.size() + ( diag_assigned_to_diag_slice_.empty() ? 0 : 1 ); + bool saw_diagonal_slice = not diag_assigned_to_diag_slice_.empty(); + bool found_w = false; + + //std::cout << "Enter get_optimal_bids_for_diagonal, unassigned_mass = " << unassigned_mass <<", num_classes = " << num_classes << ", saw_diagonal_slice = " << std::boolalpha << saw_diagonal_slice << std::endl; + + decltype(unassigned_mass) accumulated_mass = 0; + + + Real w { std::numeric_limits::max() }; + bool unassigned_not_processed = not diag_unassigned_slice_.empty(); + + for(auto slice_iter = all_items_heap_.begin(); slice_iter != all_items_heap_.end(); ++slice_iter) { + + auto slice = *slice_iter; + + if ( is_item_normal(slice.item_idx) and normal_items_assigned_to_diag_.count(slice.item_idx) == 1) { + //std::cout << __LINE__ << ": skipping slice " << slice << std::endl; + // this item is already assigned to diagonal bidder, skip + continue; + } + + if (unassigned_not_processed and slice.loss >= diag_unassigned_price_) { + // diag_unassigned slice is better, + // process it first + process_unassigned_diagonal(unassigned_mass, accumulated_mass, saw_diagonal_slice, num_classes, w, result, found_w); + unassigned_not_processed = false; + if (accumulated_mass >= unassigned_mass and found_w) { + break; + } + } + + + if (is_item_normal(slice.item_idx)) { + // all off-diagonal items are distinct + ++num_classes; + } else if (not saw_diagonal_slice) { + saw_diagonal_slice = true; + ++num_classes; + } + + if (accumulated_mass < unassigned_mass) { + //std::cout << __LINE__ << ": added slice to best items " << slice << std::endl; + result.best_item_indices.push_back(slice.item_idx); + } + + if (accumulated_mass >= unassigned_mass and num_classes >= 2) { + //std::cout << "Found w, slice = " << slice << std::endl; + w = slice.loss; + found_w = true; + result.almost_best_value = w; + break; + } + + // all items in all_items heap have mass 1 + ++accumulated_mass; + //std::cout << "accumulated_mass = " << accumulated_mass << std::endl; + + } + + if (unassigned_not_processed and (accumulated_mass < unassigned_mass or not found_w)) { + process_unassigned_diagonal(unassigned_mass, accumulated_mass, saw_diagonal_slice, num_classes, w, result, found_w); + } + + assert(found_w); + + //if (w == std::numeric_limits::max()) { std::cout << "HERE: " << *this << std::endl; } + assert(w != std::numeric_limits::max()); + + result.assigned_normal_items.clear(); + result.assigned_normal_items_bid_values.clear(); + + result.assigned_normal_items.reserve(normal_items_assigned_to_diag_.size()); + result.assigned_normal_items_bid_values.reserve(normal_items_assigned_to_diag_.size()); + + // add already assigned normal items and their new prices to bid + for(const auto item_idx : normal_items_assigned_to_diag_) { + assert( all_items_heap__iters_[item_idx] != all_items_heap_.end() ); + assert( is_item_normal(item_idx) ); + + result.assigned_normal_items.push_back(item_idx); + Real bid_value = w - this->get_cost_for_diagonal_bidder(item_idx) + this->epsilon; + //if ( bid_value <= this->get_price(item_idx) ) { + //std::cout << bid_value << " vs price " << this->get_price(item_idx) << std::endl; + //std::cout << *this << std::endl; + //} + assert( bid_value >= this->get_price(item_idx) ); + result.assigned_normal_items_bid_values.push_back(bid_value); + } + + // calculate bid values + // diag-to-diag items all have the same bid value + if (saw_diagonal_slice) { + result.diag_to_diag_value = w + this->epsilon; + } else { + result.diag_to_diag_value = std::numeric_limits::max(); + } + + result.bid_values.reserve(result.best_item_indices.size()); + for(const auto item_idx : result.best_item_indices) { + Real bid_value = w - this->get_cost_for_diagonal_bidder(item_idx) + this->epsilon; + result.bid_values.push_back(bid_value); + } + + return result; +} + + +template +IdxValPair AuctionOracleKDTreeSingleDiag::get_optimal_bid(IdxType bidder_idx) +{ + //std::cout << "enter get_optimal_bid" << std::endl; + sanity_check(); + + auto bidder = this->bidders[bidder_idx]; + + size_t best_item_idx; + Real best_item_price; + Real best_item_value; + Real second_best_item_value; + + // this function is for normal bidders only + assert(bidder.is_normal()); + + + // get 2 best items among non-diagonal points from kdtree_ + DnnPoint bidder_dnn; + bidder_dnn[0] = bidder.getRealX(); + bidder_dnn[1] = bidder.getRealY(); + auto two_best_items = kdtree_->findK(bidder_dnn, 2); + size_t best_normal_item_idx { two_best_items[0].p->id() }; + Real best_normal_item_value { two_best_items[0].d }; + // if there is only one off-diagonal point in the second diagram, + // kd-tree will not return the second candidate. + // Set its price to inf, so it will always lose to the price of the projection + Real second_best_normal_item_value { two_best_items.size() == 1 ? std::numeric_limits::max() : two_best_items[1].d }; + + size_t best_diag_item_idx; + Real best_diag_value; + Real best_diag_price; + + { + Real diag_edge_cost = std::pow(bidder.persistence_lp(this->internal_p), this->wasserstein_power); + auto best_diag_price_in_heap = diag_items_heap_.empty() ? std::numeric_limits::max() : diag_items_heap_.get_best_slice().loss; + auto best_diag_idx_in_heap = diag_items_heap_.empty() ? k_invalid_index : diag_items_heap_.get_best_slice().item_idx; + // if unassigned_diag_slice is empty, its price is max, + // same for diag-diag assigned slice, so the ifs below will work + + if (best_diag_price_in_heap <= diag_to_diag_price_ and best_diag_price_in_heap <= diag_unassigned_price_) { + best_diag_item_idx = best_diag_idx_in_heap; + best_diag_value = diag_edge_cost + best_diag_price_in_heap; + best_diag_price = best_diag_price_in_heap; + } else if (diag_to_diag_price_ < best_diag_price_in_heap and diag_to_diag_price_ < diag_unassigned_price_) { + best_diag_item_idx = *diag_assigned_to_diag_slice_.begin(); + best_diag_value = diag_edge_cost + diag_to_diag_price_; + best_diag_price = diag_to_diag_price_; + } else { + best_diag_item_idx = *diag_unassigned_slice_.begin(); + best_diag_value = diag_edge_cost + diag_unassigned_price_; + best_diag_price = diag_unassigned_price_; + } + + } + + if ( best_diag_value < best_normal_item_value) { + best_item_idx = best_diag_item_idx; + best_item_price = best_diag_price; + best_item_value = best_diag_value; + second_best_item_value = best_normal_item_value; + } else if (best_diag_value < second_best_normal_item_value) { + best_item_idx = best_normal_item_idx; + best_item_price = this->get_price(best_item_idx); + best_item_value = best_normal_item_value; + second_best_item_value = best_diag_value; + } else { + best_item_idx = best_normal_item_idx; + best_item_price = this->get_price(best_item_idx); + best_item_value = best_normal_item_value; + second_best_item_value = second_best_normal_item_value; + } + + IdxValPair result; + + result.first = best_item_idx; + result.second = ( second_best_item_value - best_item_value ) + best_item_price + this->epsilon; + + //std::cout << "bidder_idx = " << bidder_idx << ", best_item_idx = " << best_item_idx << ", best_item_value = " << best_item_value << ", second_best_item_value = " << second_best_item_value << ", eps = " << this->epsilon << std::endl; + assert( second_best_item_value >= best_item_value ); + //assert( best_item_price == this->get_price(best_item_idx) ); + assert(result.second >= best_item_price); + sanity_check(); + + return result; +} +/* +a_{ij} = d_{ij} +price_{ij} = a_{ij} + price_j +*/ + + + +//template +//std::vector AuctionOracleKDTreeSingleDiag::increase_price_of_assigned_to_diag(WHAT) +//{ + //WHAT; +//} +// + +template +Real_ AuctionOracleKDTreeSingleDiag::get_price(const size_t item_idx) const +{ + if (is_item_diagonal(item_idx)) { + if (diag_assigned_to_diag_slice_.count(item_idx) == 1) { + return diag_to_diag_price_; + } else if (diag_unassigned_slice_.count(item_idx) == 1) { + return diag_unassigned_price_; + } + } + return this-> prices[item_idx]; +} + +template +void AuctionOracleKDTreeSingleDiag::set_price(const size_t item_idx, + const Real new_price, + const bool item_is_diagonal, + const bool bidder_is_diagonal, + const OwnerType old_owner_type) +{ + + //std::cout << std::boolalpha << "enter set_price, item_idx = " << item_idx << ", new_price = " << new_price << ", old price = " << this->get_price(item_idx); + //std::cout << ", item_is_diagonal = " << item_is_diagonal << ", bidder_is_diagonal = " << bidder_is_diagonal << ", old_owner_type = " << old_owner_type << std::endl; + + bool item_is_normal = not item_is_diagonal; + bool bidder_is_normal = not bidder_is_diagonal; + + assert( new_price >= this->get_price(item_idx) ); + + // update vector prices + if (item_is_normal or bidder_is_normal) { + this->prices[item_idx] = new_price; + } + + // update kdtree_ + if (item_is_normal) { + assert(0 <= item_idx and item_idx < kdtree__items_.size()); + assert(0 <= kdtree__items_[item_idx] and kdtree__items_[item_idx] < dnn_point_handles_.size()); + kdtree_->change_weight( dnn_point_handles_[kdtree__items_[item_idx]], new_price); + } + + // update all_items_heap_ + if (bidder_is_diagonal and item_is_diagonal) { + // remove slice (item is buried in diag_assigned_to_diag_slice_) + assert(old_owner_type != OwnerType::k_diagonal); + auto iter = all_items_heap__iters_[item_idx]; + assert(iter != all_items_heap_.end()); + all_items_heap_.erase(iter); + all_items_heap__iters_[item_idx] = all_items_heap_.end(); + } else { + auto iter = all_items_heap__iters_[item_idx]; + if (iter != all_items_heap_.end()) { + // update existing element + ItemSliceR x = *iter; + x.set_loss( this->get_value_for_diagonal_bidder(item_idx) ); + all_items_heap_.erase(iter); + auto ins_res = all_items_heap_.insert(x); + all_items_heap__iters_[item_idx] = ins_res.first; + assert(ins_res.second); + } else { + // insert new slice + // for diagonal items value = price + ItemSliceR x { item_idx, new_price }; + auto ins_res = all_items_heap_.insert(x); + all_items_heap__iters_[item_idx] = ins_res.first; + assert(ins_res.second); + } + } + + // update diag_items_heap_ + if (item_is_diagonal and bidder_is_normal) { + // update existing element + auto iter = diag_items_heap__iters_[item_idx]; + if (iter != diag_items_heap_.end()) { + ItemSliceR x = *iter; + x.set_loss( new_price ); + diag_items_heap_.erase(iter); + auto ins_res = diag_items_heap_.insert(x); + diag_items_heap__iters_[item_idx] = ins_res.first; + assert(ins_res.second); + } else { + // insert new slice + // for diagonal items value = price + ItemSliceR x { item_idx, new_price }; + auto ins_res = diag_items_heap_.insert(x); + diag_items_heap__iters_[item_idx] = ins_res.first; + assert(ins_res.second); + } + } else if (bidder_is_diagonal and item_is_diagonal ) { + // remove slice (item is buried in diag_assigned_to_diag_slice_) + assert(old_owner_type != OwnerType::k_diagonal); + auto iter = diag_items_heap__iters_[item_idx]; + assert(iter != diag_items_heap_.end()); + diag_items_heap_.erase(iter); + diag_items_heap__iters_[item_idx] = diag_items_heap_.end(); + } + + // update diag_unassigned_price_ + if (item_is_diagonal and old_owner_type == OwnerType::k_none and diag_unassigned_slice_.empty()) { + diag_unassigned_price_ = std::numeric_limits::max(); + } + +} + + +template +bool AuctionOracleKDTreeSingleDiag::is_item_diagonal(const size_t item_idx) const +{ + return item_idx < this->num_diag_items_; +} + + +template +void AuctionOracleKDTreeSingleDiag::flush_assignment() +{ + //std::cout << "enter oracle->flush_assignment" << std::endl; + sanity_check(); + + for(const auto item_idx : diag_assigned_to_diag_slice_) { + diag_unassigned_slice_.insert(item_idx); + } + diag_assigned_to_diag_slice_.clear(); + + // common price of diag-diag items becomes price of diag-unassigned-slice + // diag_to_diag_slice is now empty, set its price to max + // so that get_optimal_bid works correctly + diag_unassigned_price_ = diag_to_diag_price_; + diag_to_diag_price_ = std::numeric_limits::max(); + + normal_items_assigned_to_diag_.clear(); + + sanity_check(); +} + + +template +void AuctionOracleKDTreeSingleDiag::adjust_prices() +{ + return; + + throw std::runtime_error("not implemented"); + auto pr_begin = this->prices.begin(); + auto pr_end = this->prices.end(); + + Real min_price = *(std::min_element(pr_begin, pr_end)); + + for(auto& p : this->prices) { + p -= min_price; + } + + kdtree_->adjust_weights(min_price); + diag_items_heap_.adjust_prices(min_price); + all_items_heap_.adjust_prices(min_price); +} + + +template +AuctionOracleKDTreeSingleDiag::~AuctionOracleKDTreeSingleDiag() +{ + delete kdtree_; +} + +template +void AuctionOracleKDTreeSingleDiag::sanity_check() +{ +#ifdef DEBUG_AUCTION + + //std::cout << "ORACLE CURRENT STATE IN SANITY CHECK" << *this << std::endl; + + assert( diag_items_heap_.size() + diag_assigned_to_diag_slice_.size() + diag_unassigned_slice_.size() == num_diag_items_ ); + assert( diag_items_heap__iters_.size() == num_diag_items_ ); + for(size_t i = 0; i < num_diag_items_; ++i) { + if (diag_items_heap__iters_.at(i) != diag_items_heap_.end()) { + assert(diag_items_heap__iters_[i]->item_idx == i); + } + } + + assert( all_items_heap_.size() + diag_assigned_to_diag_slice_.size() + diag_unassigned_slice_.size() == this->num_items_ ); + assert( all_items_heap__iters_.size() == this->num_items_ ); + for(size_t i = 0; i < this->num_items_; ++i) { + if (all_items_heap__iters_.at(i) != all_items_heap_.end()) { + assert(all_items_heap__iters_[i]->item_idx == i); + } else { + assert( i < num_diag_items_ ); + } + } + + for(size_t i = 0; i < num_diag_items_; ++i) { + int is_in_assigned_slice = diag_assigned_to_diag_slice_.count(i); + int is_in_unassigned_slice = diag_unassigned_slice_.count(i); + int is_in_heap = diag_items_heap__iters_[i] != diag_items_heap_.end(); + assert( is_in_assigned_slice + is_in_unassigned_slice + is_in_heap == 1); + } + + //assert((diag_assigned_to_diag_slice_.empty() and diag_to_diag_price_ == std::numeric_limits::max()) or (not diag_assigned_to_diag_slice_.empty() and diag_to_diag_price_ != std::numeric_limits::max())); + //assert((diag_unassigned_slice_.empty() and diag_unassigned_price_ == std::numeric_limits::max()) or (not diag_unassigned_slice_.empty() and diag_unassigned_price_ != std::numeric_limits::max())); + + assert(diag_assigned_to_diag_slice_.empty() or diag_to_diag_price_ != std::numeric_limits::max()); + assert(diag_unassigned_slice_.empty() or diag_unassigned_price_ != std::numeric_limits::max()); +#endif +} + + +} // ws +} // hera +#endif diff --git a/src/dionysus/wasserstein/auction_oracle_lazy_heap.h b/src/dionysus/wasserstein/auction_oracle_lazy_heap.h new file mode 100755 index 0000000..8b37421 --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle_lazy_heap.h @@ -0,0 +1,191 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef AUCTION_ORACLE_LAZY_HEAP_H +#define AUCTION_ORACLE_LAZY_HEAP_H + + +#define USE_BOOST_HEAP + +#include +#include +#include +#include + +#ifdef USE_BOOST_HEAP +#include +#endif + +#include "basic_defs_ws.h" + +namespace ws { + +template +struct CompPairsBySecondStruct { + bool operator()(const IdxValPair& a, const IdxValPair& b) const + { + return a.second < b.second; + } +}; + + +template +struct CompPairsBySecondGreaterStruct { + bool operator()(const IdxValPair& a, const IdxValPair& b) const + { + return a.second > b.second; + } +}; + +template +struct CompPairsBySecondLexStruct { + bool operator()(const IdxValPair& a, const IdxValPair& b) const + { + return a.second < b.second or (a.second == b.second and a.first > b.first); + } +}; + +template +struct CompPairsBySecondLexGreaterStruct { + bool operator()(const IdxValPair& a, const IdxValPair& b) const + { + return a.second > b.second or (a.second == b.second and a.first > b.first); + } +}; + +using ItemsTimePair = std::pair; +using UpdateList = std::list; +using UpdateListIter = UpdateList::iterator; + + +#ifdef USE_BOOST_HEAP +template +using LossesHeap = boost::heap::d_ary_heap, boost::heap::arity<2>, boost::heap::mutable_, boost::heap::compare>>; +#else +template +class IdxValHeap { +public: + using InternalKeeper = std::set, ComparisonStruct>; + using handle_type = typename InternalKeeper::iterator; + // methods + handle_type push(const IdxValPair& val) + { + auto res_pair = _heap.insert(val); + assert(res_pair.second); + assert(res_pair.first != _heap.end()); + return res_pair.first; + } + + void decrease(handle_type& handle, const IdxValPair& new_val) + { + _heap.erase(handle); + handle = push(new_val); + } + + void increase(handle_type& handle, const IdxValPair& new_val) + { + _heap.erase(handle); + handle = push(new_val); + + size_t size() const + { + return _heap.size(); + } + + handle_type ordered_begin() + { + return _heap.begin(); + } + + handle_type ordered_end() + { + return _heap.end(); + } + + +private: + std::set, ComparisonStruct> _heap; +}; + +// if we store losses, the minimal value should come first +template +using LossesHeap = IdxValHeap; +#endif + + +template +struct AuctionOracleLazyHeapRestricted : AuctionOracleBase { + + using LossesHeapR = typename ws::LossesHeap; + using LossesHeapRHandle = typename ws::LossesHeap::handle_type; + using DiagramPointR = typename ws::DiagramPoint; + + + AuctionOracleLazyHeapRestricted(const std::vector& bidders, const std::vector& items, const Real wasserstein_power, const Real _internal_p = get_infinity()); + ~AuctionOracleLazyHeapRestricted(); + // data members + // temporarily make everything public + std::vector> weight_matrix; + //Real weight_adj_const; + Real max_val; + // vector of heaps to find the best items + std::vector losses_heap; + std::vector> items_indices_for_heap_handles; + std::vector> losses_heap_handles; + // methods + void fill_in_losses_heap(); + void set_price(const IdxType items_idx, const Real new_price); + IdxValPair get_optimal_bid(const IdxType bidder_idx); + Real get_matching_weight(const std::vector& bidders_to_items) const; + void adjust_prices(); + // to update the queue in lazy fashion + std::vector items_iterators; + UpdateList update_list; + std::vector bidders_update_moments; + int update_counter; + void update_queue_for_bidder(const IdxType bidder_idx); + LossesHeapR diag_items_heap; + std::vector diag_heap_handles; + std::vector heap_handles_indices; + // debug + + DebugOptimalBid get_optimal_bid_debug(const IdxType bidder_idx); + + // for diagonal points + bool best_diagonal_items_computed; + size_t best_diagonal_item_idx; + Real best_diagonal_item_value; + size_t second_best_diagonal_item_idx; + Real second_best_diagonal_item_value; +}; + +} // end of namespace ws + +#include "auction_oracle_lazy_heap.h" + +#endif diff --git a/src/dionysus/wasserstein/auction_oracle_lazy_heap.hpp b/src/dionysus/wasserstein/auction_oracle_lazy_heap.hpp new file mode 100755 index 0000000..d179b3d --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle_lazy_heap.hpp @@ -0,0 +1,465 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#include +#include +#include +#include + +#include "def_debug_ws.h" +#include "auction_oracle.h" + + +#ifdef FOR_R_TDA +#undef DEBUG_AUCTION +#endif + +namespace ws { + +// ***************************** +// AuctionOracleLazyHeapRestricted +// ***************************** + + +template +AuctionOracleLazyHeapRestricted::AuctionOracleLazyHeapRestricted(const std::vector>& _bidders, + const std::vector>& _items, + Real _wasserstein_power, + Real _internal_p) : + AuctionOracleAbstract(_bidders, _items, _wasserstein_power, _internal_p), + max_val(0.0), + bidders_update_moments(_bidders.size(), 0), + update_counter(0), + heap_handles_indices(_items.size(), k_invalid_index), + best_diagonal_items_computed(false) +{ + weight_matrix.reserve(_bidders.size()); + //const Real max_dist_upper_bound = 3 * getFurthestDistance3Approx(b, g); + //weight_adj_const = pow(max_dist_upper_bound, wasserstein_power); + // init weight matrix + for(const auto& point_A : _bidders) { + std::vector weight_vec; + weight_vec.clear(); + weight_vec.reserve(_bidders.size()); + for(const auto& point_B : _items) { + Real val = pow(dist_lp(point_A, point_B, _internal_p), _wasserstein_power); + weight_vec.push_back( val ); + if ( val > max_val ) + max_val = val; + } + weight_matrix.push_back(weight_vec); + } + fill_in_losses_heap(); + for(size_t item_idx = 0; item_idx < _items.size(); ++item_idx) { + update_list.push_back(std::make_pair(static_cast(item_idx), 0)); + } + for(auto update_list_iter = update_list.begin(); update_list_iter != update_list.end(); ++update_list_iter) { + items_iterators.push_back(update_list_iter); + } + + size_t handle_idx {0}; + for(size_t item_idx = 0; item_idx < _items.size(); ++item_idx) { + if (_items[item_idx].is_diagonal() ) { + heap_handles_indices[item_idx] = handle_idx++; + diag_heap_handles.push_back(diag_items_heap.push(std::make_pair(item_idx, 0))); + } + } +} + + +template +void AuctionOracleLazyHeapRestricted::update_queue_for_bidder(IdxType bidder_idx) +{ + assert(0 <= bidder_idx and bidder_idx < static_cast(this->bidders.size())); + assert(bidder_idx < static_cast(bidders_update_moments.size())); + assert(losses_heap[bidder_idx] != nullptr ); + + int bidder_last_update_time = bidders_update_moments[bidder_idx]; + auto iter = update_list.begin(); + while (iter != update_list.end() and iter->second >= bidder_last_update_time) { + IdxType item_idx = iter->first; + size_t handle_idx = items_indices_for_heap_handles[bidder_idx][item_idx]; + if (handle_idx < this->items.size() ) { + IdxValPair new_val { item_idx, weight_matrix[bidder_idx][item_idx] + this->prices[item_idx]}; + // to-do: change indexing of losses_heap_handles + losses_heap[bidder_idx]->decrease(losses_heap_handles[bidder_idx][handle_idx], new_val); + } + iter++; + } + bidders_update_moments[bidder_idx] = update_counter; +} + + +template +void AuctionOracleLazyHeapRestricted::fill_in_losses_heap() +{ + using LossesHeapR = typename ws::LossesHeap; + using LossesHeapRHandleVec = typename std::vector::handle_type>; + + for(size_t bidder_idx = 0; bidder_idx < this->bidders.size(); ++bidder_idx) { + DiagramPoint bidder { this->bidders[bidder_idx] }; + // no heap for diagonal bidders + if ( bidder.is_diagonal() ) { + losses_heap.push_back( nullptr ); + losses_heap_handles.push_back(LossesHeapRHandleVec()); + items_indices_for_heap_handles.push_back( std::vector() ); + continue; + } else { + losses_heap.push_back( new LossesHeapR() ); + + assert( losses_heap.at(bidder_idx) != nullptr ); + + items_indices_for_heap_handles.push_back( std::vector(this->items.size(), k_invalid_index) ); + LossesHeapRHandleVec handles_vec; + losses_heap_handles.push_back(handles_vec); + size_t handle_idx { 0 }; + for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + assert( items_indices_for_heap_handles.at(bidder_idx).at(item_idx) > 0 ); + DiagramPoint item { this->items[item_idx] }; + if ( item.is_normal() ) { + // item can be assigned to bidder, store in heap + IdxValPair vp { item_idx, weight_matrix[bidder_idx][item_idx] + this->prices[item_idx] }; + losses_heap_handles[bidder_idx].push_back( losses_heap[bidder_idx]->push(vp) ); + // keep corresponding index in items_indices_for_heap_handles + items_indices_for_heap_handles[bidder_idx][item_idx] = handle_idx++; + } + } + } + } +} + + +template +AuctionOracleLazyHeapRestricted::~AuctionOracleLazyHeapRestricted() +{ + for(auto h : losses_heap) { + delete h; + } +} + + +template +void AuctionOracleLazyHeapRestricted::set_price(IdxType item_idx, Real new_price) +{ + assert( this->prices.at(item_idx) < new_price ); +#ifdef DEBUG_AUCTION + std::cout << "price incremented by " << this->prices.at(item_idx) - new_price << std::endl; +#endif + this->prices[item_idx] = new_price; + if (this->items[item_idx].is_normal() ) { + // lazy: record the moment we updated the price of the items, + // do not update queues. + // 1. move the items with updated price to the front of the update_list, + update_list.splice(update_list.begin(), update_list, items_iterators[item_idx]); + // 2. record the moment we updated the price and increase the counter + update_list.front().second = update_counter++; + } else { + // diagonal items are stored in one heap + diag_items_heap.decrease(diag_heap_handles[heap_handles_indices[item_idx]], std::make_pair(item_idx, new_price)); + best_diagonal_items_computed = false; + } +} + +// subtract min. price from all prices +template +void AuctionOracleLazyHeapRestricted::adjust_prices() +{ +} + + +template +DebugOptimalBid AuctionOracleLazyHeapRestricted::get_optimal_bid_debug(IdxType bidder_idx) +{ + DebugOptimalBid result; + assert(bidder_idx >=0 and bidder_idx < static_cast(this->bidders.size()) ); + + auto bidder = this->bidders[bidder_idx]; + std::vector> cand_items; + // corresponding point is always considered as a candidate + + size_t proj_item_idx = bidder_idx; + assert( 0 <= proj_item_idx and proj_item_idx < this->items.size() ); + auto proj_item = this->items[proj_item_idx]; + assert(proj_item.type != bidder.type); + //assert(proj_item.proj_id == bidder.id); + //assert(proj_item.id == bidder.proj_id); + // todo: store precomputed distance? + Real proj_item_value = this->get_value_for_bidder(bidder_idx, proj_item_idx); + cand_items.push_back( std::make_pair(proj_item_idx, proj_item_value) ); + + if (bidder.is_normal()) { + assert(losses_heap.at(bidder_idx) != nullptr); + assert(losses_heap[bidder_idx]->size() >= 2); + update_queue_for_bidder(bidder_idx); + auto pHeap = losses_heap[bidder_idx]; + assert( pHeap != nullptr ); + auto top_iter = pHeap->ordered_begin(); + cand_items.push_back( *top_iter ); + ++top_iter; // now points to the second-best items + cand_items.push_back( *top_iter ); + std::sort(cand_items.begin(), cand_items.end(), CompPairsBySecondStruct()); + assert(cand_items[1].second >= cand_items[0].second); + } else { + // for diagonal bidder the only normal point has already been added + // the other 2 candidates are diagonal items only, get from the heap + // with prices + assert(diag_items_heap.size() > 1); + auto top_diag_iter = diag_items_heap.ordered_begin(); + auto topDiag1 = *top_diag_iter++; + auto topDiag2 = *top_diag_iter; + cand_items.push_back(topDiag1); + cand_items.push_back(topDiag2); + std::sort(cand_items.begin(), cand_items.end(), CompPairsBySecondStruct()); + assert(cand_items.size() == 3); + assert(cand_items[2].second >= cand_items[1].second); + assert(cand_items[1].second >= cand_items[0].second); + } + + result.best_item_idx = cand_items[0].first; + result.second_best_item_idx = cand_items[1].first; + result.best_item_value = cand_items[0].second; + result.second_best_item_value = cand_items[1].second; + + // checking code + + //DebugOptimalBid debug_my_result(result); + //DebugOptimalBid debug_naive_result; + //debug_naive_result.best_item_value = 1e20; + //debug_naive_result.second_best_item_value = 1e20; + //Real curr_item_value; + //for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + //if ( this->bidders[bidder_idx].type != this->items[item_idx].type and + //this->bidders[bidder_idx].proj_id != this->items[item_idx].id) + //continue; + + //curr_item_value = pow(dist_lp(this->bidders[bidder_idx], this->items[item_idx]), wasserstein_power) + this->prices[item_idx]; + //if (curr_item_value < debug_naive_result.best_item_value) { + //debug_naive_result.best_item_value = curr_item_value; + //debug_naive_result.best_item_idx = item_idx; + //} + //} + + //for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + //if (item_idx == debug_naive_result.best_item_idx) { + //continue; + //} + //if ( this->bidders[bidder_idx].type != this->items[item_idx].type and + //this->bidders[bidder_idx].proj_id != this->items[item_idx].id) + //continue; + + //curr_item_value = pow(dist_lp(this->bidders[bidder_idx], this->items[item_idx]), wasserstein_power) + this->prices[item_idx]; + //if (curr_item_value < debug_naive_result.second_best_item_value) { + //debug_naive_result.second_best_item_value = curr_item_value; + //debug_naive_result.second_best_item_idx = item_idx; + //} + //} + + //if ( fabs( debug_my_result.best_item_value - debug_naive_result.best_item_value ) > 1e-6 or + //fabs( debug_naive_result.second_best_item_value - debug_my_result.second_best_item_value) > 1e-6 ) { + //std::cerr << "bidder_idx = " << bidder_idx << "; "; + //std::cerr << this->bidders[bidder_idx] << std::endl; + //for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + //std::cout << item_idx << ": " << this->items[item_idx] << "; price = " << this->prices[item_idx] << std::endl; + //} + //std::cerr << "debug_my_result: " << debug_my_result << std::endl; + //std::cerr << "debug_naive_result: " << debug_naive_result << std::endl; + //auto pHeap = losses_heap[bidder_idx]; + //assert( pHeap != nullptr ); + //for(auto top_iter = pHeap->ordered_begin(); top_iter != pHeap->ordered_end(); ++top_iter) { + //std::cerr << "in heap: " << top_iter->first << ": " << top_iter->second << "; real value = " << dist_lp(bidder, this->items[top_iter->first]) + this->prices[top_iter->first] << std::endl; + //} + //for(auto ci : cand_items) { + //std::cout << "ci.idx = " << ci.first << ", value = " << ci.second << std::endl; + //} + + ////std::cerr << "two_best_items: " << two_best_items[0].d << " " << two_best_items[1].d << std::endl; + //assert(false); + //} + + + //std::cout << "get_optimal_bid: bidder_idx = " << bidder_idx << "; best_item_idx = " << best_item_idx << "; best_item_value = " << best_item_value << "; best_items_price = " << this->prices[best_item_idx] << "; second_best_item_idx = " << top_iter->first << "; second_best_value = " << second_best_item_value << "; second_best_price = " << this->prices[top_iter->first] << "; bid = " << this->prices[best_item_idx] + ( best_item_value - second_best_item_value ) + epsilon << "; epsilon = " << epsilon << std::endl; + //std::cout << "get_optimal_bid: bidder_idx = " << bidder_idx << "; best_item_idx = " << best_item_idx << "; best_items_dist= " << (weight_adj_const - best_item_value) << "; best_items_price = " << this->prices[best_item_idx] << "; second_best_item_idx = " << top_iter->first << "; second_best_dist= " << (weight_adj_const - second_best_item_value) << "; second_best_price = " << this->prices[top_iter->first] << "; bid = " << this->prices[best_item_idx] + ( best_item_value - second_best_item_value ) + epsilon << "; epsilon = " << epsilon << std::endl; + + return result; +} + + +template +IdxValPair AuctionOracleLazyHeapRestricted::get_optimal_bid(const IdxType bidder_idx) +{ + IdxType best_item_idx; + //IdxType second_best_item_idx; + Real best_item_value; + Real second_best_item_value; + + auto& bidder = this->bidders[bidder_idx]; + IdxType proj_item_idx = bidder_idx; + assert( 0 <= proj_item_idx and proj_item_idx < this->items.size() ); + auto proj_item = this->items[proj_item_idx]; + assert(proj_item.type != bidder.type); + //assert(proj_item.proj_id == bidder.id); + //assert(proj_item.id == bidder.proj_id); + // todo: store precomputed distance? + Real proj_item_value = this->get_value_for_bidder(bidder_idx, proj_item_idx); + + if (bidder.is_diagonal()) { + // for diagonal bidder the only normal point has already been added + // the other 2 candidates are diagonal items only, get from the heap + // with prices + assert(diag_items_heap.size() > 1); + if (!best_diagonal_items_computed) { + auto top_diag_iter = diag_items_heap.ordered_begin(); + best_diagonal_item_idx = top_diag_iter->first; + best_diagonal_item_value = top_diag_iter->second; + top_diag_iter++; + second_best_diagonal_item_idx = top_diag_iter->first; + second_best_diagonal_item_value = top_diag_iter->second; + best_diagonal_items_computed = true; + } + + if ( proj_item_value < best_diagonal_item_value) { + best_item_idx = proj_item_idx; + best_item_value = proj_item_value; + second_best_item_value = best_diagonal_item_value; + //second_best_item_idx = best_diagonal_item_idx; + } else if (proj_item_value < second_best_diagonal_item_value) { + best_item_idx = best_diagonal_item_idx; + best_item_value = best_diagonal_item_value; + second_best_item_value = proj_item_value; + //second_best_item_idx = proj_item_idx; + } else { + best_item_idx = best_diagonal_item_idx; + best_item_value = best_diagonal_item_value; + second_best_item_value = second_best_diagonal_item_value; + //second_best_item_idx = second_best_diagonal_item_idx; + } + } else { + // for normal bidder get 2 best items among non-diagonal (=normal) points + // from the corresponding heap + assert(diag_items_heap.size() > 1); + update_queue_for_bidder(bidder_idx); + auto top_norm_iter = losses_heap[bidder_idx]->ordered_begin(); + IdxType best_normal_item_idx { top_norm_iter->first }; + Real best_normal_item_value { top_norm_iter->second }; + top_norm_iter++; + Real second_best_normal_item_value { top_norm_iter->second }; + //IdxType second_best_normal_item_idx { top_norm_iter->first }; + + if ( proj_item_value < best_normal_item_value) { + best_item_idx = proj_item_idx; + best_item_value = proj_item_value; + second_best_item_value = best_normal_item_value; + //second_best_item_idx = best_normal_item_idx; + } else if (proj_item_value < second_best_normal_item_value) { + best_item_idx = best_normal_item_idx; + best_item_value = best_normal_item_value; + second_best_item_value = proj_item_value; + //second_best_item_idx = proj_item_idx; + } else { + best_item_idx = best_normal_item_idx; + best_item_value = best_normal_item_value; + second_best_item_value = second_best_normal_item_value; + //second_best_item_idx = second_best_normal_item_idx; + } + } + + IdxValPair result; + + assert( second_best_item_value >= best_item_value ); + + result.first = best_item_idx; + result.second = ( second_best_item_value - best_item_value ) + this->prices[best_item_idx] + this->epsilon; + + + // checking code + + //DebugOptimalBid debug_my_result; + //debug_my_result.best_item_idx = best_item_idx; + //debug_my_result.best_item_value = best_item_value; + //debug_my_result.second_best_item_idx = second_best_item_idx; + //debug_my_result.second_best_item_value = second_best_item_value; + //DebugOptimalBid debug_naive_result; + //debug_naive_result.best_item_value = 1e20; + //debug_naive_result.second_best_item_value = 1e20; + //Real curr_item_value; + //for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + //if ( this->bidders[bidder_idx].type != this->items[item_idx].type and + //this->bidders[bidder_idx].proj_id != this->items[item_idx].id) + //continue; + + //curr_item_value = this->get_value_for_bidder(bidder_idx, item_idx); + //if (curr_item_value < debug_naive_result.best_item_value) { + //debug_naive_result.best_item_value = curr_item_value; + //debug_naive_result.best_item_idx = item_idx; + //} + //} + + //for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + //if (item_idx == debug_naive_result.best_item_idx) { + //continue; + //} + //if ( this->bidders[bidder_idx].type != this->items[item_idx].type and + //this->bidders[bidder_idx].proj_id != this->items[item_idx].id) + //continue; + + //curr_item_value = this->get_value_for_bidder(bidder_idx, item_idx); + //if (curr_item_value < debug_naive_result.second_best_item_value) { + //debug_naive_result.second_best_item_value = curr_item_value; + //debug_naive_result.second_best_item_idx = item_idx; + //} + //} + ////std::cout << "got naive result" << std::endl; + + //if ( fabs( debug_my_result.best_item_value - debug_naive_result.best_item_value ) > 1e-6 or + //fabs( debug_naive_result.second_best_item_value - debug_my_result.second_best_item_value) > 1e-6 ) { + //std::cerr << "bidder_idx = " << bidder_idx << "; "; + //std::cerr << this->bidders[bidder_idx] << std::endl; + //for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + //std::cout << item_idx << ": " << this->items[item_idx] << "; price = " << this->prices[item_idx] << std::endl; + //} + //std::cerr << "debug_my_result: " << debug_my_result << std::endl; + //std::cerr << "debug_naive_result: " << debug_naive_result << std::endl; + //auto pHeap = losses_heap[bidder_idx]; + //if ( pHeap != nullptr ) { + //for(auto top_iter = pHeap->ordered_begin(); top_iter != pHeap->ordered_end(); ++top_iter) { + //std::cerr << "in heap: " << top_iter->first << ": " << top_iter->second << "; real value = " << dist_lp(bidder, this->items[top_iter->first]) + this->prices[top_iter->first] << std::endl; + //} + //} + ////for(auto ci : cand_items) { + ////std::cout << "ci.idx = " << ci.first << ", value = " << ci.second << std::endl; + ////} + + ////std::cerr << "two_best_items: " << two_best_items[0].d << " " << two_best_items[1].d << std::endl; + //assert(false); + // } + //std::cout << "get_optimal_bid: bidder_idx = " << bidder_idx << "; best_item_idx = " << best_item_idx << "; best_item_value = " << best_item_value << "; best_items_price = " << this->prices[best_item_idx] << "; second_best_item_idx = " << top_iter->first << "; second_best_value = " << second_best_item_value << "; second_best_price = " << this->prices[top_iter->first] << "; bid = " << this->prices[best_item_idx] + ( best_item_value - second_best_item_value ) + epsilon << "; epsilon = " << epsilon << std::endl; + //std::cout << "get_optimal_bid: bidder_idx = " << bidder_idx << "; best_item_idx = " << best_item_idx << "; best_items_dist= " << (weight_adj_const - best_item_value) << "; best_items_price = " << this->prices[best_item_idx] << "; second_best_item_idx = " << top_iter->first << "; second_best_dist= " << (weight_adj_const - second_best_item_value) << "; second_best_price = " << this->prices[top_iter->first] << "; bid = " << this->prices[best_item_idx] + ( best_item_value - second_best_item_value ) + epsilon << "; epsilon = " << epsilon << std::endl; + + return result; +} + +} // end of namespace ws diff --git a/src/dionysus/wasserstein/auction_oracle_stupid_sparse_restricted.h b/src/dionysus/wasserstein/auction_oracle_stupid_sparse_restricted.h new file mode 100755 index 0000000..c932396 --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle_stupid_sparse_restricted.h @@ -0,0 +1,114 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef AUCTION_ORACLE_STUPID_SPARSE_RESTRICTED_H +#define AUCTION_ORACLE_STUPID_SPARSE_RESTRICTED_H + +#include +#include +#include + +#include "basic_defs_ws.h" +#include "diagonal_heap.h" +#include "auction_oracle_base.h" +#include "dnn/geometry/euclidean-fixed.h" +#include "dnn/local/kd-tree.h" + + +namespace hera { +namespace ws { + +template >> +struct AuctionOracleStupidSparseRestricted : AuctionOracleBase { + + using PointContainer = PointContainer_; + using Real = Real_; + + using LossesHeapR = typename ws::LossesHeapOld; + using LossesHeapRHandle = typename ws::LossesHeapOld::handle_type; + using DiagramPointR = typename ws::DiagramPoint; + using DebugOptimalBidR = typename ws::DebugOptimalBid; + + using DnnPoint = dnn::Point<2, Real>; + using DnnTraits = dnn::PointTraits; + + + AuctionOracleStupidSparseRestricted(const PointContainer& bidders, const PointContainer& items, const AuctionParams& params); + // data members + // temporarily make everything public + std::vector> admissible_items_; + Real max_val_; + LossesHeapR diag_items_heap_; + std::vector diag_heap_handles_; + std::vector heap_handles_indices_; + +// std::vector kdtree_items_; + + std::vector top_diag_indices_; + std::vector top_diag_lookup_; + size_t top_diag_counter_ { 0 }; + bool best_diagonal_items_computed_ { false }; + Real best_diagonal_item_value_; + Real second_best_diagonal_item_idx_ { k_invalid_index }; + Real second_best_diagonal_item_value_ { std::numeric_limits::max() }; + + + // methods + void set_price(const IdxType items_idx, const Real new_price, const bool update_diag = true); + IdxValPair get_optimal_bid(const IdxType bidder_idx); + void adjust_prices(); + void adjust_prices(const Real delta); + + // debug routines + DebugOptimalBidR get_optimal_bid_debug(IdxType bidder_idx) const; + void sanity_check(); + + + // heap top vector + size_t get_heap_top_size() const; + void recompute_top_diag_items(bool hard = false); + void recompute_second_best_diag(); + void reset_top_diag_counter(); + void increment_top_diag_counter(); + void add_top_diag_index(const size_t item_idx); + void remove_top_diag_index(const size_t item_idx); + bool is_in_top_diag_indices(const size_t item_idx) const; + + std::shared_ptr console_logger; + + std::pair get_minmax_price() const; + +}; + +} // ws +} // hera + + +#include "auction_oracle_stupid_sparse_restricted.hpp" + +#endif diff --git a/src/dionysus/wasserstein/auction_oracle_stupid_sparse_restricted.hpp b/src/dionysus/wasserstein/auction_oracle_stupid_sparse_restricted.hpp new file mode 100755 index 0000000..8f4504d --- /dev/null +++ b/src/dionysus/wasserstein/auction_oracle_stupid_sparse_restricted.hpp @@ -0,0 +1,568 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ +#ifndef AUCTION_ORACLE_STUPID_SPARSE_HPP +#define AUCTION_ORACLE_STUPID_SPARSE_HPP + +#include +#include +#include +#include + +#include "def_debug_ws.h" +#include "basic_defs_ws.h" +#include "auction_oracle_stupid_sparse_restricted.h" + +#ifdef FOR_R_TDA +#undef DEBUG_AUCTION +#endif + +namespace hera { +namespace ws { + + +// ***************************** +// AuctionOracleStupidSparseRestricted +// ***************************** + + +template +std::ostream& operator<<(std::ostream& output, const AuctionOracleStupidSparseRestricted& oracle) +{ + output << "Oracle " << &oracle << std::endl; + output << fmt::format(" max_val_ = {0}, best_diagonal_items_computed_ = {1}, best_diagonal_item_value_ = {2}, second_best_diagonal_item_idx_ = {3}, second_best_diagonal_item_value_ = {4}\n", + oracle.max_val_, + oracle.best_diagonal_items_computed_, + oracle.best_diagonal_item_value_, + oracle.second_best_diagonal_item_idx_, + oracle.second_best_diagonal_item_value_); + + output << fmt::format(" prices = {0}\n", + format_container_to_log(oracle.prices)); + + output << fmt::format(" diag_items_heap_ = {0}\n", + losses_heap_to_string(oracle.diag_items_heap_)); + + + output << fmt::format(" top_diag_indices_ = {0}\n", + format_container_to_log(oracle.top_diag_indices_)); + + output << fmt::format(" top_diag_counter_ = {0}\n", + oracle.top_diag_counter_); + + output << fmt::format(" top_diag_lookup_ = {0}\n", + format_container_to_log(oracle.top_diag_lookup_)); + + + output << "end of oracle " << &oracle << std::endl; + return output; +} + + +template +AuctionOracleStupidSparseRestricted::AuctionOracleStupidSparseRestricted(const PointContainer_& _bidders, + const PointContainer_& _items, + const AuctionParams& params) : + AuctionOracleBase(_bidders, _items, params), + admissible_items_(_bidders.size(), std::vector()), + heap_handles_indices_(_items.size(), k_invalid_index), + top_diag_lookup_(_items.size(), k_invalid_index) +{ + // initialize admissible edges + std::vector kdtree_items_(_items.size(), k_invalid_index); + std::vector dnn_points_; + std::vector dnn_point_handles_; + size_t dnn_item_idx { 0 }; + size_t true_idx { 0 }; + dnn_points_.reserve(this->items.size()); + // store normal items in kd-tree + for(const auto& g : this->items) { + if (g.is_normal() ) { + kdtree_items_[true_idx] = dnn_item_idx; + // index of items is id of dnn-point + DnnPoint p(true_idx); + p[0] = g.getRealX(); + p[1] = g.getRealY(); + dnn_points_.push_back(p); + assert(dnn_item_idx == dnn_points_.size() - 1); + dnn_item_idx++; + } + true_idx++; + } + assert(dnn_points_.size() < _items.size() ); + for(size_t i = 0; i < dnn_points_.size(); ++i) { + dnn_point_handles_.push_back(&dnn_points_[i]); + } + DnnTraits traits; + traits.internal_p = params.internal_p; + dnn::KDTree kdtree_(traits, dnn_point_handles_, params.wasserstein_power); + + // loop over normal bidders, find nearest neighbours + size_t bidder_idx = 0; + for(const auto& b : this->bidders) { + if (b.is_normal()) { + admissible_items_[bidder_idx].reserve(k_max_nn); + DnnPoint bidder_dnn; + bidder_dnn[0] = b.getRealX(); + bidder_dnn[1] = b.getRealY(); + auto nearest_neighbours = kdtree_.findK(bidder_dnn, k_max_nn); + assert(nearest_neighbours.size() == k_max_nn); + for(const auto& x : nearest_neighbours) { + admissible_items_[bidder_idx].push_back(x.p->id()); + } + } + bidder_idx++; + } + + size_t handle_idx {0}; + for(size_t item_idx = 0; item_idx < _items.size(); ++item_idx) { + if (this->items[item_idx].is_diagonal()) { + heap_handles_indices_[item_idx] = handle_idx++; + diag_heap_handles_.push_back(diag_items_heap_.push(std::make_pair(item_idx, 0.0))); + } + } + max_val_ = 3*getFurthestDistance3Approx<>(_bidders, _items, params.internal_p); + max_val_ = std::pow(max_val_, params.wasserstein_power); + + console_logger = spdlog::get("console"); + if (not console_logger) { + console_logger = spdlog::stdout_logger_st("console"); + } + console_logger->set_pattern("[%H:%M:%S.%e] %v"); + console_logger->info("Stupid sparse oracle ctor done, k = {0}", k_max_nn); +} + + +template +bool AuctionOracleStupidSparseRestricted::is_in_top_diag_indices(const size_t item_idx) const +{ + return top_diag_lookup_[item_idx] != k_invalid_index; +} + + +template +void AuctionOracleStupidSparseRestricted::add_top_diag_index(const size_t item_idx) +{ + assert(find(top_diag_indices_.begin(), top_diag_indices_.end(), item_idx) == top_diag_indices_.end()); + assert(this->items[item_idx].is_diagonal()); + + top_diag_indices_.push_back(item_idx); + top_diag_lookup_[item_idx] = top_diag_indices_.size() - 1; +} + +template +void AuctionOracleStupidSparseRestricted::remove_top_diag_index(const size_t item_idx) +{ + if (top_diag_indices_.size() > 1) { + // remove item_idx from top_diag_indices after swapping + // it with the last element, update index lookup appropriately + auto old_index = top_diag_lookup_[item_idx]; + auto end_element = top_diag_indices_.back(); + std::swap(top_diag_indices_[old_index], top_diag_indices_.back()); + top_diag_lookup_[end_element] = old_index; + } + + top_diag_indices_.pop_back(); + top_diag_lookup_[item_idx] = k_invalid_index; + if (top_diag_indices_.size() < 2) { + recompute_second_best_diag(); + } + best_diagonal_items_computed_ = not top_diag_indices_.empty(); + reset_top_diag_counter(); +} + + +template +void AuctionOracleStupidSparseRestricted::increment_top_diag_counter() +{ + assert(top_diag_counter_ >= 0 and top_diag_counter_ < top_diag_indices_.size()); + + ++top_diag_counter_; + if (top_diag_counter_ >= top_diag_indices_.size()) { + top_diag_counter_ -= top_diag_indices_.size(); + } + + assert(top_diag_counter_ >= 0 and top_diag_counter_ < top_diag_indices_.size()); +} + + +template +void AuctionOracleStupidSparseRestricted::reset_top_diag_counter() +{ + top_diag_counter_ = 0; +} + +template +void AuctionOracleStupidSparseRestricted::recompute_top_diag_items(bool hard) +{ + console_logger->debug("Enter recompute_top_diag_items, hard = {0}", hard); + assert(hard or top_diag_indices_.empty()); + + if (hard) { + std::fill(top_diag_lookup_.begin(), top_diag_lookup_.end(), k_invalid_index); + top_diag_indices_.clear(); + } + + auto top_diag_iter = diag_items_heap_.ordered_begin(); + best_diagonal_item_value_ = top_diag_iter->second; + add_top_diag_index(top_diag_iter->first); + + ++top_diag_iter; + + // traverse the heap while we see the same value + while(top_diag_iter != diag_items_heap_.ordered_end()) { + if ( top_diag_iter->second != best_diagonal_item_value_) { + break; + } else { + add_top_diag_index(top_diag_iter->first); + } + ++top_diag_iter; + } + + recompute_second_best_diag(); + + best_diagonal_items_computed_ = true; + reset_top_diag_counter(); + console_logger->debug("Exit recompute_top_diag_items, hard = {0}", hard); +} + +template +typename AuctionOracleStupidSparseRestricted::DebugOptimalBidR +AuctionOracleStupidSparseRestricted::get_optimal_bid_debug(IdxType bidder_idx) const +{ + DebugOptimalBidR result; + throw std::runtime_error("Not implemented"); + return result; +} + + +template +IdxValPair AuctionOracleStupidSparseRestricted::get_optimal_bid(IdxType bidder_idx) +{ + auto bidder = this->bidders[bidder_idx]; + + // corresponding point is always considered as a candidate + // if bidder is a diagonal point, proj_item is a normal point, + // and vice versa. + + size_t best_item_idx { k_invalid_index }; + size_t second_best_item_idx { k_invalid_index }; + size_t best_diagonal_item_idx { k_invalid_index }; + Real best_item_value; + Real second_best_item_value; + + + size_t proj_item_idx = bidder_idx; + assert( 0 <= proj_item_idx and proj_item_idx < this->items.size() ); + assert(this->items[proj_item_idx].type != bidder.type); + Real proj_item_value = this->get_value_for_bidder(bidder_idx, proj_item_idx); + + if (bidder.is_diagonal()) { + // for diagonal bidder the only normal point has already been added + // the other 2 candidates are diagonal items only, get from the heap + // with prices + + if (not best_diagonal_items_computed_) { + recompute_top_diag_items(); + } + + best_diagonal_item_idx = top_diag_indices_[top_diag_counter_]; + increment_top_diag_counter(); + + if ( proj_item_value < best_diagonal_item_value_) { + best_item_idx = proj_item_idx; + best_item_value = proj_item_value; + second_best_item_value = best_diagonal_item_value_; + second_best_item_idx = best_diagonal_item_idx; + } else if (proj_item_value < second_best_diagonal_item_value_) { + best_item_idx = best_diagonal_item_idx; + best_item_value = best_diagonal_item_value_; + second_best_item_value = proj_item_value; + second_best_item_idx = proj_item_idx; + } else { + best_item_idx = best_diagonal_item_idx; + best_item_value = best_diagonal_item_value_; + second_best_item_value = second_best_diagonal_item_value_; + second_best_item_idx = second_best_diagonal_item_idx_; + } + } else { + + size_t best_normal_item_idx { k_invalid_index }; + size_t second_best_normal_item_idx { k_invalid_index }; + Real best_normal_item_value { std::numeric_limits::max() }; + Real second_best_normal_item_value { std::numeric_limits::max() }; + + // find best item + for(const auto curr_item_idx : admissible_items_[bidder_idx]) { + auto curr_item_value = this->get_value_for_bidder(bidder_idx, curr_item_idx); + if (curr_item_value < best_normal_item_value) { + best_normal_item_idx = curr_item_idx; + best_normal_item_value = curr_item_value; + } + } + + // find second-best item + for(const auto curr_item_idx : admissible_items_[bidder_idx]) { + if (curr_item_idx == best_normal_item_idx) { + continue; + } + auto curr_item_value = this->get_value_for_bidder(bidder_idx, curr_item_idx); + if (curr_item_value < second_best_normal_item_value) { + second_best_normal_item_idx = curr_item_idx; + second_best_normal_item_value = curr_item_value; + } + } + + if ( proj_item_value < best_normal_item_value) { + best_item_idx = proj_item_idx; + increment_top_diag_counter(); + best_item_value = proj_item_value; + second_best_item_value = best_normal_item_value; + } else if (proj_item_value < second_best_normal_item_value) { + best_item_idx = best_normal_item_idx; + best_item_value = best_normal_item_value; + second_best_item_value = proj_item_value; + } else { + best_item_idx = best_normal_item_idx; + best_item_value = best_normal_item_value; + second_best_item_value = second_best_normal_item_value; + } + } + + IdxValPair result; + + assert( second_best_item_value >= best_item_value ); + + result.first = best_item_idx; + result.second = ( second_best_item_value - best_item_value ) + this->prices[best_item_idx] + this->epsilon; + + return result; +} + +template +void AuctionOracleStupidSparseRestricted::recompute_second_best_diag() +{ + + console_logger->debug("Enter recompute_second_best_diag"); + + if (top_diag_indices_.size() > 1) { + second_best_diagonal_item_value_ = best_diagonal_item_value_; + second_best_diagonal_item_idx_ = top_diag_indices_[0]; + } else { + if (diag_items_heap_.size() == 1) { + second_best_diagonal_item_value_ == std::numeric_limits::max(); + second_best_diagonal_item_idx_ = k_invalid_index; + } else { + auto diag_iter = diag_items_heap_.ordered_begin(); + ++diag_iter; + second_best_diagonal_item_value_ = diag_iter->second; + second_best_diagonal_item_idx_ = diag_iter->first; + } + } + + console_logger->debug("Exit recompute_second_best_diag, second_best_diagonal_item_value_ = {0}, second_best_diagonal_item_idx_ = {1}", second_best_diagonal_item_value_, second_best_diagonal_item_idx_); +} + + +template +void AuctionOracleStupidSparseRestricted::set_price(IdxType item_idx, + Real new_price, + const bool update_diag) +{ + + console_logger->debug("Enter set_price, item_idx = {0}, new_price = {1}, old price = {2}, update_diag = {3}", item_idx, new_price, this->prices[item_idx], update_diag); + + assert(this->prices.size() == this->items.size()); + assert( 0 < diag_heap_handles_.size() and diag_heap_handles_.size() <= this->items.size()); + // adjust_prices decreases prices, + // also this variable must be true in reverse phases of FR-auction + bool item_goes_down = new_price > this->prices[item_idx]; + + this->prices[item_idx] = new_price; + if ( this->items[item_idx].is_diagonal() ) { + assert(diag_heap_handles_.size() > heap_handles_indices_.at(item_idx)); + if (item_goes_down) { + diag_items_heap_.decrease(diag_heap_handles_[heap_handles_indices_[item_idx]], std::make_pair(item_idx, new_price)); + } else { + diag_items_heap_.increase(diag_heap_handles_[heap_handles_indices_[item_idx]], std::make_pair(item_idx, new_price)); + } + if (update_diag) { + // Update top_diag_indices_ only if necessary: + // normal bidders take their projections, which might not be on top + // also, set_price is called by adjust_prices, and we may have already + // removed the item from top_diag + if (is_in_top_diag_indices(item_idx)) { + remove_top_diag_index(item_idx); + } + + if (item_idx == second_best_diagonal_item_idx_) { + recompute_second_best_diag(); + } + } + } + + console_logger->debug("Exit set_price, item_idx = {0}, new_price = {1}", item_idx, new_price); +} + + +template +void AuctionOracleStupidSparseRestricted::adjust_prices(Real delta) +{ + //console_logger->debug("Enter adjust_prices, delta = {0}", delta); + //std::cerr << *this << std::endl; + + if (delta == 0.0) + return; + + for(auto& p : this->prices) { + p -= delta; + } + + bool price_goes_up = delta < 0; + + for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + if (this->items[item_idx].is_diagonal()) { + auto new_price = this->prices[item_idx]; + if (price_goes_up) { + diag_items_heap_.decrease(diag_heap_handles_[heap_handles_indices_[item_idx]], std::make_pair(item_idx, new_price)); + } else { + diag_items_heap_.increase(diag_heap_handles_[heap_handles_indices_[item_idx]], std::make_pair(item_idx, new_price)); + } + } + } + best_diagonal_item_value_ -= delta; + second_best_diagonal_item_value_ -= delta; + + //std::cerr << *this << std::endl; + //console_logger->debug("Exit adjust_prices, delta = {0}", delta); +} + +template +void AuctionOracleStupidSparseRestricted::adjust_prices() +{ + auto pr_begin = this->prices.begin(); + auto pr_end = this->prices.end(); + Real min_price = *(std::min_element(pr_begin, pr_end)); + adjust_prices(min_price); +} + +template +size_t AuctionOracleStupidSparseRestricted::get_heap_top_size() const +{ + return top_diag_indices_.size(); +} + +template +std::pair AuctionOracleStupidSparseRestricted::get_minmax_price() const +{ + auto r = std::minmax_element(this->prices.begin(), this->prices.end()); + return std::make_pair(*r.first, *r.second); +} + +template +void AuctionOracleStupidSparseRestricted::sanity_check() +{ +#ifdef DEBUG_STUPID_SPARSE_RESTR_ORACLE + + assert(admissible_items_.size() == this->bidders.size()); + + for(size_t bidder_idx = 0; bidder_idx < this->bidders.size(); ++bidder_idx) { + if (this->bidders[bidder_idx].is_normal()) { + assert(admissible_items_[bidder_idx].size() == k_max_nn); + } else { + assert(admissible_items_[bidder_idx].size() == 0); + } + } + + if (best_diagonal_items_computed_) { + std::vector diag_items_price_vec; + diag_items_price_vec.reserve(this->items.size()); + + for(size_t item_idx = 0; item_idx < this->items.size(); ++item_idx) { + if (this->items.at(item_idx).is_diagonal()) { + diag_items_price_vec.push_back(this->prices.at(item_idx)); + } else { + diag_items_price_vec.push_back(std::numeric_limits::max()); + } + } + + auto best_iter = std::min_element(diag_items_price_vec.begin(), diag_items_price_vec.end()); + assert(best_iter != diag_items_price_vec.end()); + Real true_best_diag_value = *best_iter; + size_t true_best_diag_idx = best_iter - diag_items_price_vec.begin(); + assert(true_best_diag_value != std::numeric_limits::max()); + + Real true_second_best_diag_value = std::numeric_limits::max(); + size_t true_second_best_diag_idx = k_invalid_index; + for(size_t item_idx = 0; item_idx < diag_items_price_vec.size(); ++item_idx) { + if (this->items.at(item_idx).is_normal()) { + assert(top_diag_lookup_.at(item_idx) == k_invalid_index); + continue; + } + + auto i_iter = std::find(top_diag_indices_.begin(), top_diag_indices_.end(), item_idx); + if (diag_items_price_vec.at(item_idx) == true_best_diag_value) { + assert(i_iter != top_diag_indices_.end()); + assert(top_diag_lookup_.at(item_idx) == i_iter - top_diag_indices_.begin()); + } else { + assert(top_diag_lookup_.at(item_idx) == k_invalid_index); + assert(i_iter == top_diag_indices_.end()); + } + + if (item_idx == true_best_diag_idx) { + continue; + } + if (diag_items_price_vec.at(item_idx) < true_second_best_diag_value) { + true_second_best_diag_value = diag_items_price_vec.at(item_idx); + true_second_best_diag_idx = item_idx; + } + } + + if (true_best_diag_value != best_diagonal_item_value_) { + console_logger->debug("best_diagonal_item_value_ = {0}, true value = {1}", best_diagonal_item_value_, true_best_diag_value); + std::cerr << *this; + //console_logger->debug("{0}", *this); + } + + assert(true_best_diag_value == best_diagonal_item_value_); + + assert(true_second_best_diag_idx != k_invalid_index); + + if (true_second_best_diag_value != second_best_diagonal_item_value_) { + console_logger->debug("second_best_diagonal_item_value_ = {0}, true value = {1}", second_best_diagonal_item_value_, true_second_best_diag_value); + //console_logger->debug("{0}", *this); + } + + assert(true_second_best_diag_value == second_best_diagonal_item_value_); + } +#endif +} + + +} // ws +} // hera + +#endif diff --git a/src/dionysus/wasserstein/auction_runner_fr.h b/src/dionysus/wasserstein/auction_runner_fr.h new file mode 100755 index 0000000..1abca20 --- /dev/null +++ b/src/dionysus/wasserstein/auction_runner_fr.h @@ -0,0 +1,289 @@ +/* + +Copyright (c) 2016, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef AUCTION_RUNNER_FR_H +#define AUCTION_RUNNER_FR_H + +#define ORDERED_BY_PERSISTENCE + +#include + +#include "auction_oracle.h" + +namespace hera { +namespace ws { + +// the two parameters that you can tweak in auction algorithm are: +// 1. epsilon_common_ratio +// 2. max_num_phases + +template> // alternatively: AuctionOracleLazyHeap --- TODO +class AuctionRunnerFR { +public: + + using Real = RealType_; + using AuctionOracle = AuctionOracle_; + using DgmPoint = DiagramPoint; + using DgmPointVec = std::vector; + using IdxValPairR = IdxValPair; + + const Real k_lowest_bid_value = -(std::numeric_limits::max() - 1); // all bid values must be positive + + + AuctionRunnerFR(const std::vector& A, + const std::vector& B, + const Real q, + const Real _delta, + const Real _internal_p, + const Real _initial_epsilon = 0.0, + const Real _eps_factor = 5.0, + const int _max_num_phases = std::numeric_limits::max(), + const Real _gamma_threshold = 0.0, + const size_t _max_bids_per_round = std::numeric_limits::max(), + const std::string& _log_filename_prefix = ""); + + void set_epsilon(Real new_val); + Real get_epsilon() const { return epsilon; } + void run_auction(); + Real get_wasserstein_distance(); + Real get_wasserstein_cost(); + Real get_relative_error(const bool debug_output = false) const; + bool phase_can_be_final() const; +private: + // private data + std::vector bidders, items; + const size_t num_bidders; + const size_t num_items; + std::vector items_to_bidders; + std::vector bidders_to_items; + Real wasserstein_power; + Real epsilon; + Real delta; + Real internal_p; + Real initial_epsilon; + const Real epsilon_common_ratio; // next epsilon = current epsilon / epsilon_common_ratio + Real cumulative_epsilon_factor { 1.0 }; + const int max_num_phases; // maximal number of phases of epsilon-scaling + bool is_forward { true }; // to use in distributed version only + Real weight_adj_const; + Real wasserstein_cost; + std::vector forward_bid_table; + std::vector reverse_bid_table; + // to get the 2 best items + AuctionOracle forward_oracle; + AuctionOracle reverse_oracle; + std::unordered_set unassigned_bidders; + std::unordered_set unassigned_items; + std::unordered_set items_with_bids; + std::unordered_set bidders_with_bids; + +#ifdef ORDERED_BY_PERSISTENCE + // to process unassigned by persistence + size_t batch_size; + using RealIdxPair = std::pair; + std::set> unassigned_bidders_by_persistence; + std::set> unassigned_items_by_persistence; +#endif + + + // to imitate Gauss-Seidel + const size_t max_bids_per_round; + + // to stop earlier in the last phase + const Real total_items_persistence; + const Real total_bidders_persistence; + Real partial_cost; + Real unassigned_bidders_persistence; + Real unassigned_items_persistence; + Real gamma_threshold; + size_t unassigned_threshold; + + bool is_distance_computed { false }; + int num_rounds { 0 }; + int num_rounds_non_cumulative { 0 }; + int num_phase { 0 }; + + + size_t num_diag_items { 0 }; + size_t num_normal_items { 0 }; + size_t num_diag_bidders { 0 }; + size_t num_normal_bidders { 0 }; + + + + // private methods + void assign_forward(const IdxType item_idx, const IdxType bidder_idx); + void assign_reverse(const IdxType item_idx, const IdxType bidder_idx); + void assign_to_best_bidder(const IdxType item_idx); + void assign_to_best_item(const IdxType bidder_idx); + void clear_forward_bid_table(); + void clear_reverse_bid_table(); + void assign_diag_to_diag(); + void run_auction_phases(const int max_num_phases, const Real _initial_epsilon); + void run_auction_phase(); + void run_forward_auction_phase(); + void run_reverse_auction_phase(); + void submit_forward_bid(IdxType bidder_idx, const IdxValPairR& bid); + void submit_reverse_bid(IdxType item_idx, const IdxValPairR& bid); + void flush_assignment(); + Real get_item_bidder_cost(const size_t item_idx, const size_t bidder_idx) const; + Real get_cost_to_diagonal(const DgmPoint& pt) const; + Real get_gamma() const; + + template + void run_forward_bidding_step(const Range& r); + + template + void run_reverse_bidding_step(const Range& r); + + void add_unassigned_bidder(const size_t bidder_idx); + void add_unassigned_item(const size_t item_idx); + void remove_unassigned_bidder(const size_t bidder_idx); + void remove_unassigned_item(const size_t item_idx); + + bool is_item_diagonal(const size_t item_idx) const { return item_idx < num_diag_items; } + bool is_item_normal(const size_t item_idx) const { return not is_item_diagonal(item_idx); } + bool is_bidder_diagonal(const size_t bidder_idx) const { return bidder_idx >= num_normal_bidders; } + bool is_bidder_normal(const size_t bidder_idx) const { return not is_bidder_diagonal(bidder_idx); } + + size_t num_forward_bids_submitted { 0 }; + size_t num_reverse_bids_submitted { 0 }; + + void decrease_epsilon(); + // stopping criteria + bool continue_forward(const size_t, const size_t); + bool continue_reverse(const size_t, const size_t); + bool continue_phase(); + + + + // for debug only + void sanity_check(); + void check_epsilon_css(); + void print_debug(); + void print_matching(); + + std::string log_filename_prefix; + const Real k_max_relative_error = 2.0; // if relative error cannot be estimated or is too large, use this value + void reset_round_stat(); // empty, if logging is disable + void reset_phase_stat(); + + std::unordered_set never_assigned_bidders; + std::unordered_set never_assigned_items; + + std::shared_ptr console_logger; +#ifdef LOG_AUCTION + std::unordered_set unassigned_normal_bidders; + std::unordered_set unassigned_diag_bidders; + + std::unordered_set unassigned_normal_items; + std::unordered_set unassigned_diag_items; + + + size_t all_assigned_round { 0 }; + size_t all_assigned_round_found { false }; + + int num_forward_rounds_non_cumulative { 0 }; + int num_forward_rounds { 0 }; + + int num_reverse_rounds_non_cumulative { 0 }; + int num_reverse_rounds { 0 }; + + // all per-round vars are non-cumulative + + // forward rounds + int num_normal_forward_bids_submitted { 0 }; + int num_diag_forward_bids_submitted { 0 }; + + int num_forward_diag_to_diag_assignments { 0 }; + int num_forward_diag_to_normal_assignments { 0 }; + int num_forward_normal_to_diag_assignments { 0 }; + int num_forward_normal_to_normal_assignments { 0 }; + + int num_forward_diag_from_diag_thefts { 0 }; + int num_forward_diag_from_normal_thefts { 0 }; + int num_forward_normal_from_diag_thefts { 0 }; + int num_forward_normal_from_normal_thefts { 0 }; + + // reverse rounds + int num_normal_reverse_bids_submitted { 0 }; + int num_diag_reverse_bids_submitted { 0 }; + + int num_reverse_diag_to_diag_assignments { 0 }; + int num_reverse_diag_to_normal_assignments { 0 }; + int num_reverse_normal_to_diag_assignments { 0 }; + int num_reverse_normal_to_normal_assignments { 0 }; + + int num_reverse_diag_from_diag_thefts { 0 }; + int num_reverse_diag_from_normal_thefts { 0 }; + int num_reverse_normal_from_diag_thefts { 0 }; + int num_reverse_normal_from_normal_thefts { 0 }; + + // price change statistics + std::vector> forward_price_change_cnt_vec; + std::vector> reverse_price_change_cnt_vec; + + const char* forward_plot_logger_name = "forward_plot_logger"; + const char* reverse_plot_logger_name = "reverse_plot_logger"; + const char* forward_price_stat_logger_name = "forward_price_stat_logger"; + const char* reverse_price_stat_logger_name = "reverse_price_stat_logger"; + + std::string forward_plot_logger_file_name; + std::string reverse_plot_logger_file_name; + std::string forward_price_stat_logger_file_name; + std::string reverse_price_stat_logger_file_name; + + std::shared_ptr forward_plot_logger; + std::shared_ptr reverse_plot_logger; + std::shared_ptr forward_price_stat_logger; + std::shared_ptr reverse_price_stat_logger; + + + size_t parallel_threshold = 5000; + int num_parallelizable_rounds { 0 }; + int num_parallelizable_forward_rounds { 0 }; + int num_parallelizable_reverse_rounds { 0 }; + + int num_parallel_bids { 0 }; + int num_total_bids { 0 }; + + int num_parallel_assignments { 0 }; + int num_total_assignments { 0 }; +#endif + +}; + + + +} // ws +} // hera + +#include "auction_runner_fr.hpp" + +#undef ORDERED_BY_PERSISTENCE +#endif diff --git a/src/dionysus/wasserstein/auction_runner_fr.hpp b/src/dionysus/wasserstein/auction_runner_fr.hpp new file mode 100755 index 0000000..07c1459 --- /dev/null +++ b/src/dionysus/wasserstein/auction_runner_fr.hpp @@ -0,0 +1,1440 @@ +/* + +Copyright (c) 2016, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef AUCTION_RUNNER_FR_HPP +#define AUCTION_RUNNER_FR_HPP + +#include +#include +#include +#include + +#include "def_debug_ws.h" + +#include "auction_runner_fr.h" + +#ifdef FOR_R_TDA +#include "Rcpp.h" +#undef DEBUG_FR_AUCTION +#endif + + +namespace hera { +namespace ws { + + +// ***************************** +// AuctionRunnerFR +// ***************************** + +template +AuctionRunnerFR::AuctionRunnerFR(const std::vector& A, + const std::vector& B, + const Real q, + const Real _delta, + const Real _internal_p, + const Real _initial_epsilon, + const Real _eps_factor, + const int _max_num_phases, + const Real _gamma_threshold, + const size_t _max_bids_per_round, + const std::string& _log_filename_prefix + ) : + bidders(A), + items(B), + num_bidders(A.size()), + num_items(A.size()), + items_to_bidders(A.size(), k_invalid_index), + bidders_to_items(A.size(), k_invalid_index), + wasserstein_power(q), + delta(_delta), + internal_p(_internal_p), + initial_epsilon(_initial_epsilon), + epsilon_common_ratio(_eps_factor == 0.0 ? 5.0 : _eps_factor), + max_num_phases(_max_num_phases), + forward_bid_table(A.size(), std::make_pair(k_invalid_index, k_lowest_bid_value) ), + reverse_bid_table(B.size(), std::make_pair(k_invalid_index, k_lowest_bid_value) ), + forward_oracle(bidders, items, q, _internal_p), + reverse_oracle(items, bidders, q, _internal_p), + max_bids_per_round(_max_bids_per_round), + total_items_persistence(std::accumulate(items.begin(), + items.end(), + R(0.0), + [_internal_p, q](const Real& ps, const DgmPoint& item) + { return ps + std::pow(item.persistence_lp(_internal_p), q); } + )), + total_bidders_persistence(std::accumulate(bidders.begin(), + bidders.end(), + R(0.0), + [_internal_p, q](const Real& ps, const DgmPoint& bidder) + { return ps + std::pow(bidder.persistence_lp(_internal_p), q); } + )), + partial_cost(0.0), + unassigned_bidders_persistence(total_bidders_persistence), + unassigned_items_persistence(total_items_persistence), + gamma_threshold(_gamma_threshold), + log_filename_prefix(_log_filename_prefix) +{ + assert(A.size() == B.size()); + for(const auto& p : bidders) { + if (p.is_normal()) { + num_normal_bidders++; + num_diag_items++; + } else { + num_normal_items++; + num_diag_bidders++; + } + } + +#ifdef ORDERED_BY_PERSISTENCE + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + unassigned_bidders_by_persistence.insert(std::make_pair(bidders[bidder_idx].persistence_lp(1.0), bidder_idx)); + } + + for(size_t item_idx = 0; item_idx < num_items; ++item_idx) { + unassigned_items_by_persistence.insert(std::make_pair(items[item_idx].persistence_lp(1.0), item_idx)); + } +#endif + + // for experiments + unassigned_threshold = bidders.size() / 200; + unassigned_threshold = 0; + +#ifdef ORDERED_BY_PERSISTENCE + batch_size = 5000; +#endif + + console_logger = spdlog::get("console"); + if (not console_logger) { + console_logger = spdlog::stdout_logger_st("console"); + } + console_logger->set_pattern("[%H:%M:%S.%e] %v"); + console_logger->info("Forward-reverse runnder, max_num_phases = {0}, max_bids_per_round = {1}, gamma_threshold = {2}, unassigned_threshold = {3}", + max_num_phases, + max_bids_per_round, + gamma_threshold, + unassigned_threshold); + + +// check_epsilon_css(); +#ifdef LOG_AUCTION + parallel_threshold = bidders.size() / 100; + forward_plot_logger_file_name = log_filename_prefix + "_forward_plot.txt"; + forward_plot_logger = spdlog::get(forward_plot_logger_name); + if (not forward_plot_logger) { + forward_plot_logger = spdlog::basic_logger_st(forward_plot_logger_name, forward_plot_logger_file_name); + } + forward_plot_logger->info("New forward plot starts here, diagram size = {0}, gamma_threshold = {1}, epsilon_common_ratio = {2}", + bidders.size(), + gamma_threshold, + epsilon_common_ratio); + forward_plot_logger->set_pattern("%v"); + + reverse_plot_logger_file_name = log_filename_prefix + "_reverse_plot.txt"; + reverse_plot_logger = spdlog::get(reverse_plot_logger_name); + if (not reverse_plot_logger) { + reverse_plot_logger = spdlog::basic_logger_st(reverse_plot_logger_name, reverse_plot_logger_file_name); + } + reverse_plot_logger->info("New reverse plot starts here, diagram size = {0}, gamma_threshold = {1}, epsilon_common_ratio = {2}", + bidders.size(), + gamma_threshold, + epsilon_common_ratio); + reverse_plot_logger->set_pattern("%v"); + + + + forward_price_stat_logger_file_name = log_filename_prefix + "_forward_price_change_stat"; + forward_price_stat_logger = spdlog::get(forward_price_stat_logger_name); + if (not forward_price_stat_logger) { + forward_price_stat_logger = spdlog::basic_logger_st(forward_price_stat_logger_name, + forward_price_stat_logger_file_name); + } + forward_price_stat_logger->info("New forward price statistics starts here, diagram size = {0}, gamma_threshold = {1}, epsilon_common_ratio = {2}", + bidders.size(), + gamma_threshold, + epsilon_common_ratio); + forward_price_stat_logger->set_pattern("%v"); + + reverse_price_stat_logger_file_name = log_filename_prefix + "_reverse_price_change_stat"; + reverse_price_stat_logger = spdlog::get(reverse_price_stat_logger_name); + if (not reverse_price_stat_logger) { + reverse_price_stat_logger = spdlog::basic_logger_st(reverse_price_stat_logger_name, + reverse_price_stat_logger_file_name); + } + reverse_price_stat_logger->info("New reverse price statistics starts here, diagram size = {0}, gamma_threshold = {1}, epsilon_common_ratio = {2}", + bidders.size(), + gamma_threshold, + epsilon_common_ratio); + reverse_price_stat_logger->set_pattern("%v"); +#endif +} + +template +typename AuctionRunnerFR::Real +AuctionRunnerFR::get_cost_to_diagonal(const DgmPoint& pt) const +{ + if (1.0 == wasserstein_power) { + return pt.persistence_lp(internal_p); + } else { + return std::pow(pt.persistence_lp(internal_p), wasserstein_power); + } +} + + +template +typename AuctionRunnerFR::Real +AuctionRunnerFR::get_gamma() const +{ + if (1.0 == wasserstein_power) { + return unassigned_items_persistence + unassigned_bidders_persistence; + } else { + return std::pow(unassigned_items_persistence + unassigned_bidders_persistence, + 1.0 / wasserstein_power); + } +} + +template +void AuctionRunnerFR::reset_phase_stat() +{ + num_rounds_non_cumulative = 0; +#ifdef LOG_AUCTION + num_parallelizable_rounds = 0; + num_parallelizable_forward_rounds = 0; + num_parallelizable_reverse_rounds = 0; + num_forward_rounds_non_cumulative = 0; + num_reverse_rounds_non_cumulative = 0; +#endif +} + + +template +void AuctionRunnerFR::reset_round_stat() +{ + num_forward_bids_submitted = 0; + num_reverse_bids_submitted = 0; +#ifdef LOG_AUCTION + num_normal_forward_bids_submitted = 0; + num_diag_forward_bids_submitted = 0; + + num_forward_diag_to_diag_assignments = 0; + num_forward_diag_to_normal_assignments = 0; + num_forward_normal_to_diag_assignments = 0; + num_forward_normal_to_normal_assignments = 0; + + num_forward_diag_from_diag_thefts = 0; + num_forward_diag_from_normal_thefts = 0; + num_forward_normal_from_diag_thefts = 0; + num_forward_normal_from_normal_thefts = 0; + + // reverse rounds + num_normal_reverse_bids_submitted = 0; + num_diag_reverse_bids_submitted = 0; + + num_reverse_diag_to_diag_assignments = 0; + num_reverse_diag_to_normal_assignments = 0; + num_reverse_normal_to_diag_assignments = 0; + num_reverse_normal_to_normal_assignments = 0; + + num_reverse_diag_from_diag_thefts = 0; + num_reverse_diag_from_normal_thefts = 0; + num_reverse_normal_from_diag_thefts = 0; + num_reverse_normal_from_normal_thefts = 0; +#endif +} + + +template +void AuctionRunnerFR::assign_forward(IdxType item_idx, IdxType bidder_idx) +{ + console_logger->debug("Enter assign_forward, item_idx = {0}, bidder_idx = {1}", item_idx, bidder_idx); + sanity_check(); + // only unassigned bidders submit bids + assert(bidders_to_items[bidder_idx] == k_invalid_index); + + IdxType old_item_owner = items_to_bidders[item_idx]; + + // set new owner + bidders_to_items[bidder_idx] = item_idx; + items_to_bidders[item_idx] = bidder_idx; + + // remove bidder and item from the sets of unassigned bidders/items + remove_unassigned_bidder(bidder_idx); + + if (k_invalid_index != old_item_owner) { + // old owner of item becomes unassigned + bidders_to_items[old_item_owner] = k_invalid_index; + add_unassigned_bidder(old_item_owner); + // existing edge was removed, decrease partial_cost + partial_cost -= get_item_bidder_cost(item_idx, old_item_owner); + } else { + // item was unassigned before + remove_unassigned_item(item_idx); + } + + // new edge was added to matching, increase partial cost + partial_cost += get_item_bidder_cost(item_idx, bidder_idx); + +#ifdef LOG_AUCTION + + if (unassigned_bidders.size() > parallel_threshold) { + num_parallel_assignments++; + } + num_total_assignments++; + + + int it_d = is_item_diagonal(item_idx); + int b_d = is_bidder_diagonal(bidder_idx); + // 2 - None + int old_d = ( k_invalid_index == old_item_owner ) ? 2 : is_bidder_diagonal(old_item_owner); + int key = 100 * old_d + 10 * b_d + it_d; + switch(key) { + case 211 : num_forward_diag_to_diag_assignments++; + break; + case 210 : num_forward_diag_to_normal_assignments++; + break; + case 201 : num_forward_normal_to_diag_assignments++; + break; + case 200 : num_forward_normal_to_normal_assignments++; + break; + + case 111 : num_forward_diag_to_diag_assignments++; + num_forward_diag_from_diag_thefts++; + break; + case 110 : num_forward_diag_to_normal_assignments++; + num_forward_diag_from_diag_thefts++; + break; + break; + case 101 : num_forward_normal_to_diag_assignments++; + num_forward_normal_from_diag_thefts++; + break; + break; + case 100 : num_forward_normal_to_normal_assignments++; + num_forward_normal_from_diag_thefts++; + break; + + case 11 : num_forward_diag_to_diag_assignments++; + num_forward_diag_from_normal_thefts++; + break; + case 10 : num_forward_diag_to_normal_assignments++; + num_forward_diag_from_normal_thefts++; + break; + break; + case 1 : num_forward_normal_to_diag_assignments++; + num_forward_normal_from_normal_thefts++; + break; + break; + case 0 : num_forward_normal_to_normal_assignments++; + num_forward_normal_from_normal_thefts++; + break; + default : std::cerr << "key = " << key << std::endl; + throw std::runtime_error("Bug in logging, wrong key"); + break; + } +#endif + + sanity_check(); + console_logger->debug("Exit assign_forward, item_idx = {0}, bidder_idx = {1}", item_idx, bidder_idx); +} + + +template +void AuctionRunnerFR::assign_reverse(IdxType item_idx, IdxType bidder_idx) +{ + console_logger->debug("Enter assign_reverse, item_idx = {0}, bidder_idx = {1}", item_idx, bidder_idx); + // only unassigned items submit bids in reverse phase + assert(items_to_bidders[item_idx] == k_invalid_index); + + IdxType old_bidder_owner = bidders_to_items[bidder_idx]; + + // set new owner + bidders_to_items[bidder_idx] = item_idx; + items_to_bidders[item_idx] = bidder_idx; + + // remove bidder and item from the sets of unassigned bidders/items + remove_unassigned_item(item_idx); + + if (k_invalid_index != old_bidder_owner) { + // old owner of item becomes unassigned + items_to_bidders[old_bidder_owner] = k_invalid_index; + add_unassigned_item(old_bidder_owner); + // existing edge was removed, decrease partial_cost + partial_cost -= get_item_bidder_cost(old_bidder_owner, bidder_idx); + } else { + // item was unassigned before + remove_unassigned_bidder(bidder_idx); + } + + // new edge was added to matching, increase partial cost + partial_cost += get_item_bidder_cost(item_idx, bidder_idx); + +#ifdef LOG_AUCTION + if (unassigned_items.size() > parallel_threshold) { + num_parallel_assignments++; + } + num_total_assignments++; + + int it_d = is_item_diagonal(item_idx); + int b_d = is_bidder_diagonal(bidder_idx); + // 2 - None + int old_d = (k_invalid_index == old_bidder_owner) ? 2 : is_item_diagonal(old_bidder_owner); + int key = 100 * old_d + 10 * it_d + b_d; + switch(key) { + case 211 : num_reverse_diag_to_diag_assignments++; + break; + case 210 : num_reverse_diag_to_normal_assignments++; + break; + case 201 : num_reverse_normal_to_diag_assignments++; + break; + case 200 : num_reverse_normal_to_normal_assignments++; + break; + + case 111 : num_reverse_diag_to_diag_assignments++; + num_reverse_diag_from_diag_thefts++; + break; + case 110 : num_reverse_diag_to_normal_assignments++; + num_reverse_diag_from_diag_thefts++; + break; + break; + case 101 : num_reverse_normal_to_diag_assignments++; + num_reverse_normal_from_diag_thefts++; + break; + break; + case 100 : num_reverse_normal_to_normal_assignments++; + num_reverse_normal_from_diag_thefts++; + break; + + case 11 : num_reverse_diag_to_diag_assignments++; + num_reverse_diag_from_normal_thefts++; + break; + case 10 : num_reverse_diag_to_normal_assignments++; + num_reverse_diag_from_normal_thefts++; + break; + break; + case 1 : num_reverse_normal_to_diag_assignments++; + num_reverse_normal_from_normal_thefts++; + break; + break; + case 0 : num_reverse_normal_to_normal_assignments++; + num_reverse_normal_from_normal_thefts++; + break; + default : std::cerr << "key = " << key << std::endl; + throw std::runtime_error("Bug in logging, wrong key"); + break; + } + +#endif + console_logger->debug("Exit assign_reverse, item_idx = {0}, bidder_idx = {1}", item_idx, bidder_idx); +} + +template +typename AuctionRunnerFR::Real +AuctionRunnerFR::get_item_bidder_cost(const size_t item_idx, const size_t bidder_idx) const +{ + if (wasserstein_power == 1.0) { + return dist_lp(bidders[bidder_idx], items[item_idx], internal_p); + } else { + return std::pow(dist_lp(bidders[bidder_idx], items[item_idx], internal_p), + wasserstein_power); + } +} + +template +void AuctionRunnerFR::assign_to_best_bidder(IdxType item_idx) +{ + console_logger->debug("Enter assign_to_best_bidder, item_idx = {0}", item_idx); + assert( item_idx >= 0 and item_idx < static_cast(num_items) ); + assert( forward_bid_table[item_idx].first != k_invalid_index); + + auto best_bidder_idx = forward_bid_table[item_idx].first; + auto best_bid_value = forward_bid_table[item_idx].second; + assign_forward(item_idx, best_bidder_idx); + forward_oracle.sanity_check(); + forward_oracle.set_price(item_idx, best_bid_value, true); + forward_oracle.sanity_check(); + auto new_bidder_price = -get_item_bidder_cost(item_idx, best_bidder_idx) - best_bid_value; + reverse_oracle.set_price(best_bidder_idx, new_bidder_price, false); + check_epsilon_css(); +#ifdef LOG_AUCTION + forward_price_change_cnt_vec.back()[item_idx]++; + reverse_price_change_cnt_vec.back()[best_bidder_idx]++; +#endif + console_logger->debug("Exit assign_to_best_bidder, item_idx = {0}", item_idx); +} + +template +void AuctionRunnerFR::assign_to_best_item(IdxType bidder_idx) +{ + console_logger->debug("Enter assign_to_best_item, bidder_idx = {0}", bidder_idx); + check_epsilon_css(); + assert( bidder_idx >= 0 and bidder_idx < static_cast(num_bidders) ); + assert( reverse_bid_table[bidder_idx].first != k_invalid_index); + auto best_item_idx = reverse_bid_table[bidder_idx].first; + auto best_bid_value = reverse_bid_table[bidder_idx].second; + // both assign_forward and assign_reverse take item index first, bidder index second! + assign_reverse(best_item_idx, bidder_idx); + reverse_oracle.sanity_check(); + reverse_oracle.set_price(bidder_idx, best_bid_value, true); + reverse_oracle.sanity_check(); + auto new_item_price = -get_item_bidder_cost(best_item_idx, bidder_idx) - best_bid_value; + forward_oracle.set_price(best_item_idx, new_item_price, false); +#ifdef LOG_AUCTION + forward_price_change_cnt_vec.back()[best_item_idx]++; + reverse_price_change_cnt_vec.back()[bidder_idx]++; +#endif + check_epsilon_css(); + console_logger->debug("Exit assign_to_best_item, bidder_idx = {0}", bidder_idx); +} + +template +void AuctionRunnerFR::clear_forward_bid_table() +{ + auto item_iter = items_with_bids.begin(); + while(item_iter != items_with_bids.end()) { + auto item_with_bid_idx = *item_iter; + forward_bid_table[item_with_bid_idx].first = k_invalid_index; + forward_bid_table[item_with_bid_idx].second = k_lowest_bid_value; + item_iter = items_with_bids.erase(item_iter); + } +} + +template +void AuctionRunnerFR::clear_reverse_bid_table() +{ + auto bidder_iter = bidders_with_bids.begin(); + while(bidder_iter != bidders_with_bids.end()) { + auto bidder_with_bid_idx = *bidder_iter; + reverse_bid_table[bidder_with_bid_idx].first = k_invalid_index; + reverse_bid_table[bidder_with_bid_idx].second = k_lowest_bid_value; + bidder_iter = bidders_with_bids.erase(bidder_iter); + } +} + +template +void AuctionRunnerFR::submit_forward_bid(IdxType bidder_idx, const IdxValPairR& bid) +{ + IdxType best_item_idx = bid.first; + Real bid_value = bid.second; + assert( best_item_idx >= 0 ); + + auto value_in_bid_table = forward_bid_table[best_item_idx].second; + bool new_bid_wins = (value_in_bid_table < bid_value); + // if we have tie, lower persistence wins +// if (value_in_bid_table == bid_value) { +// +// assert(forward_bid_table.at(best_item_idx).first != k_invalid_index); +// assert(&bidders.at(forward_bid_table.at(best_item_idx).first)); +// +// auto bidder_in_bid_table = bidders[forward_bid_table[best_item_idx].first]; +// new_bid_wins = bidders[best_item_idx].persistence_lp(internal_p) < bidder_in_bid_table.persistence_lp(internal_p); +// } + + if (new_bid_wins) { + forward_bid_table[best_item_idx].first = bidder_idx; + forward_bid_table[best_item_idx].second = bid_value; + } + + items_with_bids.insert(best_item_idx); + +#ifdef LOG_AUCTION + + if (unassigned_bidders.size() > parallel_threshold) { + num_parallel_bids++; + } + num_total_bids++; + + + if (is_bidder_diagonal(bidder_idx)) { + num_diag_forward_bids_submitted++; + } else { + num_normal_forward_bids_submitted++; + } +#endif +} + +template +void AuctionRunnerFR::submit_reverse_bid(IdxType item_idx, const IdxValPairR& bid) +{ + assert( items.at(item_idx).is_diagonal() or items.at(item_idx).is_normal() ); + IdxType best_bidder_idx = bid.first; + assert( bidders.at(best_bidder_idx).is_diagonal() or bidders.at(best_bidder_idx).is_normal() ); + Real bid_value = bid.second; + assert(bid_value > k_lowest_bid_value); + auto value_in_bid_table = reverse_bid_table[best_bidder_idx].second; + bool new_bid_wins = (value_in_bid_table < bid_value); + // if we have tie, lower persistence wins +// if (value_in_bid_table == bid_value) { +// assert(reverse_bid_table[best_bidder_idx].first != k_invalid_index); +// auto bidder_in_bid_table = bidders[reverse_bid_table[best_bidder_idx].first]; +// new_bid_wins = bidders[best_bidder_idx].persistence_lp(internal_p) < bidder_in_bid_table.persistence_lp(internal_p); +// } + if (new_bid_wins) { + reverse_bid_table[best_bidder_idx].first = item_idx; + reverse_bid_table[best_bidder_idx].second = bid_value; + } + bidders_with_bids.insert(best_bidder_idx); + +#ifdef LOG_AUCTION + + if (unassigned_items.size() > parallel_threshold) { + num_parallel_bids++; + } + num_total_bids++; + + if (is_item_diagonal(item_idx)) { + num_diag_reverse_bids_submitted++; + } else { + num_normal_reverse_bids_submitted++; + } +#endif +} + + +template +void AuctionRunnerFR::print_debug() +{ +#ifdef DEBUG_FR_AUCTION + std::cout << "**********************" << std::endl; + std::cout << "Current assignment:" << std::endl; + for(size_t idx = 0; idx < bidders_to_items.size(); ++idx) { + std::cout << idx << " <--> " << bidders_to_items[idx] << std::endl; + } + std::cout << "Weights: " << std::endl; + //for(size_t i = 0; i < num_bidders; ++i) { + //for(size_t j = 0; j < num_items; ++j) { + //std::cout << oracle.weight_matrix[i][j] << " "; + //} + //std::cout << std::endl; + //} + std::cout << "Bidder prices: " << std::endl; + for(const auto price : forward_oracle.get_prices()) { + std::cout << price << std::endl; + } + std::cout << "**********************" << std::endl; +#endif +} + + +template +typename AuctionRunnerFR::Real +AuctionRunnerFR::get_relative_error(const bool debug_output) const +{ + Real result; + Real gamma = get_gamma(); + // cost minus n epsilon + Real reduced_cost = partial_cost - num_bidders * get_epsilon(); + if ( reduced_cost < 0) { +#ifdef LOG_AUCTION + if (debug_output) { + console_logger->debug("Epsilon too large, reduced_cost = {0}", reduced_cost); + } +#endif + result = k_max_relative_error; + } else { + Real denominator = std::pow(reduced_cost, 1.0 / wasserstein_power) - gamma; + if (denominator <= 0) { +#ifdef LOG_AUCTION + if (debug_output) { + console_logger->debug("Epsilon too large, reduced_cost = {0}, denominator = {1}, gamma = {2}", reduced_cost, denominator, gamma); + } +#endif + result = k_max_relative_error; + } else { + Real numerator = 2 * gamma + + std::pow(partial_cost, 1.0 / wasserstein_power) - + std::pow(reduced_cost, 1.0 / wasserstein_power); + + result = numerator / denominator; +#ifdef LOG_AUCTION + if (debug_output) { + console_logger->debug("Reduced_cost = {0}, denominator = {1}, numerator {2}, error = {3}, gamma = {4}", + reduced_cost, + denominator, + numerator, + result, + gamma); + } +#endif + } + } + return result; +} + +template +void AuctionRunnerFR::flush_assignment() +{ + console_logger->debug("Enter flush_assignment"); + for(auto& b2i : bidders_to_items) { + b2i = k_invalid_index; + } + for(auto& i2b : items_to_bidders) { + i2b = k_invalid_index; + } + + // all bidders and items become unassigned + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + unassigned_bidders.insert(bidder_idx); + } + + // all items and items become unassigned + for(size_t item_idx = 0; item_idx < num_items; ++item_idx) { + unassigned_items.insert(item_idx); + } + + + //forward_oracle.adjust_prices(); + //reverse_oracle.adjust_prices(); + + partial_cost = 0.0; + unassigned_bidders_persistence = total_bidders_persistence; + unassigned_items_persistence = total_items_persistence; + +#ifdef ORDERED_BY_PERSISTENCE + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + unassigned_bidders_by_persistence.insert(std::make_pair(bidders[bidder_idx].persistence_lp(1.0), bidder_idx)); + } + + for(size_t item_idx = 0; item_idx < num_items; ++item_idx) { + unassigned_items_by_persistence.insert(std::make_pair(items[item_idx].persistence_lp(1.0), item_idx)); + } +#endif + +#ifdef LOG_AUCTION + + reset_phase_stat(); + + forward_price_change_cnt_vec.push_back(std::vector(num_items, 0)); + reverse_price_change_cnt_vec.push_back(std::vector(num_bidders, 0)); + + // all bidders and items become unassigned + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + if (is_bidder_normal(bidder_idx)) { + unassigned_normal_bidders.insert(bidder_idx); + } else { + unassigned_diag_bidders.insert(bidder_idx); + } + } + + never_assigned_bidders = unassigned_bidders; + + for(size_t item_idx = 0; item_idx < items.size(); ++item_idx) { + if (is_item_normal(item_idx)) { + unassigned_normal_items.insert(item_idx); + } else { + unassigned_diag_items.insert(item_idx); + } + } + + never_assigned_items = unassigned_items; +#endif + check_epsilon_css(); + console_logger->debug("Exit flush_assignment"); +} + +template +void AuctionRunnerFR::set_epsilon(Real new_val) +{ + assert(new_val > 0.0); + epsilon = new_val; + forward_oracle.set_epsilon(new_val); + reverse_oracle.set_epsilon(new_val); +} + + +template +bool AuctionRunnerFR::continue_forward(const size_t original_unassigned_bidders, const size_t min_forward_matching_increment) +{ +// if (unassigned_threshold == 0) { +// return not unassigned_bidders.empty() and get_relative_error(false) > delta; +// } + //return unassigned_bidders.size() > unassigned_threshold and + //static_cast(unassigned_bidders.size()) >= static_cast(original_unassigned_bidders) - static_cast(min_forward_matching_increment); + return unassigned_bidders.size() > unassigned_threshold and + static_cast(unassigned_bidders.size()) >= static_cast(original_unassigned_bidders) - static_cast(min_forward_matching_increment) and + get_relative_error() >= delta; +// return not unassigned_bidders.empty() and +// static_cast(unassigned_bidders.size()) >= static_cast(original_unassigned_bidders) - static_cast(min_forward_matching_increment) and +// get_relative_error() >= delta; +} + + +template +bool AuctionRunnerFR::continue_reverse(const size_t original_unassigned_items, const size_t min_reverse_matching_increment) +{ + //return unassigned_items.size() > unassigned_threshold and + //static_cast(unassigned_items.size()) >= static_cast(original_unassigned_items) - static_cast(min_reverse_matching_increment); + return unassigned_items.size() > unassigned_threshold and + static_cast(unassigned_items.size()) >= static_cast(original_unassigned_items) - static_cast(min_reverse_matching_increment) and + get_relative_error() >= delta; +// return not unassigned_items.empty() and +// static_cast(unassigned_items.size()) >= static_cast(original_unassigned_items) - static_cast(min_reverse_matching_increment) and +// get_relative_error() >= delta; +} + + +template +bool AuctionRunnerFR::continue_phase() +{ + //return not unassigned_bidders.empty(); + return unassigned_bidders.size() > unassigned_threshold and get_relative_error() >= delta; +// return not never_assigned_bidders.empty() or +// not never_assigned_items.empty() or +// unassigned_bidders.size() > unassigned_threshold and get_relative_error() >= delta; +} + + + +template +void AuctionRunnerFR::run_auction_phase() +{ + num_phase++; + while(continue_phase()) { + forward_oracle.recompute_top_diag_items(true); + forward_oracle.sanity_check(); + console_logger->debug("forward_oracle recompute_top_diag_items done"); + run_forward_auction_phase(); + reverse_oracle.recompute_top_diag_items(true); + console_logger->debug("reverse_oracle recompute_top_diag_items done"); + reverse_oracle.sanity_check(); + run_reverse_auction_phase(); + } +} + +template +void AuctionRunnerFR::run_auction_phases(const int max_num_phases, const Real _initial_epsilon) +{ + set_epsilon(_initial_epsilon); + assert( forward_oracle.get_epsilon() > 0 ); + assert( reverse_oracle.get_epsilon() > 0 ); + for(int phase_num = 0; phase_num < max_num_phases; ++phase_num) { + flush_assignment(); + console_logger->info("Phase {0} started: eps = {1}", + num_phase, + get_epsilon()); + + run_auction_phase(); + Real current_result = partial_cost; +#ifdef LOG_AUCTION + console_logger->info("Phase {0} done: current_result = {1}, eps = {2}, unassigned_threshold = {3}, unassigned = {4}, error = {5}, gamma = {6}", + num_phase, + partial_cost, + get_epsilon(), + format_int<>(unassigned_threshold), + unassigned_bidders.size(), + get_relative_error(false), + get_gamma()); + + console_logger->info("Phase {0} done: num_rounds / num_parallelizable_rounds = {1} / {2} = {3}, cumulative rounds = {4}", + num_phase, + format_int(num_rounds_non_cumulative), + format_int(num_parallelizable_rounds), + static_cast(num_parallelizable_rounds) / static_cast(num_rounds_non_cumulative), + format_int(num_rounds) + ); + + console_logger->info("parallelizable_forward_rounds / num_forward_rounds = {0} / {1} = {2}", + format_int<>(num_parallelizable_forward_rounds), + format_int<>(num_forward_rounds_non_cumulative), + static_cast(num_parallelizable_forward_rounds) / static_cast(num_forward_rounds_non_cumulative) + ); + + num_parallelizable_forward_rounds = 0; + num_forward_rounds_non_cumulative = 0; + + console_logger->info("parallelizable_reverse_rounds / num_reverse_rounds = {0} / {1} = {2}", + format_int<>(num_parallelizable_reverse_rounds), + format_int<>(num_reverse_rounds_non_cumulative), + static_cast(num_parallelizable_reverse_rounds) / static_cast(num_reverse_rounds_non_cumulative) + ); + + num_parallelizable_reverse_rounds = 0; + num_reverse_rounds_non_cumulative = 0; + + console_logger->info("num_parallel_bids / num_total_bids = {0} / {1} = {2}, num_parallel_assignments / num_total_assignments = {3} / {4} = {5}", + format_int<>(num_parallel_bids), + format_int<>(num_total_bids), + static_cast(num_parallel_bids) / static_cast(num_total_bids), + format_int<>(num_parallel_assignments), + format_int<>(num_total_assignments), + static_cast(num_parallel_assignments) / static_cast(num_total_assignments) + ); + + auto forward_min_max_price = forward_oracle.get_minmax_price(); + auto reverse_min_max_price = reverse_oracle.get_minmax_price(); + + console_logger->info("forward min price = {0}, max price = {1}; reverse min price = {2}, reverse max price = {3}", + forward_min_max_price.first, + forward_min_max_price.second, + reverse_min_max_price.first, + reverse_min_max_price.second + ); + + for(size_t item_idx = 0; item_idx < num_items; ++item_idx) { + forward_price_stat_logger->info("{0} {1} {2} {3} {4}", + phase_num, + item_idx, + items[item_idx].getRealX(), + items[item_idx].getRealY(), + forward_price_change_cnt_vec.back()[item_idx] + ); + } + + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + reverse_price_stat_logger->info("{0} {1} {2} {3} {4}", + phase_num, + bidder_idx, + bidders[bidder_idx].getRealX(), + bidders[bidder_idx].getRealY(), + reverse_price_change_cnt_vec.back()[bidder_idx] + ); + } +#endif + + if (get_relative_error(true) <= delta) { + break; + } + // decrease epsilon for the next iteration + decrease_epsilon(); + + unassigned_threshold = std::floor( static_cast(unassigned_threshold) / 1.1 ); + + if (phase_can_be_final()) { + unassigned_threshold = 0; +#ifdef LOG_AUCTION + console_logger->info("Unassigned threshold set to zero!"); +#endif + } + } +} + +template +bool AuctionRunnerFR::phase_can_be_final() const +{ + Real estimated_error; + // cost minus n epsilon + Real reduced_cost = partial_cost - num_bidders * get_epsilon(); + if (reduced_cost <= 0.0) { + return false; + } else { + Real denominator = std::pow(reduced_cost, 1.0 / wasserstein_power); + if (denominator <= 0) { + return false; + } else { + Real numerator = std::pow(partial_cost, 1.0 / wasserstein_power) - + std::pow(reduced_cost, 1.0 / wasserstein_power); + + estimated_error = numerator / denominator; + return estimated_error <= delta; + } + } +} + +template +void AuctionRunnerFR::run_auction() +{ + double init_eps = ( initial_epsilon > 0.0 ) ? initial_epsilon : std::min(forward_oracle.max_val_, reverse_oracle.max_val_) / 4.0 ; + assert(init_eps > 0.0); + run_auction_phases(max_num_phases, init_eps); + is_distance_computed = true; + wasserstein_cost = partial_cost; + if (get_relative_error() > delta) { +#ifndef FOR_R_TDA + std::cerr << "Maximum iteration number exceeded, exiting. Current result is: "; + std::cerr << get_wasserstein_distance() << std::endl; +#endif + throw std::runtime_error("Maximum iteration number exceeded"); + } +} + +template +void AuctionRunnerFR::add_unassigned_bidder(const size_t bidder_idx) +{ + const DgmPoint& bidder = bidders[bidder_idx]; + unassigned_bidders.insert(bidder_idx); + unassigned_bidders_persistence += get_cost_to_diagonal(bidder); + +#ifdef ORDERED_BY_PERSISTENCE + unassigned_bidders_by_persistence.insert(std::make_pair(bidder.persistence_lp(1.0), bidder_idx)); +#endif + +#ifdef LOG_AUCTION + if (is_bidder_diagonal(bidder_idx)) { + unassigned_diag_bidders.insert(bidder_idx); + } else { + unassigned_normal_bidders.insert(bidder_idx); + } +#endif +} + +template +void AuctionRunnerFR::add_unassigned_item(const size_t item_idx) +{ + const DgmPoint& item = items[item_idx]; + unassigned_items.insert(item_idx); + unassigned_items_persistence += get_cost_to_diagonal(item); + +#ifdef ORDERED_BY_PERSISTENCE + unassigned_items_by_persistence.insert(std::make_pair(item.persistence_lp(1.0), item_idx)); +#endif + +#ifdef LOG_AUCTION + if (is_item_diagonal(item_idx)) { + unassigned_diag_items.insert(item_idx); + } else { + unassigned_normal_items.insert(item_idx); + } +#endif +} + + +template +void AuctionRunnerFR::remove_unassigned_bidder(const size_t bidder_idx) +{ + unassigned_bidders_persistence -= get_cost_to_diagonal(bidders[bidder_idx]); + + unassigned_bidders.erase(bidder_idx); + never_assigned_bidders.erase(bidder_idx); + +#ifdef ORDERED_BY_PERSISTENCE + unassigned_bidders_by_persistence.erase(std::make_pair(bidders[bidder_idx].persistence_lp(1.0), bidder_idx)); +#endif + +#ifdef LOG_AUCTION + if (is_bidder_diagonal(bidder_idx)) { + unassigned_diag_bidders.erase(bidder_idx); + } else { + unassigned_normal_bidders.erase(bidder_idx); + } + if (never_assigned_bidders.empty() and not all_assigned_round_found) { + all_assigned_round = num_rounds_non_cumulative; + all_assigned_round_found = true; + } +#endif +} + +template +void AuctionRunnerFR::remove_unassigned_item(const size_t item_idx) +{ + console_logger->debug("Enter remove_unassigned_item, unassigned_items.size = {0}", unassigned_items.size()); + unassigned_items_persistence -= get_cost_to_diagonal(items[item_idx]); + + never_assigned_items.erase(item_idx); + unassigned_items.erase(item_idx); + +#ifdef ORDERED_BY_PERSISTENCE + unassigned_items_by_persistence.erase(std::make_pair(items[item_idx].persistence_lp(1.0), item_idx)); +#endif + +#ifdef LOG_AUCTION + if (is_item_normal(item_idx)) { + unassigned_normal_items.erase(item_idx); + } else { + unassigned_diag_items.erase(item_idx); + } + if (never_assigned_items.empty() and not all_assigned_round_found) { + all_assigned_round = num_rounds_non_cumulative; + all_assigned_round_found = true; + } +#endif + console_logger->debug("Exit remove_unassigned_item, unassigned_items.size = {0}", unassigned_items.size()); +} + +template +void AuctionRunnerFR::decrease_epsilon() +{ + auto eps_diff = 1.01 * get_epsilon() * (epsilon_common_ratio - 1.0 ) / epsilon_common_ratio; + reverse_oracle.adjust_prices( -eps_diff ); + set_epsilon( get_epsilon() / epsilon_common_ratio ); + cumulative_epsilon_factor *= epsilon_common_ratio; +} + + + +template +void AuctionRunnerFR::run_reverse_auction_phase() +{ + console_logger->debug("Enter run_reverse_auction_phase"); + size_t original_unassigned_items = unassigned_items.size(); +// const size_t min_reverse_matching_increment = std::max( static_cast(1), static_cast(original_unassigned_items / 10)); + size_t min_reverse_matching_increment = 1; + + while (continue_reverse(original_unassigned_items, min_reverse_matching_increment)) { + num_rounds++; + num_rounds_non_cumulative++; + console_logger->debug("started round = {0}, reverse, unassigned = {1}", num_rounds, unassigned_items.size()); + + check_epsilon_css(); +#ifdef LOG_AUCTION + if (unassigned_items.size() >= parallel_threshold) { + ++num_parallelizable_reverse_rounds; + ++num_parallelizable_rounds; + } + num_reverse_rounds++; + num_reverse_rounds_non_cumulative++; +#endif + + reset_round_stat(); + // bidding +#ifdef ORDERED_BY_PERSISTENCE + std::vector active_items; + active_items.reserve(batch_size); + for(auto iter = unassigned_items_by_persistence.begin(); + iter != unassigned_items_by_persistence.end(); ++iter) { + active_items.push_back(iter->second); + if (active_items.size() >= batch_size) { + break; + } + } + run_reverse_bidding_step(active_items); +#else + //if (not never_assigned_items.empty()) + //run_reverse_bidding_step(never_assigned_items); + //else + //run_reverse_bidding_step(unassigned_items); + run_reverse_bidding_step(unassigned_items); +#endif + + // assignment phase + for(auto bidder_idx : bidders_with_bids ) { + assign_to_best_item(bidder_idx); + } + + check_epsilon_css(); + + console_logger->debug("ended round = {0}, reverse, unassigned = {1}", num_rounds, unassigned_items.size()); + +#ifdef LOG_AUCTION + + reverse_plot_logger->info("{0} {1} {2} {3} {4} {5} {6} {7} {8} {9} {10} {11} {12} {13} {14} {15} {16} {17} {18} {19} {20} {21} {22}", + num_phase, + num_rounds, + num_reverse_rounds, + unassigned_bidders.size(), + get_gamma(), + partial_cost, + reverse_oracle.get_epsilon(), + num_normal_reverse_bids_submitted, + num_diag_reverse_bids_submitted, + num_reverse_diag_to_diag_assignments, + num_reverse_diag_to_normal_assignments, + num_reverse_normal_to_diag_assignments, + num_reverse_normal_to_normal_assignments, + num_reverse_diag_from_diag_thefts, + num_reverse_diag_from_normal_thefts, + num_reverse_normal_from_diag_thefts, + num_reverse_normal_from_normal_thefts, + unassigned_normal_bidders.size(), + unassigned_diag_bidders.size(), + unassigned_normal_items.size(), + unassigned_diag_items.size(), + reverse_oracle.get_heap_top_size(), + get_relative_error(false) + ); + sanity_check(); +#endif + } +} + +template +template +void AuctionRunnerFR::run_forward_bidding_step(const Range& active_bidders) +{ + clear_forward_bid_table(); + for(const auto bidder_idx : active_bidders) { + console_logger->debug("current bidder (forward): {0}, persistence = {1}", bidders[bidder_idx], bidders[bidder_idx].persistence_lp(1.0)); + submit_forward_bid(bidder_idx, forward_oracle.get_optimal_bid(bidder_idx)); + if (++num_forward_bids_submitted >= max_bids_per_round) { + break; + } + } +} + +template +template +void AuctionRunnerFR::run_reverse_bidding_step(const Range& active_items) +{ + clear_reverse_bid_table(); + + assert(bidders_with_bids.empty()); + assert(std::all_of(reverse_bid_table.begin(), reverse_bid_table.end(), + [ki = k_invalid_index, kl = k_lowest_bid_value](const IdxValPairR& b) { return static_cast(b.first) == ki and b.second == kl; })); + + for(const auto item_idx : active_items) { + console_logger->debug("current bidder (reverse): {0}, persistence = {1}", items[item_idx], items[item_idx].persistence_lp(1.0)); + submit_reverse_bid(item_idx, reverse_oracle.get_optimal_bid(item_idx)); + if (++num_reverse_bids_submitted >= max_bids_per_round) { + break; + } + } +} + + +template +void AuctionRunnerFR::run_forward_auction_phase() +{ + const size_t original_unassigned_bidders = unassigned_bidders.size(); +// const size_t min_forward_matching_increment = std::max( static_cast(1), static_cast(original_unassigned_bidders / 10)); + const size_t min_forward_matching_increment = 1; + while (continue_forward(original_unassigned_bidders, min_forward_matching_increment)) { + console_logger->debug("started round = {0}, forward, unassigned = {1}", num_rounds, unassigned_bidders.size()); + check_epsilon_css(); + num_rounds++; +#ifdef LOG_AUCTION + if (unassigned_bidders.size() >= parallel_threshold) { + ++num_parallelizable_forward_rounds; + ++num_parallelizable_rounds; + } + num_forward_rounds++; + num_forward_rounds_non_cumulative++; +#endif + + reset_round_stat(); + // bidding step +#ifdef ORDERED_BY_PERSISTENCE + std::vector active_bidders; + active_bidders.reserve(batch_size); + for(auto iter = unassigned_bidders_by_persistence.begin(); + iter != unassigned_bidders_by_persistence.end(); ++iter) { + active_bidders.push_back(iter->second); + if (active_bidders.size() >= batch_size) { + break; + } + } + run_forward_bidding_step(active_bidders); +#else + + //if (not never_assigned_bidders.empty()) + //run_forward_bidding_step(never_assigned_bidders); + //else + //run_forward_bidding_step(unassigned_bidders); + run_forward_bidding_step(unassigned_bidders); +#endif + + // assignment step + for(auto item_idx : items_with_bids ) { + assign_to_best_bidder(item_idx); + } + + console_logger->debug("ended round = {0}, forward, unassigned = {1}", num_rounds, unassigned_bidders.size()); + check_epsilon_css(); + +#ifdef LOG_AUCTION + forward_plot_logger->info("{0} {1} {2} {3} {4} {5} {6} {7} {8} {9} {10} {11} {12} {13} {14} {15} {16} {17} {18} {19} {20} {21} {22}", + num_phase, + num_rounds, + num_forward_rounds, + unassigned_bidders.size(), + get_gamma(), + partial_cost, + forward_oracle.get_epsilon(), + num_normal_forward_bids_submitted, + num_diag_forward_bids_submitted, + num_forward_diag_to_diag_assignments, + num_forward_diag_to_normal_assignments, + num_forward_normal_to_diag_assignments, + num_forward_normal_to_normal_assignments, + num_forward_diag_from_diag_thefts, + num_forward_diag_from_normal_thefts, + num_forward_normal_from_diag_thefts, + num_forward_normal_from_normal_thefts, + unassigned_normal_bidders.size(), + unassigned_diag_bidders.size(), + unassigned_normal_items.size(), + unassigned_diag_items.size(), + forward_oracle.get_heap_top_size(), + get_relative_error(false) + ); +#endif + } ; + +} +template +void AuctionRunnerFR::assign_diag_to_diag() +{ + size_t n_diag_to_diag = std::min(num_diag_bidders, num_diag_items); + if (n_diag_to_diag < 2) + return; + for(size_t i = 0; i < n_diag_to_diag; ++i) { + } +} + +template +typename AuctionRunnerFR::Real +AuctionRunnerFR::get_wasserstein_distance() +{ + assert(is_distance_computed); + return std::pow(wasserstein_cost, 1.0 / wasserstein_power); +} + +template +typename AuctionRunnerFR::Real +AuctionRunnerFR::get_wasserstein_cost() +{ + assert(is_distance_computed); + return wasserstein_cost; +} + + + +template +void AuctionRunnerFR::sanity_check() +{ +#ifdef DEBUG_FR_AUCTION + assert(partial_cost >= 0); + + + assert(num_diag_items == num_normal_bidders); + assert(num_diag_bidders == num_normal_items); + assert(num_diag_bidders + num_normal_bidders == num_bidders); + assert(num_diag_items + num_normal_items == num_items); + assert(num_items == num_bidders); + + + for(size_t b = 0; b < num_bidders; ++b) { + assert( is_bidder_diagonal(b) == bidders.at(b).is_diagonal() ); + assert( is_bidder_normal(b) == bidders.at(b).is_normal() ); + } + + for(size_t i = 0; i < num_items; ++i) { + assert( is_item_diagonal(i) == items.at(i).is_diagonal() ); + assert( is_item_normal(i) == items.at(i).is_normal() ); + } + + // check matching consistency + assert(bidders_to_items.size() == num_bidders); + assert(items_to_bidders.size() == num_bidders); + + assert(std::count(bidders_to_items.begin(), bidders_to_items.end(), k_invalid_index) == std::count(items_to_bidders.begin(), items_to_bidders.end(), k_invalid_index)); + + Real true_partial_cost = 0.0; + + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + if (bidders_to_items[bidder_idx] != k_invalid_index) { + assert(items_to_bidders.at(bidders_to_items[bidder_idx]) == static_cast(bidder_idx)); + true_partial_cost += get_item_bidder_cost(bidders_to_items[bidder_idx], bidder_idx); + } + } + + assert(fabs(partial_cost - true_partial_cost) < 0.00001); + + for(size_t item_idx = 0; item_idx < num_items; ++item_idx) { + if (items_to_bidders[item_idx] != k_invalid_index) { + assert(bidders_to_items.at(items_to_bidders[item_idx]) == static_cast(item_idx)); + } + } + +#ifdef ORDERED_BY_PERSISTENCE + assert(unassigned_bidders.size() == unassigned_bidders_by_persistence.size()); + if (unassigned_items.size() != unassigned_items_by_persistence.size()) { + console_logger->error("unassigned_items.size() = {0}, unassigned_items_by_persistence.size() = {1}", unassigned_items.size(),unassigned_items_by_persistence.size()); + console_logger->error("unassigned_items = {0}, unassigned_items_by_persistence = {1}", format_container_to_log(unassigned_items),format_pair_container_to_log(unassigned_items_by_persistence)); + } + assert(unassigned_items.size() == unassigned_items_by_persistence.size()); + + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + if (bidders_to_items[bidder_idx] == k_invalid_index) { + assert(unassigned_bidders.count(bidder_idx) == 1); + assert(unassigned_bidders_by_persistence.count(std::make_pair(bidders[bidder_idx].persistence_lp(1.0), bidder_idx)) == 1); + } else { + assert(unassigned_bidders.count(bidder_idx) == 0); + assert(unassigned_bidders_by_persistence.count(std::make_pair(bidders[bidder_idx].persistence_lp(1.0), bidder_idx)) == 0); + } + } + + for(size_t item_idx = 0; item_idx < num_items; ++item_idx) { + if (items_to_bidders[item_idx] == k_invalid_index) { + assert(unassigned_items.count(item_idx) == 1); + assert(unassigned_items_by_persistence.count(std::make_pair(items[item_idx].persistence_lp(1.0), item_idx)) == 1); + } else { + assert(unassigned_items.count(item_idx) == 0); + assert(unassigned_items_by_persistence.count(std::make_pair(items[item_idx].persistence_lp(1.0), item_idx)) == 0); + } + } +#endif + + +#endif +} + +template +void AuctionRunnerFR::check_epsilon_css() +{ +#ifdef DEBUG_FR_AUCTION + sanity_check(); + + std::vector b_prices = reverse_oracle.get_prices(); + std::vector i_prices = forward_oracle.get_prices(); + double eps = forward_oracle.get_epsilon(); + + for(size_t b = 0; b < num_bidders; ++b) { + for(size_t i = 0; i < num_items; ++i) { + if(((is_bidder_normal(b) and is_item_diagonal(i)) or (is_bidder_diagonal(b) and is_item_normal(i))) and b != i) + continue; + if (b_prices[b] + i_prices[i] + eps < -get_item_bidder_cost(i, b) - 0.000001) { + console_logger->debug("b = {0}, i = {1}, eps = {2}, b_price = {3}, i_price[i] = {4}, cost = {5}, b_price + i_price + eps = {6}", + b, + i, + eps, + b_prices[b], + i_prices[i], + get_item_bidder_cost(i, b), + b_prices[b] + i_prices[i] + eps + ); + } + assert(b_prices[b] + i_prices[i] + eps >= -get_item_bidder_cost(i, b) - 0.000001); + } + } + + for(size_t b = 0; b < num_bidders; ++b) { + auto i = bidders_to_items[b]; + if (i != k_invalid_index) { + assert( fabs(b_prices[b] + i_prices[i] + get_item_bidder_cost(i, b)) < 0.000001 ); + } + } +#endif +} + +template +void AuctionRunnerFR::print_matching() +{ +#ifdef DEBUG_FR_AUCTION + sanity_check(); + for(size_t bidder_idx = 0; bidder_idx < bidders_to_items.size(); ++bidder_idx) { + if (bidders_to_items[bidder_idx] >= 0) { + auto pA = bidders[bidder_idx]; + auto pB = items[bidders_to_items[bidder_idx]]; + std::cout << pA << " <-> " << pB << "+" << pow(dist_lp(pA, pB, internal_p), wasserstein_power) << std::endl; + } else { + assert(false); + } + } +#endif +} + +} // ws +} // hera + +#endif diff --git a/src/dionysus/wasserstein/auction_runner_gs.h b/src/dionysus/wasserstein/auction_runner_gs.h new file mode 100755 index 0000000..fc76987 --- /dev/null +++ b/src/dionysus/wasserstein/auction_runner_gs.h @@ -0,0 +1,122 @@ +/* + +Copyright (c) 2016, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef AUCTION_RUNNER_GS_H +#define AUCTION_RUNNER_GS_H + +#include +#include + +#include "spdlog/spdlog.h" +#include "auction_oracle.h" + +namespace hera { +namespace ws { + +template, class PointContainer_ = std::vector> > // alternatively: AuctionOracleLazyHeap --- TODO +class AuctionRunnerGS { +public: + using Real = RealType_; + using AuctionOracle = AuctionOracle_; + using DgmPoint = typename AuctionOracle::DiagramPointR; + using IdxValPairR = IdxValPair; + using PointContainer = PointContainer_; + + + AuctionRunnerGS(const PointContainer& A, + const PointContainer& B, + const AuctionParams& params, + const std::string& _log_filename_prefix = ""); + + void set_epsilon(Real new_val) { assert(epsilon > 0.0); epsilon = new_val; }; + Real get_epsilon() const { return oracle.get_epsilon(); } + Real get_wasserstein_cost(); + Real get_wasserstein_distance(); + Real get_relative_error() const { return relative_error; }; + void enable_logging(const char* log_filename, const size_t _max_unassigned_to_log); +//private: + // private data + PointContainer bidders, items; + const size_t num_bidders; + const size_t num_items; + std::vector items_to_bidders; + std::vector bidders_to_items; + Real wasserstein_power; + Real epsilon; + Real delta; + Real internal_p; + Real initial_epsilon; + Real epsilon_common_ratio; // next epsilon = current epsilon / epsilon_common_ratio + const int max_num_phases; // maximal number of iterations of epsilon-scaling + Real weight_adj_const; + Real wasserstein_cost; + Real relative_error; + int dimension; + // to get the 2 best items + AuctionOracle oracle; + std::unordered_set unassigned_bidders; + // private methods + void assign_item_to_bidder(const IdxType bidder_idx, const IdxType items_idx); + void run_auction(); + void run_auction_phases(const int max_num_phases, const Real _initial_epsilon); + void run_auction_phase(); + void flush_assignment(); + // return 0, if item_idx is invalid + Real get_item_bidder_cost(const size_t item_idx, const size_t bidder_idx, const bool tolerate_invalid_idx = false) const; + + // for debug only + void sanity_check(); + void print_debug(); + int count_unhappy(); + void print_matching(); + Real getDistanceToQthPowerInternal(); + int num_phase { 0 }; + int num_rounds { 0 }; + bool is_distance_computed {false}; +#ifdef LOG_AUCTION + bool log_auction { false }; + std::shared_ptr console_logger; + std::shared_ptr plot_logger; + std::unordered_set unassigned_items; + size_t max_unassigned_to_log { 0 }; + const char* logger_name = "auction_detailed_logger"; // the name in spdlog registry; filename is provided as parameter in enable_logging + const Real total_items_persistence; + const Real total_bidders_persistence; + Real partial_cost; + Real unassigned_bidders_persistence; + Real unassigned_items_persistence; +#endif +}; + +} // ws +} // hera + + +#include "auction_runner_gs.hpp" + +#endif diff --git a/src/dionysus/wasserstein/auction_runner_gs.hpp b/src/dionysus/wasserstein/auction_runner_gs.hpp new file mode 100755 index 0000000..d9f419d --- /dev/null +++ b/src/dionysus/wasserstein/auction_runner_gs.hpp @@ -0,0 +1,486 @@ +/* + +Copyright (c) 2016, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + + +#include +#include +#include +#include +#include +#include +#include + +#include "def_debug_ws.h" + +#define PRINT_DETAILED_TIMING + +#ifdef FOR_R_TDA +#include "Rcpp.h" +#undef DEBUG_AUCTION +#endif + + +namespace hera { +namespace ws { + +// ***************************** +// AuctionRunnerGS +// ***************************** + +template +AuctionRunnerGS::AuctionRunnerGS(const PC& A, + const PC& B, + const AuctionParams& params, + const std::string& _log_filename_prefix) : + bidders(A), + items(B), + num_bidders(A.size()), + num_items(B.size()), + items_to_bidders(B.size(), k_invalid_index), + bidders_to_items(A.size(), k_invalid_index), + wasserstein_power(params.wasserstein_power), + delta(params.delta), + internal_p(params.internal_p), + initial_epsilon(params.initial_epsilon), + epsilon_common_ratio(params.epsilon_common_ratio == 0.0 ? 5.0 : params.epsilon_common_ratio), + max_num_phases(params.max_num_phases), + dimension(params.dim), + oracle(bidders, items, params) +#ifdef LOG_AUCTION + , total_items_persistence(std::accumulate(items.begin(), + items.end(), + R(0.0), + [params](const Real& ps, const DgmPoint& item) + { return ps + std::pow(item.persistence_lp(params.internal_p), params.wasserstein_power); } + )) + + , total_bidders_persistence(std::accumulate(bidders.begin(), + bidders.end(), + R(0.0), + [params](const Real& ps, const DgmPoint& bidder) + { return ps + std::pow(bidder.persistence_lp(params.internal_p), params.wasserstein_power); } + )) + , partial_cost(0.0) + , unassigned_bidders_persistence(0.0) + , unassigned_items_persistence(0.0) +#endif + +{ + assert(initial_epsilon >= 0.0 ); + assert(epsilon_common_ratio >= 0.0 ); + assert(A.size() == B.size()); +#ifdef LOG_AUCTION + + unassigned_items_persistence = total_items_persistence; + unassigned_bidders_persistence = total_bidders_persistence; + + console_logger = spdlog::get("console"); + if (not console_logger) { + console_logger = spdlog::stdout_logger_st("console"); + } + console_logger->set_pattern("[%H:%M:%S.%e] %v"); + console_logger->debug("Gauss-Seidel, num_bidders = {0}", num_bidders); + + plot_logger = spdlog::get("plot_logger"); + if (not plot_logger) { + plot_logger = spdlog::basic_logger_st("plot_logger", "plot_logger.txt"); + plot_logger->info("New plot starts here"); + plot_logger->set_pattern("%v"); + } +#endif + +} + +#ifdef LOG_AUCTION +template +void AuctionRunnerGS::enable_logging(const char* log_filename, const size_t _max_unassigned_to_log) +{ + log_auction = true; + max_unassigned_to_log = _max_unassigned_to_log; + + auto log = spdlog::basic_logger_st(logger_name, log_filename); + log->set_pattern("%v"); +} +#endif + +template +void AuctionRunnerGS::assign_item_to_bidder(IdxType item_idx, IdxType bidder_idx) +{ + num_rounds++; + sanity_check(); + // only unassigned bidders should submit bids and get items + assert(bidders_to_items[bidder_idx] == k_invalid_index); + IdxType old_item_owner = items_to_bidders[item_idx]; + + // set new owner + bidders_to_items[bidder_idx] = item_idx; + items_to_bidders[item_idx] = bidder_idx; + // remove bidder from the list of unassigned bidders + unassigned_bidders.erase(bidder_idx); + + // old owner becomes unassigned + if (old_item_owner != k_invalid_index) { + bidders_to_items[old_item_owner] = k_invalid_index; + unassigned_bidders.insert(old_item_owner); + } + + +#ifdef LOG_AUCTION + + partial_cost += get_item_bidder_cost(item_idx, bidder_idx, true); + partial_cost -= get_item_bidder_cost(item_idx, old_item_owner, true); + + unassigned_items.erase(item_idx); + + unassigned_bidders_persistence -= std::pow(bidders[bidder_idx].persistence_lp(internal_p), wasserstein_power); + + if (old_item_owner != k_invalid_index) { + // item has been assigned to some other bidder, + // and he became unassigned + unassigned_bidders_persistence += std::pow(bidders[old_item_owner].persistence_lp(internal_p), wasserstein_power); + } else { + // item was unassigned before + unassigned_items_persistence -= std::pow(items[item_idx].persistence_lp(internal_p), wasserstein_power); + } + + if (log_auction) + plot_logger->info("{0} {1} {2} {3} {4} {5} {6} {7} {8} {9}", + num_phase, + num_rounds, + unassigned_bidders.size(), + unassigned_items_persistence, + unassigned_bidders_persistence, + unassigned_items_persistence + unassigned_bidders_persistence, + partial_cost, + total_bidders_persistence, + total_items_persistence, + oracle.get_epsilon() + ); + + + if (log_auction and unassigned_bidders.size() <= max_unassigned_to_log) { + auto logger = spdlog::get(logger_name); + if (logger) { + auto item = items[item_idx]; + auto bidder = bidders[bidder_idx]; + logger->info("{0} # ({1}, {2}) # ({3}, {4}) # {5} # {6} # {7}", + num_rounds, + item.getRealX(), + item.getRealY(), + bidder.getRealX(), + bidder.getRealY(), + format_point_set_to_log(unassigned_bidders, bidders), + format_point_set_to_log(unassigned_items, items), + oracle.get_epsilon()); + } + } +#endif +} + + +template +void AuctionRunnerGS::flush_assignment() +{ + for(auto& b2i : bidders_to_items) { + b2i = k_invalid_index; + } + for(auto& i2b : items_to_bidders) { + i2b = k_invalid_index; + } + // we must flush assignment only after we got perfect matching + assert(unassigned_bidders.empty()); + // all bidders become unassigned + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + unassigned_bidders.insert(bidder_idx); + } + assert(unassigned_bidders.size() == bidders.size()); + +#ifdef LOG_AUCTION + partial_cost = 0.0; + unassigned_bidders_persistence = total_bidders_persistence; + unassigned_items_persistence = total_items_persistence; + + for(size_t item_idx = 0; item_idx < items.size(); ++item_idx) { + unassigned_items.insert(item_idx); + } +#endif + + oracle.adjust_prices(); +} + + +template +void AuctionRunnerGS::run_auction_phases(const int max_num_phases, const Real _initial_epsilon) +{ + relative_error = std::numeric_limits::max(); + // choose some initial epsilon + oracle.set_epsilon(_initial_epsilon); + assert( oracle.get_epsilon() > 0 ); + for(int phase_num = 0; phase_num < max_num_phases; ++phase_num) { + flush_assignment(); + run_auction_phase(); + Real current_result = getDistanceToQthPowerInternal(); + Real denominator = current_result - num_bidders * oracle.get_epsilon(); + current_result = pow(current_result, 1.0 / wasserstein_power); +#ifdef LOG_AUCTION + console_logger->info("Phase {0} done, num_rounds (cumulative) = {1}, current_result = {2}, epsilon = {3}", + phase_num, format_int(num_rounds), current_result, + oracle.get_epsilon()); +#endif + if ( denominator <= 0 ) { +#ifdef LOG_AUCTION + console_logger->info("Epsilon is too large"); +#endif + } else { + denominator = pow(denominator, 1.0 / wasserstein_power); + Real numerator = current_result - denominator; + relative_error = numerator / denominator; +#ifdef LOG_AUCTION + console_logger->info("error = {0} / {1} = {2}", + numerator, denominator, relative_error); +#endif + if (relative_error <= delta) { + break; + } + } + // decrease epsilon for the next iteration + oracle.set_epsilon( oracle.get_epsilon() / epsilon_common_ratio ); + } +} + + +template +void AuctionRunnerGS::run_auction() +{ + + if (num_bidders == 1) { + assign_item_to_bidder(0, 0); + wasserstein_cost = get_item_bidder_cost(0,0); + return; + } + + double init_eps = ( initial_epsilon > 0.0 ) ? initial_epsilon : oracle.max_val_ / 4.0 ; + run_auction_phases(max_num_phases, init_eps); + is_distance_computed = true; + if (relative_error > delta) { +#ifndef FOR_R_TDA + std::cerr << "Maximum iteration number exceeded, exiting. Current result is: "; + std::cerr << pow(wasserstein_cost, 1.0/wasserstein_power) << std::endl; +#endif + throw std::runtime_error("Maximum iteration number exceeded"); + } +} + + +template +void AuctionRunnerGS::run_auction_phase() +{ + num_phase++; + //std::cout << "Entered run_auction_phase" << std::endl; + do { + size_t bidder_idx = *unassigned_bidders.begin(); + auto optimal_bid = oracle.get_optimal_bid(bidder_idx); + auto optimal_item_idx = optimal_bid.first; + auto bid_value = optimal_bid.second; + assign_item_to_bidder(optimal_bid.first, bidder_idx); + oracle.set_price(optimal_item_idx, bid_value); + //print_debug(); +#ifdef FOR_R_TDA + if ( num_rounds % 10000 == 0 ) { + Rcpp::check_user_interrupt(); + } +#endif + } while (not unassigned_bidders.empty()); + //std::cout << "run_auction_phase finished" << std::endl; + +#ifdef DEBUG_AUCTION + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + if ( bidders_to_items[bidder_idx] < 0 or bidders_to_items[bidder_idx] >= num_bidders) { + std::cerr << "After auction terminated bidder " << bidder_idx; + std::cerr << " has no items assigned" << std::endl; + throw std::runtime_error("Auction did not give a perfect matching"); + } + } +#endif + +} + +template +R AuctionRunnerGS::get_item_bidder_cost(const size_t item_idx, const size_t bidder_idx, const bool tolerate_invalid_idx) const +{ + if (item_idx != k_invalid_index and bidder_idx != k_invalid_index) { + return std::pow(dist_lp(bidders[bidder_idx], items[item_idx], internal_p, dimension), + wasserstein_power); + } else { + if (tolerate_invalid_idx) + return R(0.0); + else + throw std::runtime_error("Invalid idx in get_item_bidder_cost, item_idx = " + std::to_string(item_idx) + ", bidder_idx = " + std::to_string(bidder_idx)); + } +} + +template +R AuctionRunnerGS::getDistanceToQthPowerInternal() +{ + sanity_check(); + Real result = 0.0; + //std::cout << "-------------------------------------------------------------------------\n"; + for(size_t bIdx = 0; bIdx < num_bidders; ++bIdx) { + result += get_item_bidder_cost(bidders_to_items[bIdx], bIdx); + } + //std::cout << "-------------------------------------------------------------------------\n"; + wasserstein_cost = result; + return result; +} + +template +R AuctionRunnerGS::get_wasserstein_distance() +{ + assert(is_distance_computed); + return pow(get_wasserstein_cost(), 1.0/wasserstein_power); +} + +template +R AuctionRunnerGS::get_wasserstein_cost() +{ + assert(is_distance_computed); + return wasserstein_cost; +} + + + +// Debug routines + +template +void AuctionRunnerGS::print_debug() +{ +#ifdef DEBUG_AUCTION + sanity_check(); + std::cout << "**********************" << std::endl; + std::cout << "Current assignment:" << std::endl; + for(size_t idx = 0; idx < bidders_to_items.size(); ++idx) { + std::cout << idx << " <--> " << bidders_to_items[idx] << std::endl; + } + std::cout << "Weights: " << std::endl; + //for(size_t i = 0; i < num_bidders; ++i) { + //for(size_t j = 0; j < num_items; ++j) { + //std::cout << oracle.weight_matrix[i][j] << " "; + //} + //std::cout << std::endl; + //} + std::cout << "Prices: " << std::endl; + for(const auto price : oracle.get_prices()) { + std::cout << price << std::endl; + } + std::cout << "**********************" << std::endl; +#endif +} + + +template +void AuctionRunnerGS::sanity_check() +{ +#ifdef DEBUG_AUCTION + if (bidders_to_items.size() != num_bidders) { + std::cerr << "Wrong size of bidders_to_items, must be " << num_bidders << ", is " << bidders_to_items.size() << std::endl; + throw std::runtime_error("Wrong size of bidders_to_items"); + } + + if (items_to_bidders.size() != num_bidders) { + std::cerr << "Wrong size of items_to_bidders, must be " << num_bidders << ", is " << items_to_bidders.size() << std::endl; + throw std::runtime_error("Wrong size of items_to_bidders"); + } + + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + assert( bidders_to_items[bidder_idx] == k_invalid_index or ( bidders_to_items[bidder_idx] < num_items and bidders_to_items[bidder_idx] >= 0)); + + if ( bidders_to_items[bidder_idx] != k_invalid_index) { + + if ( std::count(bidders_to_items.begin(), + bidders_to_items.end(), + bidders_to_items[bidder_idx]) > 1 ) { + std::cerr << "Item " << bidders_to_items[bidder_idx]; + std::cerr << " appears in bidders_to_items more than once" << std::endl; + throw std::runtime_error("Duplicate in bidders_to_items"); + } + + if (items_to_bidders.at(bidders_to_items[bidder_idx]) != static_cast(bidder_idx)) { + std::cerr << "Inconsitency: bidder_idx = " << bidder_idx; + std::cerr << ", item_idx in bidders_to_items = "; + std::cerr << bidders_to_items[bidder_idx]; + std::cerr << ", bidder_idx in items_to_bidders = "; + std::cerr << items_to_bidders[bidders_to_items[bidder_idx]] << std::endl; + throw std::runtime_error("inconsistent mapping"); + } + } + } + + for(IdxType item_idx = 0; item_idx < static_cast(num_bidders); ++item_idx) { + assert( items_to_bidders[item_idx] == k_invalid_index or ( items_to_bidders[item_idx] < num_items and items_to_bidders[item_idx] >= 0)); + if ( items_to_bidders.at(item_idx) != k_invalid_index) { + + // check for uniqueness + if ( std::count(items_to_bidders.begin(), + items_to_bidders.end(), + items_to_bidders[item_idx]) > 1 ) { + std::cerr << "Bidder " << items_to_bidders[item_idx]; + std::cerr << " appears in items_to_bidders more than once" << std::endl; + throw std::runtime_error("Duplicate in items_to_bidders"); + } + // check for consistency + if (bidders_to_items.at(items_to_bidders.at(item_idx)) != static_cast(item_idx)) { + std::cerr << "Inconsitency: item_idx = " << item_idx; + std::cerr << ", bidder_idx in items_to_bidders = "; + std::cerr << items_to_bidders[item_idx]; + std::cerr << ", item_idx in bidders_to_items= "; + std::cerr << bidders_to_items[items_to_bidders[item_idx]] << std::endl; + throw std::runtime_error("inconsistent mapping"); + } + } + } +#endif +} + +template +void AuctionRunnerGS::print_matching() +{ +#ifdef DEBUG_AUCTION + sanity_check(); + for(size_t bIdx = 0; bIdx < bidders_to_items.size(); ++bIdx) { + if (bidders_to_items[bIdx] != k_invalid_index) { + auto pA = bidders[bIdx]; + auto pB = items[bidders_to_items[bIdx]]; + std::cout << pA << " <-> " << pB << "+" << pow(dist_lp(pA, pB, internal_p, dimension), wasserstein_power) << std::endl; + } else { + assert(false); + } + } +#endif +} + +} // ws +} // hera diff --git a/src/dionysus/wasserstein/auction_runner_gs_single_diag.h b/src/dionysus/wasserstein/auction_runner_gs_single_diag.h new file mode 100755 index 0000000..f32fbbc --- /dev/null +++ b/src/dionysus/wasserstein/auction_runner_gs_single_diag.h @@ -0,0 +1,149 @@ +/* + +Copyright (c) 2016, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef AUCTION_RUNNER_GS_SINGLE_DIAG_H +#define AUCTION_RUNNER_GS_SINGLE_DIAG_H + +#include +#include + +#include "auction_oracle.h" + +namespace hera { +namespace ws { + +// the two parameters that you can tweak in auction algorithm are: +// 1. epsilon_common_ratio +// 2. max_iter_num + +template > // alternatively: AuctionOracleLazyHeap --- TODO +class AuctionRunnerGaussSeidelSingleDiag { +public: + using RealType = RealType_; + using Real = RealType_; + using AuctionOracle = AuctionOracle_; + using DgmPoint = DiagramPoint; + using DgmPointVec = std::vector; + using DiagonalBidR = DiagonalBid; + + + + AuctionRunnerGaussSeidelSingleDiag(const DgmPointVec& A, + const DgmPointVec& B, + const Real q, + const Real _delta, + const Real _internal_p, + const Real _initial_epsilon, + const Real _eps_factor, + const int _max_iter_num = std::numeric_limits::max()); + + void set_epsilon(Real new_val) { oracle->set_epsilon(new_val); }; + Real get_epsilon() const { return oracle->get_epsilon(); } + Real get_wasserstein_cost(); + Real get_wasserstein_distance(); + Real get_relative_error() const { return relative_error; }; + void enable_logging(const char* log_filename, const size_t _max_unassigned_to_log); +//private: + // private data + DgmPointVec bidders, items; + const size_t num_bidders; + const size_t num_items; + size_t num_normal_bidders; + size_t num_diag_bidders; + size_t num_normal_items; + size_t num_diag_items; + std::vector items_to_bidders; + std::vector bidders_to_items; + const Real wasserstein_power; + const Real delta; + const Real internal_p; + const Real initial_epsilon; + Real epsilon_common_ratio; // next epsilon = current epsilon / epsilon_common_ratio + const int max_iter_num; // maximal number of iterations of epsilon-scaling + Real weight_adj_const; + Real wasserstein_cost; + Real relative_error; + // to get the 2 best items we use oracle + std::unique_ptr oracle; + // unassigned guys + std::unordered_set unassigned_normal_bidders; + std::unordered_set unassigned_diag_bidders; + // private methods + // + void process_diagonal_bid(const DiagonalBidR& bid); + + void assign_item_to_bidder(const IdxType item_idx, + const IdxType bidder_idx, + const IdxType old_owner_idx, + const bool item_is_diagonal, + const bool bidder_is_diagonal, + const bool call_set_prices = false, + const Real new_price = std::numeric_limits::max()); + + void run_auction(); + void run_auction_phases(const int max_num_phases, const Real _initial_epsilon); + void run_auction_phase(); + void flush_assignment(); + // return 0, if item_idx is invalid + Real get_item_bidder_cost(const size_t item_idx, const size_t bidder_idx, const bool tolerate_invalid_idx = false) const; + + bool is_bidder_diagonal(const size_t bidder_idx) const; + bool is_bidder_normal(const size_t bidder_idx) const; + bool is_item_diagonal(const size_t item_idx) const; + bool is_item_normal(const size_t item_idx) const; + + OwnerType get_owner_type(size_t bidder_idx) const; + + // for debug only + void sanity_check(); + void print_debug(); + int count_unhappy(); + void print_matching(); + Real getDistanceToQthPowerInternal(); + int num_phase { 0 }; + int num_rounds { 0 }; +#ifdef LOG_AUCTION + bool log_auction { false }; + std::unordered_set unassigned_items; + size_t max_unassigned_to_log { 0 }; + const char* logger_name = "auction_detailed_logger"; // the name in spdlog registry; filename is provided as parameter in enable_logging + const Real total_items_persistence; + const Real total_bidders_persistence; + Real partial_cost; + Real unassigned_bidders_persistence; + Real unassigned_items_persistence; +#endif +}; + +} // ws +} // hera + + +#include "auction_runner_gs_single_diag.hpp" + +#endif diff --git a/src/dionysus/wasserstein/auction_runner_gs_single_diag.hpp b/src/dionysus/wasserstein/auction_runner_gs_single_diag.hpp new file mode 100755 index 0000000..a3c401e --- /dev/null +++ b/src/dionysus/wasserstein/auction_runner_gs_single_diag.hpp @@ -0,0 +1,738 @@ +/* + +Copyright (c) 2016, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + + +#include +#include +#include +#include +#include +#include + +#include "def_debug_ws.h" + +#define PRINT_DETAILED_TIMING + +#ifdef FOR_R_TDA +#include "Rcpp.h" +#undef DEBUG_AUCTION +#endif + + +namespace hera { +namespace ws { + +// ***************************** +// AuctionRunnerGaussSeidelSingleDiag +// ***************************** + +template +std::ostream& operator<<(std::ostream& s, const AuctionRunnerGaussSeidelSingleDiag& ar) +{ + s << "--------------------------------------------------\n"; + s << "AuctionRunnerGaussSeidelSingleDiag, current assignment, bidders_to_items:" << std::endl; + for(size_t idx = 0; idx < ar.bidders_to_items.size(); ++idx) { + s << idx << " <--> " << ar.bidders_to_items[idx] << std::endl; + } + s << "--------------------------------------------------\n"; + s << "AuctionRunnerGaussSeidelSingleDiag, current assignment, items_to_bidders:" << std::endl; + for(size_t idx = 0; idx < ar.items_to_bidders.size(); ++idx) { + s << idx << " <--> " << ar.items_to_bidders[idx] << std::endl; + } + s << "--------------------------------------------------\n"; + s << "AuctionRunnerGaussSeidelSingleDiag, prices:" << std::endl; + for(size_t item_idx = 0; item_idx < ar.num_items; ++item_idx) { + s << item_idx << ": " << ar.oracle->get_price(item_idx) << std::endl; + } + s << "--------------------------------------------------\n"; + s << "AuctionRunnerGaussSeidelSingleDiag, oracle :" << *(ar.oracle) << std::endl; + s << "--------------------------------------------------\n"; + return s; +} + + +template +AuctionRunnerGaussSeidelSingleDiag::AuctionRunnerGaussSeidelSingleDiag(const DgmPointVec& A, + const DgmPointVec& B, + const Real q, + const Real _delta, + const Real _internal_p, + const Real _initial_epsilon, + const Real _eps_factor, + const int _max_iter_num) : + bidders(A), + items(B), + num_bidders(A.size()), + num_items(B.size()), + items_to_bidders(B.size(), k_invalid_index), + bidders_to_items(A.size(), k_invalid_index), + wasserstein_power(q), + delta(_delta), + internal_p(_internal_p), + initial_epsilon(_initial_epsilon), + epsilon_common_ratio(_eps_factor == 0.0 ? 5.0 : _eps_factor), + max_iter_num(_max_iter_num) +#ifdef LOG_AUCTION + , total_items_persistence(std::accumulate(items.begin(), + items.end(), + R(0.0), + [_internal_p, q](const Real& ps, const DgmPoint& item) + { return ps + std::pow(item.persistence_lp(_internal_p), q); } + )) + + , total_bidders_persistence(std::accumulate(bidders.begin(), + bidders.end(), + R(0.0), + [_internal_p, q](const Real& ps, const DgmPoint& bidder) + { return ps + std::pow(bidder.persistence_lp(_internal_p), q); } + )) + , partial_cost(0.0) + , unassigned_bidders_persistence(0.0) + , unassigned_items_persistence(0.0) +#endif + +{ + assert(initial_epsilon >= 0.0 ); + assert(epsilon_common_ratio >= 0.0 ); + assert(A.size() == B.size()); + oracle = std::unique_ptr(new AuctionOracle(bidders, items, wasserstein_power, internal_p)); + + for(num_normal_bidders = 0; num_normal_bidders < num_bidders; ++num_normal_bidders) { + if (bidders[num_normal_bidders].is_diagonal()) + break; + } + + num_diag_bidders = num_bidders - num_normal_bidders; + num_diag_items = num_normal_bidders; + num_normal_items = num_items - num_diag_items; + + for(size_t i = num_normal_bidders; i < num_bidders; ++i) { + assert(bidders[i].is_diagonal()); + } + +#ifdef LOG_AUCTION + + unassigned_items_persistence = total_items_persistence; + unassigned_bidders_persistence = total_bidders_persistence; + + if (not spdlog::get("plot_logger")) { + auto log = spdlog::basic_logger_st("plot_logger", "plot_logger.txt"); + log->info("New plot starts here"); + log->set_pattern("%v"); + } +#endif + +} + +#ifdef LOG_AUCTION +template +void AuctionRunnerGaussSeidelSingleDiag::enable_logging(const char* log_filename, const size_t _max_unassigned_to_log) +{ + log_auction = true; + max_unassigned_to_log = _max_unassigned_to_log; + + auto log = spdlog::basic_logger_st(logger_name, log_filename); + log->set_pattern("%v"); +} +#endif + +template +void AuctionRunnerGaussSeidelSingleDiag::process_diagonal_bid(const DiagonalBidR& bid) +{ + + //std::cout << "Enter process_diagonal_bid, bid = " << bid << std::endl; + + // increase price of already assigned normal items + for(size_t k = 0; k < bid.assigned_normal_items.size(); ++k) { + size_t assigned_normal_item_idx = bid.assigned_normal_items[k]; + Real new_price = bid.assigned_normal_items_bid_values[k]; + bool item_is_diagonal = false; + bool bidder_is_diagonal = true; + + // TODO: SPECIAL PROCEDURE HEER` + oracle->set_price(assigned_normal_item_idx, new_price, item_is_diagonal, bidder_is_diagonal, OwnerType::k_diagonal); + } + + // set common diag-diag price + // if diag_assigned_to_diag_slice_ is empty, it will be + // numeric_limits::max() + + oracle->diag_to_diag_price_ = bid.diag_to_diag_value; + + int unassigned_diag_idx = 0; + auto unassigned_diag_item_iter = oracle->diag_unassigned_slice_.begin(); + auto bid_vec_idx = 0; + for(const auto diag_bidder_idx : unassigned_diag_bidders) { + if (unassigned_diag_idx < bid.num_from_unassigned_diag) { + // take diagonal point from unassigned slice + + //std::cout << "assigning to diag_bidder_idx = " << diag_bidder_idx << std::endl; + assert(unassigned_diag_item_iter != oracle->diag_unassigned_slice_.end()); + + auto item_idx = *unassigned_diag_item_iter; + + ++unassigned_diag_idx; + ++unassigned_diag_item_iter; + assign_item_to_bidder(item_idx, diag_bidder_idx, k_invalid_index, true, true, false); + } else { + // take point from best_item_indices + size_t item_idx = bid.best_item_indices[bid_vec_idx]; + Real new_price = bid.bid_values[bid_vec_idx]; + bid_vec_idx++; + + auto old_owner_idx = items_to_bidders[item_idx]; + bool item_is_diagonal = is_item_diagonal(item_idx); + + assign_item_to_bidder(item_idx, diag_bidder_idx, old_owner_idx, item_is_diagonal, true, true, new_price); + } + } + + // all bids of diagonal bidders are satisfied + unassigned_diag_bidders.clear(); + + if (oracle->diag_unassigned_slice_.empty()) { + oracle->diag_unassigned_price_ = std::numeric_limits::max(); + } + + //std::cout << "Exit process_diagonal_bid\n" << *this; +} + +template +bool AuctionRunnerGaussSeidelSingleDiag::is_bidder_diagonal(const size_t bidder_idx) const +{ + return bidder_idx >= num_normal_bidders; +} + +template +bool AuctionRunnerGaussSeidelSingleDiag::is_bidder_normal(const size_t bidder_idx) const +{ + return bidder_idx < num_normal_bidders; +} + +template +bool AuctionRunnerGaussSeidelSingleDiag::is_item_diagonal(const size_t item_idx) const +{ + return item_idx < num_diag_items; +} + +template +bool AuctionRunnerGaussSeidelSingleDiag::is_item_normal(const size_t item_idx) const +{ + return item_idx >= num_diag_items; +} + +template +void AuctionRunnerGaussSeidelSingleDiag::assign_item_to_bidder(const IdxType item_idx, + const IdxType bidder_idx, + const IdxType old_owner_idx, + const bool item_is_diagonal, + const bool bidder_is_diagonal, + const bool call_set_price, + const R new_price) +{ + //std::cout << "Enter assign_item_to_bidder, " << std::boolalpha ; + //std::cout << "item_idx = " << item_idx << ", bidder_idx = " << bidder_idx << ", old_owner_idx = " << old_owner_idx << ", item_is_diagonal = " << item_is_diagonal << ", bidder_is_diagonal = " << bidder_is_diagonal << std::endl; + //std::cout << "################################################################################" << std::endl; + //std::cout << *this << std::endl; + //std::cout << *(this->oracle) << std::endl; + //std::cout << "################################################################################" << std::endl; + num_rounds++; + + // for readability + const bool item_is_normal = not item_is_diagonal; + const bool bidder_is_normal = not bidder_is_diagonal; + + // only unassigned bidders should submit bids and get items + assert(bidders_to_items[bidder_idx] == k_invalid_index); + + + // update matching information + bidders_to_items[bidder_idx] = item_idx; + items_to_bidders[item_idx] = bidder_idx; + + + // remove bidder from the list of unassigned bidders + // for diagonal bidders we don't need to: in Gauss-Seidel they are all + // processed at once, so the set unassigned_diag_bidders will be cleared + if (bidder_is_normal) { + unassigned_normal_bidders.erase(bidder_idx); + } + + OwnerType old_owner_type = get_owner_type(old_owner_idx); + + if (old_owner_type != OwnerType::k_none) { + bidders_to_items[old_owner_idx] = k_invalid_index; + } + + switch(old_owner_type) + { + case OwnerType::k_normal : unassigned_normal_bidders.insert(old_owner_idx); + break; + case OwnerType::k_diagonal : unassigned_diag_bidders.insert(old_owner_idx); + break; + case OwnerType::k_none : break; + } + + + // update normal_items_assigned_to_diag_ + + if (old_owner_type == OwnerType::k_diagonal and item_is_normal and bidder_is_normal) { + // normal item was stolen from diagonal, erase + assert( oracle->normal_items_assigned_to_diag_.count(item_idx) == 1 ); + oracle->normal_items_assigned_to_diag_.erase(item_idx); + } else if (bidder_is_diagonal and item_is_normal and old_owner_type != OwnerType::k_diagonal) { + // diagonal bidder got a new normal item, insert + assert(oracle->normal_items_assigned_to_diag_.count(item_idx) == 0); + oracle->normal_items_assigned_to_diag_.insert(item_idx); + } + + + // update diag_assigned_to_diag_slice_ + if (item_is_diagonal and bidder_is_normal and old_owner_type == OwnerType::k_diagonal) { + assert( oracle->diag_assigned_to_diag_slice_.count(item_idx) == 1); + oracle->diag_assigned_to_diag_slice_.erase(item_idx); + } else if (item_is_diagonal and bidder_is_diagonal) { + assert( old_owner_type != OwnerType::k_diagonal ); // diagonal does not steal from itself + assert( oracle->diag_assigned_to_diag_slice_.count(item_idx) == 0); + oracle->diag_assigned_to_diag_slice_.insert(item_idx); + } + + // update diag_unassigned_slice_ + if (item_is_diagonal and old_owner_type == OwnerType::k_none) { + oracle->diag_unassigned_slice_.erase(item_idx); + } + + if ( not (not call_set_price or new_price != std::numeric_limits::max())) { + std::cout << "In the middle of assign_item_to_bidder, " << std::boolalpha ; + std::cout << "item_idx = " << item_idx << ", bidder_idx = " << bidder_idx << ", old_owner_idx = " << old_owner_idx << ", item_is_diagonal = " << item_is_diagonal << ", bidder_is_diagonal = " << bidder_is_diagonal << std::endl; + std::cout << "################################################################################" << std::endl; + std::cout << *this << std::endl; + std::cout << "################################################################################" << std::endl; + } + assert(not call_set_price or new_price != std::numeric_limits::max()); + if (call_set_price) { + oracle->set_price(item_idx, new_price, item_is_diagonal, bidder_is_diagonal, old_owner_type); + } + + //std::cout << "Exit assign_item_to_bidder, state\n" << *this << std::endl; + +#ifdef LOG_AUCTION + + partial_cost += get_item_bidder_cost(item_idx, bidder_idx, true); + partial_cost -= get_item_bidder_cost(item_idx, old_owner_idx, true); + + unassigned_items.erase(item_idx); + + unassigned_bidders_persistence -= std::pow(bidders[bidder_idx].persistence_lp(internal_p), wasserstein_power); + + if (old_owner_type != OwnerType::k_none) { + // item has been assigned to some other bidder, + // and he became unassigned + unassigned_bidders_persistence += std::pow(bidders[old_owner_idx].persistence_lp(internal_p), wasserstein_power); + } else { + // item was unassigned before + unassigned_items_persistence -= std::pow(items[item_idx].persistence_lp(internal_p), wasserstein_power); + } + + auto plot_logger = spdlog::get("plot_logger"); + plot_logger->info("{0} {1} {2} {3} {4} {5} {6} {7} {8} {9} {10}", + num_phase, + num_rounds, + unassigned_normal_bidders.size(), + unassigned_diag_bidders.size(), + unassigned_items_persistence, + unassigned_bidders_persistence, + unassigned_items_persistence + unassigned_bidders_persistence, + partial_cost, + total_bidders_persistence, + total_items_persistence, + oracle->get_epsilon() + ); + + + if (log_auction and unassigned_normal_bidders.size() + unassigned_diag_bidders.size() <= max_unassigned_to_log) { + auto logger = spdlog::get(logger_name); + if (logger) { + auto item = items[item_idx]; + auto bidder = bidders[bidder_idx]; + logger->info("{0} # ({1}, {2}) # ({3}, {4}) # {5} # {6} # {7} # {8}", + num_rounds, + item.getRealX(), + item.getRealY(), + bidder.getRealX(), + bidder.getRealY(), + format_point_set_to_log(unassigned_diag_bidders, bidders), + format_point_set_to_log(unassigned_normal_bidders, bidders), + format_point_set_to_log(unassigned_items, items), + oracle->get_epsilon()); + } + } +#endif +} + + + +template +void AuctionRunnerGaussSeidelSingleDiag::flush_assignment() +{ + for(auto& b2i : bidders_to_items) { + b2i = k_invalid_index; + } + for(auto& i2b : items_to_bidders) { + i2b = k_invalid_index; + } + + // we must flush assignment only after we got perfect matching + assert(unassigned_normal_bidders.empty() and unassigned_diag_bidders.empty()); + // all bidders become unassigned + for(size_t bidder_idx = 0; bidder_idx < num_normal_bidders; ++bidder_idx) { + unassigned_normal_bidders.insert(bidder_idx); + } + for(size_t bidder_idx = num_normal_bidders; bidder_idx < num_bidders; ++bidder_idx) { + unassigned_diag_bidders.insert(bidder_idx); + } + assert(unassigned_normal_bidders.size() + unassigned_diag_bidders.size() == bidders.size()); + assert(unassigned_normal_bidders.size() == num_normal_bidders); + assert(unassigned_diag_bidders.size() == num_diag_bidders); + + oracle->flush_assignment(); + oracle->adjust_prices(); + +#ifdef LOG_AUCTION + partial_cost = 0.0; + unassigned_bidders_persistence = total_bidders_persistence; + unassigned_items_persistence = total_items_persistence; + + for(size_t item_idx = 0; item_idx < items.size(); ++item_idx) { + unassigned_items.insert(item_idx); + } +#endif + +} + + +template +void AuctionRunnerGaussSeidelSingleDiag::run_auction_phases(const int max_num_phases, const Real _initial_epsilon) +{ + relative_error = std::numeric_limits::max(); + // choose some initial epsilon + oracle->set_epsilon(_initial_epsilon); + assert( oracle->get_epsilon() > 0 ); + for(int phase_num = 0; phase_num < max_num_phases; ++phase_num) { + flush_assignment(); + run_auction_phase(); + phase_num++; + //std::cout << "Iteration " << phase_num << " completed. " << std::endl; + // result is d^q + Real current_result = getDistanceToQthPowerInternal(); + Real denominator = current_result - num_bidders * oracle->get_epsilon(); + current_result = pow(current_result, 1.0 / wasserstein_power); + //std::cout << "Current result is " << current_result << std::endl; + if ( denominator <= 0 ) { + //std::cout << "Epsilon is too big." << std::endl; + } else { + denominator = pow(denominator, 1.0 / wasserstein_power); + Real numerator = current_result - denominator; + relative_error = numerator / denominator; + //std::cout << " numerator: " << numerator << " denominator: " << denominator << std::endl; + //std::cout << " error bound: " << numerator / denominator << std::endl; + // if relative error is greater than delta, continue + if (relative_error <= delta) { + break; + } + } + // decrease epsilon for the next iteration + oracle->set_epsilon( oracle->get_epsilon() / epsilon_common_ratio ); + } + //print_matching(); +} + + +template +void AuctionRunnerGaussSeidelSingleDiag::run_auction() +{ + double init_eps = ( initial_epsilon > 0.0 ) ? initial_epsilon : oracle->max_val_ / 4.0 ; + run_auction_phases(max_iter_num, init_eps); + if (relative_error > delta) { +#ifndef FOR_R_TDA + std::cerr << "Maximum iteration number exceeded, exiting. Current result is: "; + std::cerr << pow(wasserstein_cost, 1.0/wasserstein_power) << std::endl; +#endif + throw std::runtime_error("Maximum iteration number exceeded"); + } +} + +template +OwnerType AuctionRunnerGaussSeidelSingleDiag::get_owner_type(size_t bidder_idx) const +{ + if (bidder_idx == k_invalid_index) { + return OwnerType::k_none; + } else if (is_bidder_diagonal(bidder_idx)) { + return OwnerType::k_diagonal; + } else { + return OwnerType::k_normal; + } +} + +template +void AuctionRunnerGaussSeidelSingleDiag::run_auction_phase() +{ + num_phase++; + //std::cout << "Entered run_auction_phase" << std::endl; + do { + + if (not unassigned_diag_bidders.empty()) { + // process all unassigned diagonal bidders + // easy for Gauss-Seidel: every bidder alwasy gets all he wants + // + sanity_check(); + //std::cout << "Current state " << __LINE__ << *this << std::endl; + process_diagonal_bid(oracle->get_optimal_bids_for_diagonal( unassigned_diag_bidders.size() )); + sanity_check(); + } else { + sanity_check(); + // process normal unassigned bidder + size_t bidder_idx = *(unassigned_normal_bidders.begin()); + auto optimal_bid = oracle->get_optimal_bid(bidder_idx); + auto optimal_item_idx = optimal_bid.first; + auto bid_value = optimal_bid.second; + bool item_is_diagonal = is_item_diagonal(optimal_item_idx); + size_t old_owner_idx = items_to_bidders[optimal_item_idx]; + + //OwnerType old_owner_type = get_owner_type(old_owner_idx); + //std::cout << "bidder_idx = " << bidder_idx << ", item_idx = " << optimal_item_idx << ", old_owner_type = " << old_owner_type << std::endl; + + assign_item_to_bidder(optimal_item_idx, bidder_idx, old_owner_idx, item_is_diagonal, false, true, bid_value); + sanity_check(); + } + +#ifdef FOR_R_TDA + if ( num_rounds % 10000 == 0 ) { + Rcpp::check_user_interrupt(); + } +#endif + } while (not (unassigned_diag_bidders.empty() and unassigned_normal_bidders.empty())); + //std::cout << "run_auction_phase finished" << std::endl; + +#ifdef DEBUG_AUCTION + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + if ( bidders_to_items[bidder_idx] < 0 or bidders_to_items[bidder_idx] >= (IdxType)num_bidders) { + std::cerr << "After auction terminated bidder " << bidder_idx; + std::cerr << " has no items assigned" << std::endl; + throw std::runtime_error("Auction did not give a perfect matching"); + } + } +#endif + +} + +template +R AuctionRunnerGaussSeidelSingleDiag::get_item_bidder_cost(size_t item_idx, size_t bidder_idx, const bool tolerate_invalid_idx) const +{ + if (item_idx != k_invalid_index and bidder_idx != k_invalid_index) { + // skew edges are replaced by edges to projection + if (is_bidder_diagonal(bidder_idx) and is_item_normal(item_idx)) { + bidder_idx = item_idx; + } else if (is_bidder_normal(bidder_idx) and is_item_diagonal(item_idx)) { + item_idx = bidder_idx; + } + return std::pow(dist_lp(bidders[bidder_idx], items[item_idx], internal_p), + wasserstein_power); + } else { + if (tolerate_invalid_idx) + return R(0.0); + else + throw std::runtime_error("Invalid idx in get_item_bidder_cost, item_idx = " + std::to_string(item_idx) + ", bidder_idx = " + std::to_string(bidder_idx)); + } +} + +template +R AuctionRunnerGaussSeidelSingleDiag::getDistanceToQthPowerInternal() +{ + sanity_check(); + Real result = 0.0; + for(size_t bIdx = 0; bIdx < num_bidders; ++bIdx) { + result += get_item_bidder_cost(bidders_to_items[bIdx], bIdx); + } + wasserstein_cost = result; + return result; +} + +template +R AuctionRunnerGaussSeidelSingleDiag::get_wasserstein_distance() +{ + return pow(get_wasserstein_cost(), 1.0/wasserstein_power); +} + +template +R AuctionRunnerGaussSeidelSingleDiag::get_wasserstein_cost() +{ + run_auction(); + return wasserstein_cost; +} + + + +// Debug routines + + +template +void AuctionRunnerGaussSeidelSingleDiag::print_debug() +{ +#ifdef DEBUG_AUCTION + std::cout << "**********************" << std::endl; + std::cout << "Current assignment:" << std::endl; + for(size_t idx = 0; idx < bidders_to_items.size(); ++idx) { + std::cout << idx << " <--> " << bidders_to_items[idx] << std::endl; + } + std::cout << "Weights: " << std::endl; + //for(size_t i = 0; i < num_bidders; ++i) { + //for(size_t j = 0; j < num_items; ++j) { + //std::cout << oracle->weight_matrix[i][j] << " "; + //} + //std::cout << std::endl; + //} + std::cout << "Prices: " << std::endl; + for(const auto price : oracle->get_prices()) { + std::cout << price << std::endl; + } + std::cout << "**********************" << std::endl; +#endif +} + + +template +void AuctionRunnerGaussSeidelSingleDiag::sanity_check() +{ +#ifdef DEBUG_AUCTION + if (bidders_to_items.size() != num_bidders) { + std::cerr << "Wrong size of bidders_to_items, must be " << num_bidders << ", is " << bidders_to_items.size() << std::endl; + throw std::runtime_error("Wrong size of bidders_to_items"); + } + + if (items_to_bidders.size() != num_bidders) { + std::cerr << "Wrong size of items_to_bidders, must be " << num_bidders << ", is " << items_to_bidders.size() << std::endl; + throw std::runtime_error("Wrong size of items_to_bidders"); + } + + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + assert( bidders_to_items[bidder_idx] == k_invalid_index or ( bidders_to_items[bidder_idx] < static_cast(num_items) and bidders_to_items[bidder_idx] >= 0)); + + if ( bidders_to_items[bidder_idx] != k_invalid_index) { + + if ( std::count(bidders_to_items.begin(), + bidders_to_items.end(), + bidders_to_items[bidder_idx]) > 1 ) { + std::cerr << "Item " << bidders_to_items[bidder_idx]; + std::cerr << " appears in bidders_to_items more than once" << std::endl; + throw std::runtime_error("Duplicate in bidders_to_items"); + } + + if (items_to_bidders.at(bidders_to_items[bidder_idx]) != static_cast(bidder_idx)) { + std::cerr << "Inconsitency: bidder_idx = " << bidder_idx; + std::cerr << ", item_idx in bidders_to_items = "; + std::cerr << bidders_to_items[bidder_idx]; + std::cerr << ", bidder_idx in items_to_bidders = "; + std::cerr << items_to_bidders[bidders_to_items[bidder_idx]] << std::endl; + throw std::runtime_error("inconsistent mapping"); + } + } + } + + for(size_t item_idx = 0; item_idx < num_diag_items; ++item_idx) { + auto owner = items_to_bidders.at(item_idx); + if ( owner == k_invalid_index) { + assert((oracle->diag_unassigned_slice_.count(item_idx) == 1 and + oracle->diag_items_heap__iters_[item_idx] == oracle->diag_items_heap_.end() and + oracle->all_items_heap__iters_[item_idx] == oracle->all_items_heap_.end()) + or + (oracle->diag_unassigned_slice_.count(item_idx) == 0 and + oracle->diag_items_heap__iters_[item_idx] != oracle->diag_items_heap_.end() and + oracle->all_items_heap__iters_[item_idx] != oracle->all_items_heap_.end())); + assert(oracle->diag_assigned_to_diag_slice_.count(item_idx) == 0); + } else { + if (is_bidder_diagonal(owner)) { + assert(oracle->diag_unassigned_slice_.count(item_idx) == 0); + assert(oracle->diag_assigned_to_diag_slice_.count(item_idx) == 1); + assert(oracle->diag_items_heap__iters_[item_idx] == oracle->diag_items_heap_.end()); + assert(oracle->all_items_heap__iters_[item_idx] == oracle->all_items_heap_.end()); + } else { + assert(oracle->diag_unassigned_slice_.count(item_idx) == 0); + assert(oracle->diag_assigned_to_diag_slice_.count(item_idx) == 0); + assert(oracle->diag_items_heap__iters_[item_idx] != oracle->diag_items_heap_.end()); + assert(oracle->all_items_heap__iters_[item_idx] != oracle->all_items_heap_.end()); + } + } + } + + for(IdxType item_idx = 0; item_idx < static_cast(num_bidders); ++item_idx) { + assert( items_to_bidders[item_idx] == k_invalid_index or ( items_to_bidders[item_idx] < static_cast(num_items) and items_to_bidders[item_idx] >= 0)); + if ( items_to_bidders.at(item_idx) != k_invalid_index) { + + // check for uniqueness + if ( std::count(items_to_bidders.begin(), + items_to_bidders.end(), + items_to_bidders[item_idx]) > 1 ) { + std::cerr << "Bidder " << items_to_bidders[item_idx]; + std::cerr << " appears in items_to_bidders more than once" << std::endl; + throw std::runtime_error("Duplicate in items_to_bidders"); + } + // check for consistency + if (bidders_to_items.at(items_to_bidders.at(item_idx)) != static_cast(item_idx)) { + std::cerr << "Inconsitency: item_idx = " << item_idx; + std::cerr << ", bidder_idx in items_to_bidders = "; + std::cerr << items_to_bidders[item_idx]; + std::cerr << ", item_idx in bidders_to_items= "; + std::cerr << bidders_to_items[items_to_bidders[item_idx]] << std::endl; + throw std::runtime_error("inconsistent mapping"); + } + } + } + + oracle->sanity_check(); +#endif +} + +template +void AuctionRunnerGaussSeidelSingleDiag::print_matching() +{ +#ifdef DEBUG_AUCTION + sanity_check(); + for(size_t bIdx = 0; bIdx < bidders_to_items.size(); ++bIdx) { + if (bidders_to_items[bIdx] != k_invalid_index) { + auto pA = bidders[bIdx]; + auto pB = items[bidders_to_items[bIdx]]; + std::cout << pA << " <-> " << pB << "+" << pow(dist_lp(pA, pB, internal_p), wasserstein_power) << std::endl; + } else { + assert(false); + } + } +#endif +} + +} // ws +} // hera diff --git a/src/dionysus/wasserstein/auction_runner_jac.h b/src/dionysus/wasserstein/auction_runner_jac.h new file mode 100755 index 0000000..252ca32 --- /dev/null +++ b/src/dionysus/wasserstein/auction_runner_jac.h @@ -0,0 +1,230 @@ +/* + +Copyright (c) 2016, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef HERA_AUCTION_RUNNER_JAC_H +#define HERA_AUCTION_RUNNER_JAC_H + +#ifdef WASSERSTEIN_PURE_GEOM +#undef LOG_AUCTION +#undef ORDERED_BY_PERSISTENCE +#endif + +//#define ORDERED_BY_PERSISTENCE + +#include + +#include "auction_oracle.h" + +namespace hera { +namespace ws { + +// the two parameters that you can tweak in auction algorithm are: +// 1. epsilon_common_ratio +// 2. max_num_phases + +template, class PointContainer_ = std::vector> > // alternatively: AuctionOracleLazyHeap --- TODO +class AuctionRunnerJac { +public: + + using Real = RealType_; + using AuctionOracle = AuctionOracle_; + using DgmPoint = typename AuctionOracle::DiagramPointR; + using IdxValPairR = IdxValPair; + using PointContainer = PointContainer_; + + const Real k_lowest_bid_value = -1; // all bid values must be positive + + + AuctionRunnerJac(const PointContainer& A, + const PointContainer& B, + const AuctionParams& params, + const std::string& _log_filename_prefix = ""); + + void set_epsilon(Real new_val); + Real get_epsilon() const { return epsilon; } + void run_auction(); + template + void run_bidding_step(const Range& r); + bool is_done() const; + void decrease_epsilon(); + Real get_wasserstein_distance(); + Real get_wasserstein_cost(); + Real get_relative_error(const bool debug_output = false) const; +//private: + // private data + PointContainer bidders; + PointContainer items; + const size_t num_bidders; + const size_t num_items; + std::vector items_to_bidders; + std::vector bidders_to_items; + Real wasserstein_power; + Real epsilon; + Real delta; + Real internal_p; + Real initial_epsilon; + const Real epsilon_common_ratio; // next epsilon = current epsilon / epsilon_common_ratio + const int max_num_phases; // maximal number of phases of epsilon-scaling + Real weight_adj_const; + Real wasserstein_cost; + std::vector bid_table; + // to get the 2 best items + AuctionOracle oracle; + std::unordered_set unassigned_bidders; + std::unordered_set items_with_bids; + // to imitate Gauss-Seidel + const size_t max_bids_per_round; + Real partial_cost { 0.0 }; + bool is_distance_computed { false }; + int num_rounds { 0 }; + int num_phase { 0 }; + int dimension; + + size_t unassigned_threshold; // for experiments + +#ifndef WASSERSTEIN_PURE_GEOM + std::unordered_set unassigned_normal_bidders; + std::unordered_set unassigned_diag_bidders; + bool diag_first {true}; + size_t batch_size { 1000 }; +#ifdef ORDERED_BY_PERSISTENCE + // to process unassigned by persistence + using RealIdxPair = std::pair; + std::set> unassigned_normal_bidders_by_persistence; +#endif + + // to stop earlier in the last phase + const Real total_items_persistence; + const Real total_bidders_persistence; + Real unassigned_bidders_persistence; + Real unassigned_items_persistence; + Real gamma_threshold; + + + size_t num_diag_items { 0 }; + size_t num_normal_items { 0 }; + size_t num_diag_bidders { 0 }; + size_t num_normal_bidders { 0 }; + + +#endif + + + + // private methods + void assign_item_to_bidder(const IdxType bidder_idx, const IdxType items_idx); + void assign_to_best_bidder(const IdxType items_idx); + void clear_bid_table(); + void run_auction_phases(const int max_num_phases, const Real _initial_epsilon); + void run_auction_phase(); + void submit_bid(IdxType bidder_idx, const IdxValPairR& items_bid_value_pair); + void flush_assignment(); + Real get_item_bidder_cost(const size_t item_idx, const size_t bidder_idx) const; +#ifndef WASSERSTEIN_PURE_GEOM + Real get_cost_to_diagonal(const DgmPoint& pt) const; + Real get_gamma() const; +#endif + bool continue_auction_phase() const; + + void add_unassigned_bidder(const size_t bidder_idx); + void remove_unassigned_bidder(const size_t bidder_idx); + void remove_unassigned_item(const size_t item_idx); + +#ifndef WASSERSTEIN_PURE_GEOM + bool is_item_diagonal(const size_t item_idx) const { return item_idx < num_diag_items; } + bool is_item_normal(const size_t item_idx) const { return not is_item_diagonal(item_idx); } + bool is_bidder_diagonal(const size_t bidder_idx) const { return bidder_idx >= num_normal_bidders; } + bool is_bidder_normal(const size_t bidder_idx) const { return not is_bidder_diagonal(bidder_idx); } +#endif + + + + // for debug only + void sanity_check(); + void print_debug(); + void print_matching(); + + std::string log_filename_prefix; + const Real k_max_relative_error = 2.0; // if relative error cannot be estimated or is too large, use this value + +#ifdef LOG_AUCTION + + size_t parallel_threshold { 5000 }; + bool is_step_parallel {false}; + std::unordered_set unassigned_items; + std::unordered_set unassigned_normal_items; + std::unordered_set unassigned_diag_items; + std::unordered_set never_assigned_bidders; + size_t all_assigned_round { 0 }; + size_t all_assigned_round_found { false }; + + int num_rounds_non_cumulative { 0 }; // set to 0 in the beginning of each phase + int num_diag_assignments { 0 }; + int num_diag_assignments_non_cumulative { 0 }; + int num_diag_bids_submitted { 0 }; + int num_diag_stole_from_diag { 0 }; + int num_normal_assignments { 0 }; + int num_normal_assignments_non_cumulative { 0 }; + int num_normal_bids_submitted { 0 }; + + std::vector> price_change_cnt_vec; + + + const char* plot_logger_name = "plot_logger"; + const char* price_state_logger_name = "price_stat_logger"; + std::string plot_logger_file_name; + std::string price_stat_logger_file_name; + std::shared_ptr plot_logger; + std::shared_ptr price_stat_logger; + std::shared_ptr console_logger; + + + int num_parallel_bids { 0 }; + int num_total_bids { 0 }; + + int num_parallel_diag_bids { 0 }; + int num_total_diag_bids { 0 }; + + int num_parallel_normal_bids { 0 }; + int num_total_normal_bids { 0 }; + + int num_parallel_assignments { 0 }; + int num_total_assignments { 0 }; +#endif + +}; // AuctionRunnerJac + + +} // ws +} // hera + +#include "auction_runner_jac.hpp" + +#undef ORDERED_BY_PERSISTENCE + +#endif diff --git a/src/dionysus/wasserstein/auction_runner_jac.hpp b/src/dionysus/wasserstein/auction_runner_jac.hpp new file mode 100755 index 0000000..8663bae --- /dev/null +++ b/src/dionysus/wasserstein/auction_runner_jac.hpp @@ -0,0 +1,873 @@ +/* + +Copyright (c) 2016, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + +*/ + +#ifndef AUCTION_RUNNER_JAC_HPP +#define AUCTION_RUNNER_JAC_HPP + +#include +#include +#include +#include + +#include "def_debug_ws.h" +#include "auction_runner_jac.h" + + +#ifdef FOR_R_TDA +#include "Rcpp.h" +#undef DEBUG_AUCTION +#endif + + +namespace hera { +namespace ws { + + +// ***************************** +// AuctionRunnerJac +// ***************************** + + template + AuctionRunnerJac::AuctionRunnerJac(const PointContainer& A, + const PointContainer& B, + const AuctionParams& params, + const std::string &_log_filename_prefix + ) : + bidders(A), + items(B), + num_bidders(A.size()), + num_items(A.size()), + items_to_bidders(A.size(), k_invalid_index), + bidders_to_items(A.size(), k_invalid_index), + wasserstein_power(params.wasserstein_power), + delta(params.delta), + internal_p(params.internal_p), + initial_epsilon(params.initial_epsilon), + epsilon_common_ratio(params.epsilon_common_ratio == 0.0 ? 5.0 : params.epsilon_common_ratio), + max_num_phases(params.max_num_phases), + bid_table(A.size(), std::make_pair(k_invalid_index, k_lowest_bid_value)), + oracle(bidders, items, params), + max_bids_per_round(params.max_bids_per_round), + dimension(params.dim), +#ifndef WASSERSTEIN_PURE_GEOM + total_items_persistence(std::accumulate(items.begin(), + items.end(), + R(0.0), + [params](const Real &ps, const DgmPoint &item) { + return ps + std::pow(item.persistence_lp(params.internal_p), params.wasserstein_power); + } + )), + total_bidders_persistence(std::accumulate(bidders.begin(), + bidders.end(), + R(0.0), + [params](const Real &ps, const DgmPoint &bidder) { + return ps + std::pow(bidder.persistence_lp(params.internal_p), params.wasserstein_power); + } + )), + unassigned_bidders_persistence(total_bidders_persistence), + unassigned_items_persistence(total_items_persistence), + gamma_threshold(params.gamma_threshold), +#endif + log_filename_prefix(_log_filename_prefix) + { + assert(A.size() == B.size()); + +#ifndef WASSERSTEIN_PURE_GEOM + for (const auto &p : bidders) { + if (p.is_normal()) { + num_normal_bidders++; + num_diag_items++; + } else { + num_normal_items++; + num_diag_bidders++; + } + } +#endif + // for experiments + unassigned_threshold = 100; + +#ifdef ORDERED_BY_PERSISTENCE + batch_size = 1000; + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + if (is_bidder_normal(bidder_idx)) { + unassigned_normal_bidders_by_persistence.insert( + std::make_pair(bidders[bidder_idx].persistence_lp(1.0), bidder_idx)); + } + } +#endif + +#ifdef LOG_AUCTION + parallel_threshold = 16; + console_logger = spdlog::get("console"); + if (not console_logger) { + console_logger = spdlog::stdout_logger_st("console"); + } + console_logger->set_pattern("[%H:%M:%S.%e] %v"); +#ifdef ORDERED_BY_PERSISTENCE + if (max_bids_per_round == 1) { + console_logger->info("Gauss-Seidel imitated by Jacobi runner, q = {0}, max_bids_per_round = {1}, batch_size = {4}, gamma_threshold = {2}, diag_first = {3} ORDERED_BY_PERSISTENCE", + wasserstein_power, + max_bids_per_round, + gamma_threshold, + diag_first, + batch_size); + } else { + console_logger->info("Jacobi runner, q = {0}, max_bids_per_round = {1}, batch_size = {4}, gamma_threshold = {2}, diag_first = {3} ORDERED_BY_PERSISTENCE", + wasserstein_power, + max_bids_per_round, + gamma_threshold, + diag_first, + batch_size); + } + +#else + if (max_bids_per_round == 1) { + console_logger->info( + "Gauss-Seidel imitated by Jacobi runner, q = {0}, max_bids_per_round = {1}, batch_size = {4}, gamma_threshold = {2}, diag_first = {3}", + wasserstein_power, + max_bids_per_round, + gamma_threshold, + diag_first, + batch_size); + } else { + console_logger->info( + "Jacobi runner, q = {0}, max_bids_per_round = {1}, batch_size = {4}, gamma_threshold = {2}, diag_first = {3}", + wasserstein_power, + max_bids_per_round, + gamma_threshold, + diag_first, + batch_size); + } +#endif + + plot_logger_file_name = log_filename_prefix + "_plot.txt"; + plot_logger = spdlog::get(plot_logger_name); + if (not plot_logger) { + plot_logger = spdlog::basic_logger_st(plot_logger_name, plot_logger_file_name); + } + plot_logger->info("New plot starts here, diagram size = {0}, gamma_threshold = {1}, epsilon_common_ratio = {2}", + bidders.size(), + gamma_threshold, + epsilon_common_ratio); + plot_logger->set_pattern("%v"); + + price_stat_logger_file_name = log_filename_prefix + "_price_change_stat"; + price_stat_logger = spdlog::get(price_state_logger_name); + if (not price_stat_logger) { + price_stat_logger = spdlog::basic_logger_st(price_state_logger_name, + price_stat_logger_file_name); + } + price_stat_logger->info( + "New price statistics starts here, diagram size = {0}, gamma_threshold = {1}, epsilon_common_ratio = {2}", + bidders.size(), + gamma_threshold, + epsilon_common_ratio); + price_stat_logger->set_pattern("%v"); +#endif + } + +#ifndef WASSERSTEIN_PURE_GEOM + template + typename AuctionRunnerJac::Real + AuctionRunnerJac::get_cost_to_diagonal(const DgmPoint &pt) const { + return std::pow(pt.persistence_lp(internal_p), wasserstein_power); + } + + template + typename AuctionRunnerJac::Real + AuctionRunnerJac::get_gamma() const { + return std::pow(std::fabs(unassigned_items_persistence + unassigned_bidders_persistence), + 1.0 / wasserstein_power); + } +#endif + + template + void AuctionRunnerJac::assign_item_to_bidder(IdxType item_idx, IdxType bidder_idx) + { + //sanity_check(); + // only unassigned bidders submit bids + assert(bidders_to_items[bidder_idx] == k_invalid_index); + + IdxType old_item_owner = items_to_bidders[item_idx]; + + // set new owner + bidders_to_items[bidder_idx] = item_idx; + items_to_bidders[item_idx] = bidder_idx; + + // remove bidder and item from the sets of unassigned bidders/items + remove_unassigned_bidder(bidder_idx); + + if (k_invalid_index != old_item_owner) { + // old owner of item becomes unassigned + bidders_to_items[old_item_owner] = k_invalid_index; + add_unassigned_bidder(old_item_owner); + // existing edge was removed, decrease partial_cost + partial_cost -= get_item_bidder_cost(item_idx, old_item_owner); + } else { + // item was unassigned before + remove_unassigned_item(item_idx); + } + + // new edge was added to matching, increase partial cost + partial_cost += get_item_bidder_cost(item_idx, bidder_idx); + +#ifdef LOG_AUCTION + if (is_item_diagonal(item_idx)) { + num_diag_assignments++; + num_diag_assignments_non_cumulative++; + } else { + num_normal_assignments++; + num_normal_assignments_non_cumulative++; + } + + if (k_invalid_index != old_item_owner) { + if (is_bidder_diagonal(bidder_idx) and is_bidder_diagonal(old_item_owner)) { + num_diag_stole_from_diag++; + } + } +#endif + + //sanity_check(); + } + + template + typename AuctionRunnerJac::Real + AuctionRunnerJac::get_item_bidder_cost(const size_t item_idx, const size_t bidder_idx) const + { + return std::pow(dist_lp(bidders[bidder_idx], items[item_idx], internal_p, dimension), + wasserstein_power); + } + + template + void AuctionRunnerJac::assign_to_best_bidder(IdxType item_idx) { + assert(item_idx >= 0 and item_idx < static_cast(num_items)); + assert(bid_table[item_idx].first != k_invalid_index); + IdxValPairR best_bid{bid_table[item_idx]}; + assign_item_to_bidder(item_idx, best_bid.first); + oracle.set_price(item_idx, best_bid.second); +#ifdef LOG_AUCTION + + if (is_step_parallel) { + num_parallel_assignments++; + } + num_total_assignments++; + + price_change_cnt_vec.back()[item_idx]++; +#endif + } + + template + void AuctionRunnerJac::clear_bid_table() { + auto iter = items_with_bids.begin(); + while (iter != items_with_bids.end()) { + auto item_with_bid_idx = *iter; + bid_table[item_with_bid_idx].first = k_invalid_index; + bid_table[item_with_bid_idx].second = k_lowest_bid_value; + iter = items_with_bids.erase(iter); + } + } + + template + void AuctionRunnerJac::submit_bid(IdxType bidder_idx, const IdxValPairR &bid) { + IdxType item_idx = bid.first; + Real bid_value = bid.second; + assert(item_idx >= 0); + if (bid_table[item_idx].second < bid_value) { + bid_table[item_idx].first = bidder_idx; + bid_table[item_idx].second = bid_value; + } + items_with_bids.insert(item_idx); + +#ifdef LOG_AUCTION + + num_total_bids++; + + + if (is_bidder_diagonal(bidder_idx)) { + num_diag_bids_submitted++; + } else { + num_normal_bids_submitted++; + } +#endif + } + + template + void AuctionRunnerJac::print_debug() { +#ifdef DEBUG_AUCTION + sanity_check(); + std::cout << "**********************" << std::endl; + std::cout << "Current assignment:" << std::endl; + for(size_t idx = 0; idx < bidders_to_items.size(); ++idx) { + std::cout << idx << " <--> " << bidders_to_items[idx] << std::endl; + } + std::cout << "Weights: " << std::endl; + //for(size_t i = 0; i < num_bidders; ++i) { + //for(size_t j = 0; j < num_items; ++j) { + //std::cout << oracle.weight_matrix[i][j] << " "; + //} + //std::cout << std::endl; + //} + std::cout << "Prices: " << std::endl; + for(const auto price : oracle.get_prices()) { + std::cout << price << std::endl; + } + //std::cout << "Value matrix: " << std::endl; + //for(size_t i = 0; i < num_bidders; ++i) { + //for(size_t j = 0; j < num_items; ++j) { + //std::cout << oracle.weight_matrix[i][j] - oracle.prices[j] << " "; + //} + //std::cout << std::endl; + //} + std::cout << "**********************" << std::endl; +#endif + } + + template + typename AuctionRunnerJac::Real + AuctionRunnerJac::get_relative_error(const bool debug_output) const + { + Real result; +#ifndef WASSERSTEIN_PURE_GEOM + Real gamma = get_gamma(); +#else + Real gamma = 0.0; +#endif + // cost minus n epsilon + Real reduced_cost = partial_cost - num_bidders * get_epsilon(); + if (reduced_cost < 0) { +#ifdef LOG_AUCTION + if (debug_output) { + console_logger->info("Epsilon too large, reduced_cost = {0}, gamma = {1}", reduced_cost, gamma); + } +#endif + result = k_max_relative_error; + } else { + Real denominator = std::pow(reduced_cost, 1.0 / wasserstein_power) - gamma; + if (denominator <= 0) { +#ifdef LOG_AUCTION + if (debug_output) { + console_logger->info("Epsilon too large, reduced_cost = {0}, denominator = {1}, gamma = {2}", + reduced_cost, denominator, gamma); + } +#endif + result = k_max_relative_error; + } else { + Real numerator = 2 * gamma + + std::pow(partial_cost, 1.0 / wasserstein_power) - + std::pow(reduced_cost, 1.0 / wasserstein_power); + + result = numerator / denominator; +#ifdef LOG_AUCTION + if (debug_output) { + console_logger->info( + "Reduced_cost = {0}, denominator = {1}, numerator {2}, error = {3}, gamma = {4}", + reduced_cost, + denominator, + numerator, + result, + gamma); + } +#endif + } + } + return result; + } + + template + void AuctionRunnerJac::flush_assignment() { + for (auto &b2i : bidders_to_items) { + b2i = k_invalid_index; + } + for (auto &i2b : items_to_bidders) { + i2b = k_invalid_index; + } + + // all bidders and items become unassigned + for (size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + unassigned_bidders.insert(bidder_idx); + } + +#ifdef ORDERED_BY_PERSISTENCE + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + if (is_bidder_normal(bidder_idx)) { + unassigned_normal_bidders_by_persistence.insert( + std::make_pair(bidders[bidder_idx].persistence_lp(1.0), bidder_idx)); + } + } +#endif + oracle.adjust_prices(); + + partial_cost = 0.0; + + +#ifndef WASSERSTEIN_PURE_GEOM + for (size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + if (is_bidder_normal(bidder_idx)) { + unassigned_normal_bidders.insert(bidder_idx); + } else { + unassigned_diag_bidders.insert(bidder_idx); + } + } + + unassigned_bidders_persistence = total_bidders_persistence; + unassigned_items_persistence = total_items_persistence; + +#ifdef LOG_AUCTION + + price_change_cnt_vec.push_back(std::vector(num_items, 0)); + + never_assigned_bidders = unassigned_bidders; + + for (size_t item_idx = 0; item_idx < items.size(); ++item_idx) { + unassigned_items.insert(item_idx); + if (is_item_normal(item_idx)) { + unassigned_normal_items.insert(item_idx); + } else { + unassigned_diag_items.insert(item_idx); + } + } + + num_diag_bids_submitted = 0; + num_normal_bids_submitted = 0; + num_diag_assignments = 0; + num_normal_assignments = 0; + + all_assigned_round = 0; + all_assigned_round_found = false; + num_rounds_non_cumulative = 0; +#endif +#endif + + } // flush_assignment + + + template + void AuctionRunnerJac::set_epsilon(Real new_val) { + assert(new_val > 0.0); + epsilon = new_val; + oracle.set_epsilon(new_val); + } + + template + void AuctionRunnerJac::run_auction_phases(const int max_num_phases, const Real _initial_epsilon) { + set_epsilon(_initial_epsilon); + assert(oracle.get_epsilon() > 0); + for (int phase_num = 0; phase_num < max_num_phases; ++phase_num) { + flush_assignment(); + run_auction_phase(); + +#ifdef LOG_AUCTION + console_logger->info( + "Phase {0} done, current_result = {1}, eps = {2}, error = {7}, num_rounds = {3}, num_assignments = {4}, num_bids_submitted = {5}, # unassigned = {6}", + num_phase, + partial_cost, + get_epsilon(), + format_int<>(num_rounds), + format_int<>(num_normal_assignments + num_diag_assignments), + format_int<>(num_normal_bids_submitted + num_diag_bids_submitted), + unassigned_bidders.size(), + get_relative_error(num_phase == 1) + ); + +// console_logger->info("num_rounds (non-cumulative)= {0}, num_diag_assignments = {1}, num_normal_assignments = {2}, num_diag_bids_submitted = {3}, num_normal_bids_submitted = {4}", +// format_int<>(num_rounds_non_cumulative), +// format_int<>(num_diag_assignments), +// format_int<>(num_normal_assignments), +// format_int<>(num_diag_bids_submitted), +// format_int<>(num_normal_bids_submitted) +// ); + + console_logger->info( + "num_parallel_bids / num_total_bids = {0} / {1} = {2}, num_parallel_assignments / num_total_aassignments = {3} / {4} = {5}", + format_int<>(num_parallel_bids), + format_int<>(num_total_bids), + static_cast(num_parallel_bids) / static_cast(num_total_bids), + format_int<>(num_parallel_assignments), + format_int<>(num_total_assignments), + static_cast(num_parallel_assignments) / static_cast(num_total_assignments) + ); + + console_logger->info( + "num_parallel_diag_bids / num_total_diag_bids = {0} / {1} = {2}, num_parallel_normal_bids / num_total_normal_bids = {3} / {4} = {5}", + format_int<>(num_parallel_diag_bids), + format_int<>(num_total_diag_bids), + static_cast(num_parallel_diag_bids) / static_cast(num_total_diag_bids), + format_int<>(num_parallel_normal_bids), + format_int<>(num_total_normal_bids), + static_cast(num_parallel_normal_bids) / static_cast(num_total_normal_bids) + ); + +// console_logger->info("num_rounds before all biders assigned = {0}, num_rounds (non-cumulative)= {1}, fraction = {2}", +// format_int<>(all_assigned_round), +// format_int<>(num_rounds_non_cumulative), +// static_cast(all_assigned_round) / static_cast(num_rounds_non_cumulative) +// ); + + for (size_t item_idx = 0; item_idx < num_items; ++item_idx) { + price_stat_logger->info("{0} {1} {2} {3} {4}", + phase_num, + item_idx, + items[item_idx][0], + items[item_idx][1], + price_change_cnt_vec.back()[item_idx] + ); + } +#endif + + + if (is_done()) + break; + else + decrease_epsilon(); + + } + } + + template + void AuctionRunnerJac::decrease_epsilon() { + set_epsilon(get_epsilon() / epsilon_common_ratio); + } + + template + void AuctionRunnerJac::run_auction() + { + if (num_bidders == 1) { + assign_item_to_bidder(0, 0); + wasserstein_cost = get_item_bidder_cost(0,0); + return; + } + double init_eps = (initial_epsilon > 0.0) ? initial_epsilon : oracle.max_val_ / 4.0; + run_auction_phases(max_num_phases, init_eps); + is_distance_computed = true; + wasserstein_cost = partial_cost; + if (not is_done()) { +#ifndef FOR_R_TDA + std::cerr << "Maximum iteration number exceeded, exiting. Current result is: "; + std::cerr << get_wasserstein_distance() << std::endl; +#endif + throw std::runtime_error("Maximum iteration number exceeded"); + } + } + + template + void AuctionRunnerJac::add_unassigned_bidder(const size_t bidder_idx) + { + unassigned_bidders.insert(bidder_idx); + +#ifndef WASSERSTEIN_PURE_GEOM + const auto &bidder = bidders[bidder_idx]; + unassigned_bidders_persistence += get_cost_to_diagonal(bidder); + + if (is_bidder_diagonal(bidder_idx)) { + unassigned_diag_bidders.insert(bidder_idx); + } else { + unassigned_normal_bidders.insert(bidder_idx); + } +#ifdef ORDERED_BY_PERSISTENCE + if (is_bidder_normal(bidder_idx)) { + unassigned_normal_bidders_by_persistence.insert(std::make_pair(bidder.persistence_lp(1.0), bidder_idx)); + } +#endif + +#endif + } + + template + void AuctionRunnerJac::remove_unassigned_bidder(const size_t bidder_idx) + { + unassigned_bidders.erase(bidder_idx); +#ifndef WASSERSTEIN_PURE_GEOM + const auto &bidder = bidders[bidder_idx]; + unassigned_bidders_persistence -= get_cost_to_diagonal(bidder); + +#ifdef ORDERED_BY_PERSISTENCE + if (is_bidder_normal(bidder_idx)) { + unassigned_normal_bidders_by_persistence.erase(std::make_pair(bidder.persistence_lp(1.0), bidder_idx)); + } +#endif + + if (is_bidder_diagonal(bidder_idx)) { + unassigned_diag_bidders.erase(bidder_idx); + } else { + unassigned_normal_bidders.erase(bidder_idx); + } + + +#ifdef LOG_AUCTION + never_assigned_bidders.erase(bidder_idx); + if (never_assigned_bidders.empty() and not all_assigned_round_found) { + all_assigned_round = num_rounds_non_cumulative; + all_assigned_round_found = true; + } +#endif +#endif + } + + template + void AuctionRunnerJac::remove_unassigned_item(const size_t item_idx) { +#ifndef WASSERSTEIN_PURE_GEOM + unassigned_items_persistence -= get_cost_to_diagonal(items[item_idx]); + +#ifdef LOG_AUCTION + unassigned_items.erase(item_idx); + + if (is_item_normal(item_idx)) { + unassigned_normal_items.erase(item_idx); + } else { + unassigned_diag_items.erase(item_idx); + } +#endif +#endif + } + + template + template + void AuctionRunnerJac::run_bidding_step(const Range &active_bidders) + { +#ifdef LOG_AUCTION + is_step_parallel = false; + size_t diag_bids_submitted = 0; + size_t normal_bids_submitted = 0; +#endif + + clear_bid_table(); + size_t bids_submitted = 0; + for (const auto bidder_idx : active_bidders) { + + ++bids_submitted; + + submit_bid(bidder_idx, oracle.get_optimal_bid(bidder_idx)); + +#ifdef LOG_AUCTION + if (is_bidder_diagonal(bidder_idx)) { + diag_bids_submitted++; + } else { + normal_bids_submitted++; + } + + if (bids_submitted >= parallel_threshold) { + is_step_parallel = true; + } + + if (bids_submitted >= max_bids_per_round) { + break; + } + if (diag_first and not unassigned_diag_bidders.empty() and + diag_bids_submitted >= oracle.get_heap_top_size()) { + continue; + } +#endif + } + +#ifdef LOG_AUCTION + num_total_diag_bids += diag_bids_submitted; + num_total_normal_bids += normal_bids_submitted; + if (is_step_parallel) { + num_parallel_bids += bids_submitted; + num_parallel_diag_bids += diag_bids_submitted; + num_parallel_normal_bids += normal_bids_submitted; + } +#endif + } + + template + bool AuctionRunnerJac::is_done() const + { + return get_relative_error() <= delta; + } + + template + bool AuctionRunnerJac::continue_auction_phase() const + { + return not unassigned_bidders.empty() and not is_done(); + } + + template + void AuctionRunnerJac::run_auction_phase() + { + num_phase++; + //console_logger->debug("Entered run_auction_phase"); + + do { + num_rounds++; +#ifdef LOG_AUCTION + num_diag_stole_from_diag = 0; + num_normal_assignments_non_cumulative = 0; + num_diag_assignments_non_cumulative = 0; + num_rounds_non_cumulative++; +#endif + + // bidding +#ifdef ORDERED_BY_PERSISTENCE + if (not unassigned_diag_bidders.empty()) { + run_bidding_step(unassigned_diag_bidders); + } else { + std::vector active_bidders; + active_bidders.reserve(batch_size); + for (auto iter = unassigned_normal_bidders_by_persistence.begin(); iter != unassigned_normal_bidders_by_persistence.end(); ++iter) { + active_bidders.push_back(iter->second); + if (active_bidders.size() >= batch_size) { + break; + } + } + run_bidding_step(active_bidders); + } +#elif defined WASSERSTEIN_PURE_GEOM + run_bidding_step(unassigned_bidders); +#else + if (diag_first and not unassigned_diag_bidders.empty()) { + run_bidding_step(unassigned_diag_bidders); + } else { + run_bidding_step(unassigned_bidders); + } +#endif + + // assignment + for (auto item_idx : items_with_bids) { + assign_to_best_bidder(item_idx); + } +#ifdef LOG_AUCTION + plot_logger->info("{0} {1} {2} {3} {4} {5} {6} {7} {8} {9} {10} {11} {12} {13} {14}", + num_phase, + num_rounds, + unassigned_bidders.size(), + get_gamma(), + partial_cost, + oracle.get_epsilon(), + unassigned_normal_bidders.size(), + unassigned_diag_bidders.size(), + unassigned_normal_items.size(), + unassigned_diag_items.size(), + num_normal_assignments_non_cumulative, + num_diag_assignments_non_cumulative, + oracle.get_heap_top_size(), + get_relative_error(false), + num_diag_stole_from_diag + ); +#endif + //sanity_check(); + } while (continue_auction_phase()); + } + + template + typename AuctionRunnerJac::Real + AuctionRunnerJac::get_wasserstein_distance() + { + assert(is_distance_computed); + return std::pow(wasserstein_cost, 1.0 / wasserstein_power); + } + + template + typename AuctionRunnerJac::Real + AuctionRunnerJac::get_wasserstein_cost() + { + assert(is_distance_computed); + return wasserstein_cost; + } + + + template + void AuctionRunnerJac::sanity_check() + { +#ifdef DEBUG_AUCTION + if (bidders_to_items.size() != num_bidders) { + std::cerr << "Wrong size of bidders_to_items, must be " << num_bidders << ", is " << bidders_to_items.size() << std::endl; + throw "Wrong size of bidders_to_items"; + } + + if (items_to_bidders.size() != num_bidders) { + std::cerr << "Wrong size of items_to_bidders, must be " << num_bidders << ", is " << items_to_bidders.size() << std::endl; + throw "Wrong size of items_to_bidders"; + } + + for(size_t bidder_idx = 0; bidder_idx < num_bidders; ++bidder_idx) { + if ( bidders_to_items[bidder_idx] >= 0) { + + if ( std::count(bidders_to_items.begin(), + bidders_to_items.end(), + bidders_to_items[bidder_idx]) > 1 ) { + std::cerr << "Good " << bidders_to_items[bidder_idx]; + std::cerr << " appears in bidders_to_items more than once" << std::endl; + throw "Duplicate in bidders_to_items"; + } + + if (items_to_bidders.at(bidders_to_items[bidder_idx]) != static_cast(bidder_idx)) { + std::cerr << "Inconsitency: bidder_idx = " << bidder_idx; + std::cerr << ", item_idx in bidders_to_items = "; + std::cerr << bidders_to_items[bidder_idx]; + std::cerr << ", bidder_idx in items_to_bidders = "; + std::cerr << items_to_bidders[bidders_to_items[bidder_idx]] << std::endl; + throw "inconsistent mapping"; + } + } + } + + for(IdxType item_idx = 0; item_idx < static_cast(num_bidders); ++item_idx) { + if ( items_to_bidders[item_idx] >= 0) { + + // check for uniqueness + if ( std::count(items_to_bidders.begin(), + items_to_bidders.end(), + items_to_bidders[item_idx]) > 1 ) { + std::cerr << "Bidder " << items_to_bidders[item_idx]; + std::cerr << " appears in items_to_bidders more than once" << std::endl; + throw "Duplicate in items_to_bidders"; + } + // check for consistency + if (bidders_to_items.at(items_to_bidders[item_idx]) != static_cast(item_idx)) { + std::cerr << "Inconsitency: item_idx = " << item_idx; + std::cerr << ", bidder_idx in items_to_bidders = "; + std::cerr << items_to_bidders[item_idx]; + std::cerr << ", item_idx in bidders_to_items= "; + std::cerr << bidders_to_items[items_to_bidders[item_idx]] << std::endl; + throw "inconsistent mapping"; + } + } + } +#endif + } + + template + void AuctionRunnerJac::print_matching() { +#ifdef DEBUG_AUCTION + sanity_check(); + for(size_t bidder_idx = 0; bidder_idx < bidders_to_items.size(); ++bidder_idx) { + if (bidders_to_items[bidder_idx] >= 0) { + auto pA = bidders[bidder_idx]; + auto pB = items[bidders_to_items[bidder_idx]]; + std::cout << pA << " <-> " << pB << "+" << pow(dist_lp(pA, pB, internal_p, dimension), wasserstein_power) << std::endl; + } else { + assert(false); + } + } +#endif + } + +} // ws +} // hera + +#endif diff --git a/src/dionysus/wasserstein/basic_defs_ws.h b/src/dionysus/wasserstein/basic_defs_ws.h new file mode 100755 index 0000000..58d6fd2 --- /dev/null +++ b/src/dionysus/wasserstein/basic_defs_ws.h @@ -0,0 +1,336 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef BASIC_DEFS_WS_H +#define BASIC_DEFS_WS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#endif + +#ifndef FOR_R_TDA +#include "spdlog/spdlog.h" +#include "spdlog/fmt/fmt.h" +#include "spdlog/fmt/ostr.h" +#endif + +#include "dnn/geometry/euclidean-dynamic.h" +#include "def_debug_ws.h" + +#define MIN_VALID_ID 10 + +namespace hera +{ + +template +bool is_infinity(const Real& x) +{ + return x == Real(-1); +}; + +template +Real get_infinity() +{ + return Real( -1 ); +} + +template +bool is_p_valid_norm(const Real& p) +{ + return is_infinity(p) or p >= Real(1); +} + +template +struct AuctionParams +{ + Real wasserstein_power { 1.0 }; + Real delta { 0.01 }; // relative error + Real internal_p { get_infinity() }; + Real initial_epsilon { 0.0 }; // 0.0 means maxVal / 4.0 + Real epsilon_common_ratio { 5.0 }; + Real gamma_threshold { 0.0 }; // for experiments, not in use now + int max_num_phases { std::numeric_limits::max() }; + size_t max_bids_per_round { 1 }; // imitate Gauss-Seidel is default behaviour + unsigned int dim { 2 }; // for pure geometric version only; ignored in persistence diagrams +}; + +namespace ws +{ + + using IdxType = int; + + constexpr size_t k_invalid_index = std::numeric_limits::max(); + + template + using IdxValPair = std::pair; + + + + template + std::ostream& operator<<(std::ostream& output, const IdxValPair p) + { + output << fmt::format("({0}, {1})", p.first, p.second); + return output; + } + + enum class OwnerType { k_none, k_normal, k_diagonal }; + + std::ostream& operator<<(std::ostream& s, const OwnerType t) + { + switch(t) + { + case OwnerType::k_none : s << "NONE"; break; + case OwnerType::k_normal: s << "NORMAL"; break; + case OwnerType::k_diagonal: s << "DIAGONAL"; break; + } + return s; + } + + template + struct Point { + Real x, y; + bool operator==(const Point& other) const; + bool operator!=(const Point& other) const; + Point(Real _x, Real _y) : x(_x), y(_y) {} + Point() : x(0.0), y(0.0) {} + }; + +#ifndef FOR_R_TDA + template + std::ostream& operator<<(std::ostream& output, const Point p); +#endif + + template + inline void hash_combine(std::size_t & seed, const T & v) + { + std::hash hasher; + seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + + template + struct DiagramPoint + { + using Real = Real_; + // data members + // Points above the diagonal have type NORMAL + // Projections onto the diagonal have type DIAG + // for DIAG points only x-coordinate is relevant + enum Type { NORMAL, DIAG}; + Real x, y; + Type type; + // methods + DiagramPoint(Real xx, Real yy, Type ttype); + bool is_diagonal() const { return type == DIAG; } + bool is_normal() const { return type == NORMAL; } + Real getRealX() const; // return the x-coord + Real getRealY() const; // return the y-coord + Real persistence_lp(const Real p) const; + struct LexicographicCmp + { + bool operator()(const DiagramPoint& p1, const DiagramPoint& p2) const + { return p1.type < p2.type || (p1.type == p2.type && (p1.x < p2.x || (p1.x == p2.x && p1.y < p2.y))); } + }; + + const Real& operator[](const int idx) const + { + switch(idx) + { + case 0 : return x; + break; + case 1 : return y; + break; + default: throw std::out_of_range("DiagramPoint has dimension 2"); + } + } + + Real& operator[](const int idx) + { + switch(idx) + { + case 0 : return x; + break; + case 1 : return y; + break; + default: throw std::out_of_range("DiagramPoint has dimension 2"); + } + } + + }; + + + template + struct DiagramPointHash { + size_t operator()(const DiagramPoint &p) const + { + std::size_t seed = 0; + hash_combine(seed, std::hash(p.x)); + hash_combine(seed, std::hash(p.y)); + hash_combine(seed, std::hash(p.is_diagonal())); + return seed; + } + }; + + +#ifndef FOR_R_TDA + template + std::ostream& operator<<(std::ostream& output, const DiagramPoint p); +#endif + + template + void format_arg(fmt::BasicFormatter &f, const char *&format_str, const DiagramPoint&p) { + if (p.is_diagonal()) { + f.writer().write("({0},{1}, DIAG)", p.x, p.y); + } else { + f.writer().write("({0},{1}, NORM)", p.x, p.y); + } + } + + + template + struct DistImpl + { + Real operator()(const Pt& a, const Pt& b, const Real p, const int dim) + { + Real result = 0.0; + if (hera::is_infinity(p)) { + for(int d = 0; d < dim; ++d) { + result = std::max(result, std::fabs(a[d] - b[d])); + } + } else if (p == 1.0) { + for(int d = 0; d < dim; ++d) { + result += std::fabs(a[d] - b[d]); + } + } else { + assert(p > 1.0); + for(int d = 0; d < dim; ++d) { + result += std::pow(std::fabs(a[d] - b[d]), p); + } + result = std::pow(result, 1.0 / p); + } + return result; + } + }; + + template + struct DistImpl> + { + Real operator()(const DiagramPoint& a, const DiagramPoint& b, const Real p, const int dim) + { + Real result = 0.0; + if ( a.is_diagonal() and b.is_diagonal()) { + return result; + } else if (hera::is_infinity(p)) { + result = std::max(std::fabs(a.getRealX() - b.getRealX()), std::fabs(a.getRealY() - b.getRealY())); + } else if (p == 1.0) { + result = std::fabs(a.getRealX() - b.getRealX()) + std::fabs(a.getRealY() - b.getRealY()); + } else { + assert(p > 1.0); + result = std::pow(std::pow(std::fabs(a.getRealX() - b.getRealX()), p) + std::pow(std::fabs(a.getRealY() - b.getRealY()), p), 1.0 / p); + } + return result; + } + }; + + template + R dist_lp(const Pt& a, const Pt& b, const R p, const int dim) + { + return DistImpl()(a, b, p, dim); + } + + // TODO + template + double getFurthestDistance3Approx(DiagPointContainer& A, DiagPointContainer& B, const Real p) + { + int dim = 2; + Real result { 0.0 }; + DiagramPoint begA = *(A.begin()); + DiagramPoint optB = *(B.begin()); + for(const auto& pointB : B) { + if (dist_lp(begA, pointB, p, dim) > result) { + result = dist_lp(begA, pointB, p, dim); + optB = pointB; + } + } + for(const auto& pointA : A) { + if (dist_lp(pointA, optB, p, dim) > result) { + result = dist_lp(pointA, optB, p, dim); + } + } + return result; + } + + template + Real getFurthestDistance3Approx_pg(const hera::ws::dnn::DynamicPointVector& A, const hera::ws::dnn::DynamicPointVector& B, const Real p, const int dim) + { + Real result { 0.0 }; + int opt_b_idx = 0; + for(size_t b_idx = 0; b_idx < B.size(); ++b_idx) { + if (dist_lp(A[0], B[b_idx], p, dim) > result) { + result = dist_lp(A[0], B[b_idx], p, dim); + opt_b_idx = b_idx; + } + } + + for(size_t a_idx = 0; a_idx < A.size(); ++a_idx) { + result = std::max(result, dist_lp(A[a_idx], B[opt_b_idx], p, dim)); + } + + return result; + } + + + template + std::string format_container_to_log(const Container& cont); + + template + std::string format_point_set_to_log(const IndexContainer& indices, const std::vector>& points); + + template + std::string format_int(T i); + +} // ws +} // hera + + + +#include "basic_defs_ws.hpp" + + +#endif diff --git a/src/dionysus/wasserstein/basic_defs_ws.hpp b/src/dionysus/wasserstein/basic_defs_ws.hpp new file mode 100755 index 0000000..629a2f8 --- /dev/null +++ b/src/dionysus/wasserstein/basic_defs_ws.hpp @@ -0,0 +1,193 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + +*/ + +#include +#include +#include +#include "basic_defs_ws.h" + +#ifndef FOR_R_TDA +#include +#endif + +#include + +namespace hera { +namespace ws { +// Point + +template +bool Point::operator==(const Point& other) const +{ + return ((this->x == other.x) and (this->y == other.y)); +} + +template +bool Point::operator!=(const Point& other) const +{ + return !(*this == other); +} + + +#ifndef FOR_R_TDA +template +std::ostream& operator<<(std::ostream& output, const Point p) +{ + output << "(" << p.x << ", " << p.y << ")"; + return output; +} +#endif + +template +Real sqr_dist(const Point& a, const Point& b) +{ + return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y); +} + +template +Real dist(const Point& a, const Point& b) +{ + return sqrt(sqr_dist(a, b)); +} + + +template +Real DiagramPoint::persistence_lp(const Real p) const +{ + if (is_diagonal()) + return 0.0; + else { + Real u { (getRealY() + getRealX())/2 }; + int dim = 2; + DiagramPoint a_proj(u, u, DiagramPoint::DIAG); + return dist_lp(*this, a_proj, p, dim); + } +} + + +#ifndef FOR_R_TDA +template +std::ostream& operator<<(std::ostream& output, const DiagramPoint p) +{ + if ( p.type == DiagramPoint::DIAG ) { + output << "(" << p.x << ", " << p.y << ", " << 0.5 * (p.x + p.y) << " DIAG )"; + } else { + output << "(" << p.x << ", " << p.y << ", " << " NORMAL)"; + } + return output; +} +#endif + +template +DiagramPoint::DiagramPoint(Real xx, Real yy, Type ttype) : + x(xx), + y(yy), + type(ttype) +{ + //if ( yy < xx ) + //throw "Point is below the diagonal"; + //if ( yy == xx and ttype != DiagramPoint::DIAG) + //throw "Point on the main diagonal must have DIAG type"; +} + +template +Real DiagramPoint::getRealX() const +{ + if (is_normal()) + return x; + else + return Real(0.5) * (x + y); +} + +template +Real DiagramPoint::getRealY() const +{ + if (is_normal()) + return y; + else + return Real(0.5) * (x + y); +} + +template +std::string format_container_to_log(const Container& cont) +{ + std::stringstream result; + result << "["; + for(auto iter = cont.begin(); iter != cont.end(); ++iter) { + result << *iter; + if (std::next(iter) != cont.end()) { + result << ", "; + } + } + result << "]"; + return result.str(); +} + +template +std::string format_pair_container_to_log(const Container& cont) +{ + std::stringstream result; + result << "["; + for(auto iter = cont.begin(); iter != cont.end(); ++iter) { + result << "(" << iter->first << ", " << iter->second << ")"; + if (std::next(iter) != cont.end()) { + result << ", "; + } + } + result << "]"; + return result.str(); +} + + +template +std::string format_point_set_to_log(const IndexContainer& indices, + const std::vector>& points) +{ + std::stringstream result; + result << "["; + for(auto iter = indices.begin(); iter != indices.end(); ++iter) { + DiagramPoint p = points[*iter]; + result << "(" << p.getRealX() << ", " << p.getRealY() << ")"; + if (std::next(iter) != indices.end()) + result << ", "; + } + result << "]"; + return result.str(); +} + +template +std::string format_int(T i) +{ + std::stringstream ss; + ss.imbue(std::locale("")); + ss << std::fixed << i; + return ss.str(); +} + + +} // end of namespace ws +} // hera diff --git a/src/dionysus/wasserstein/catch/catch.hpp b/src/dionysus/wasserstein/catch/catch.hpp new file mode 100755 index 0000000..f7681f4 --- /dev/null +++ b/src/dionysus/wasserstein/catch/catch.hpp @@ -0,0 +1,11545 @@ +/* + * Catch v1.9.6 + * Generated: 2017-06-27 12:19:54.557875 + * ---------------------------------------------------------- + * This file has been merged from multiple headers. Please don't edit it directly + * Copyright (c) 2012 Two Blue Cubes Ltd. All rights reserved. + * + * Distributed under the Boost Software License, Version 1.0. (See accompanying + * file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + */ +#ifndef TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED +#define TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED + +#define TWOBLUECUBES_CATCH_HPP_INCLUDED + +#ifdef __clang__ +# pragma clang system_header +#elif defined __GNUC__ +# pragma GCC system_header +#endif + +// #included from: internal/catch_suppress_warnings.h + +#ifdef __clang__ +# ifdef __ICC // icpc defines the __clang__ macro +# pragma warning(push) +# pragma warning(disable: 161 1682) +# else // __ICC +# pragma clang diagnostic ignored "-Wglobal-constructors" +# pragma clang diagnostic ignored "-Wvariadic-macros" +# pragma clang diagnostic ignored "-Wc99-extensions" +# pragma clang diagnostic ignored "-Wunused-variable" +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wpadded" +# pragma clang diagnostic ignored "-Wc++98-compat" +# pragma clang diagnostic ignored "-Wc++98-compat-pedantic" +# pragma clang diagnostic ignored "-Wswitch-enum" +# pragma clang diagnostic ignored "-Wcovered-switch-default" +# endif +#elif defined __GNUC__ +# pragma GCC diagnostic ignored "-Wvariadic-macros" +# pragma GCC diagnostic ignored "-Wunused-variable" +# pragma GCC diagnostic ignored "-Wparentheses" + +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wpadded" +#endif +#if defined(CATCH_CONFIG_MAIN) || defined(CATCH_CONFIG_RUNNER) +# define CATCH_IMPL +#endif + +#ifdef CATCH_IMPL +# ifndef CLARA_CONFIG_MAIN +# define CLARA_CONFIG_MAIN_NOT_DEFINED +# define CLARA_CONFIG_MAIN +# endif +#endif + +// #included from: internal/catch_notimplemented_exception.h +#define TWOBLUECUBES_CATCH_NOTIMPLEMENTED_EXCEPTION_H_INCLUDED + +// #included from: catch_common.h +#define TWOBLUECUBES_CATCH_COMMON_H_INCLUDED + +// #included from: catch_compiler_capabilities.h +#define TWOBLUECUBES_CATCH_COMPILER_CAPABILITIES_HPP_INCLUDED + +// Detect a number of compiler features - mostly C++11/14 conformance - by compiler +// The following features are defined: +// +// CATCH_CONFIG_CPP11_NULLPTR : is nullptr supported? +// CATCH_CONFIG_CPP11_NOEXCEPT : is noexcept supported? +// CATCH_CONFIG_CPP11_GENERATED_METHODS : The delete and default keywords for compiler generated methods +// CATCH_CONFIG_CPP11_IS_ENUM : std::is_enum is supported? +// CATCH_CONFIG_CPP11_TUPLE : std::tuple is supported +// CATCH_CONFIG_CPP11_LONG_LONG : is long long supported? +// CATCH_CONFIG_CPP11_OVERRIDE : is override supported? +// CATCH_CONFIG_CPP11_UNIQUE_PTR : is unique_ptr supported (otherwise use auto_ptr) +// CATCH_CONFIG_CPP11_SHUFFLE : is std::shuffle supported? +// CATCH_CONFIG_CPP11_TYPE_TRAITS : are type_traits and enable_if supported? + +// CATCH_CONFIG_CPP11_OR_GREATER : Is C++11 supported? + +// CATCH_CONFIG_VARIADIC_MACROS : are variadic macros supported? +// CATCH_CONFIG_COUNTER : is the __COUNTER__ macro supported? +// CATCH_CONFIG_WINDOWS_SEH : is Windows SEH supported? +// CATCH_CONFIG_POSIX_SIGNALS : are POSIX signals supported? +// **************** +// Note to maintainers: if new toggles are added please document them +// in configuration.md, too +// **************** + +// In general each macro has a _NO_ form +// (e.g. CATCH_CONFIG_CPP11_NO_NULLPTR) which disables the feature. +// Many features, at point of detection, define an _INTERNAL_ macro, so they +// can be combined, en-mass, with the _NO_ forms later. + +// All the C++11 features can be disabled with CATCH_CONFIG_NO_CPP11 + +#ifdef __cplusplus + +# if __cplusplus >= 201103L +# define CATCH_CPP11_OR_GREATER +# endif + +# if __cplusplus >= 201402L +# define CATCH_CPP14_OR_GREATER +# endif + +#endif + +#ifdef __clang__ + +# if __has_feature(cxx_nullptr) +# define CATCH_INTERNAL_CONFIG_CPP11_NULLPTR +# endif + +# if __has_feature(cxx_noexcept) +# define CATCH_INTERNAL_CONFIG_CPP11_NOEXCEPT +# endif + +# if defined(CATCH_CPP11_OR_GREATER) +# define CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \ + _Pragma( "clang diagnostic push" ) \ + _Pragma( "clang diagnostic ignored \"-Wexit-time-destructors\"" ) +# define CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS \ + _Pragma( "clang diagnostic pop" ) + +# define CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS \ + _Pragma( "clang diagnostic push" ) \ + _Pragma( "clang diagnostic ignored \"-Wparentheses\"" ) +# define CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS \ + _Pragma( "clang diagnostic pop" ) +# endif + +#endif // __clang__ + +//////////////////////////////////////////////////////////////////////////////// +// We know some environments not to support full POSIX signals +#if defined(__CYGWIN__) || defined(__QNX__) + +# if !defined(CATCH_CONFIG_POSIX_SIGNALS) +# define CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS +# endif + +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Cygwin +#ifdef __CYGWIN__ + +// Required for some versions of Cygwin to declare gettimeofday +// see: http://stackoverflow.com/questions/36901803/gettimeofday-not-declared-in-this-scope-cygwin +# define _BSD_SOURCE + +#endif // __CYGWIN__ + +//////////////////////////////////////////////////////////////////////////////// +// Borland +#ifdef __BORLANDC__ + +#endif // __BORLANDC__ + +//////////////////////////////////////////////////////////////////////////////// +// EDG +#ifdef __EDG_VERSION__ + +#endif // __EDG_VERSION__ + +//////////////////////////////////////////////////////////////////////////////// +// Digital Mars +#ifdef __DMC__ + +#endif // __DMC__ + +//////////////////////////////////////////////////////////////////////////////// +// GCC +#ifdef __GNUC__ + +# if __GNUC__ == 4 && __GNUC_MINOR__ >= 6 && defined(__GXX_EXPERIMENTAL_CXX0X__) +# define CATCH_INTERNAL_CONFIG_CPP11_NULLPTR +# endif + +// - otherwise more recent versions define __cplusplus >= 201103L +// and will get picked up below + +#endif // __GNUC__ + +//////////////////////////////////////////////////////////////////////////////// +// Visual C++ +#ifdef _MSC_VER + +#define CATCH_INTERNAL_CONFIG_WINDOWS_SEH + +#if (_MSC_VER >= 1600) +# define CATCH_INTERNAL_CONFIG_CPP11_NULLPTR +# define CATCH_INTERNAL_CONFIG_CPP11_UNIQUE_PTR +#endif + +#if (_MSC_VER >= 1900 ) // (VC++ 13 (VS2015)) +#define CATCH_INTERNAL_CONFIG_CPP11_NOEXCEPT +#define CATCH_INTERNAL_CONFIG_CPP11_GENERATED_METHODS +#define CATCH_INTERNAL_CONFIG_CPP11_SHUFFLE +#define CATCH_INTERNAL_CONFIG_CPP11_TYPE_TRAITS +#endif + +#endif // _MSC_VER + +//////////////////////////////////////////////////////////////////////////////// + +// Use variadic macros if the compiler supports them +#if ( defined _MSC_VER && _MSC_VER > 1400 && !defined __EDGE__) || \ + ( defined __WAVE__ && __WAVE_HAS_VARIADICS ) || \ + ( defined __GNUC__ && __GNUC__ >= 3 ) || \ + ( !defined __cplusplus && __STDC_VERSION__ >= 199901L || __cplusplus >= 201103L ) + +#define CATCH_INTERNAL_CONFIG_VARIADIC_MACROS + +#endif + +// Use __COUNTER__ if the compiler supports it +#if ( defined _MSC_VER && _MSC_VER >= 1300 ) || \ + ( defined __GNUC__ && ( __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3 )) ) || \ + ( defined __clang__ && __clang_major__ >= 3 ) + +#define CATCH_INTERNAL_CONFIG_COUNTER + +#endif + +//////////////////////////////////////////////////////////////////////////////// +// C++ language feature support + +// catch all support for C++11 +#if defined(CATCH_CPP11_OR_GREATER) + +# if !defined(CATCH_INTERNAL_CONFIG_CPP11_NULLPTR) +# define CATCH_INTERNAL_CONFIG_CPP11_NULLPTR +# endif + +# ifndef CATCH_INTERNAL_CONFIG_CPP11_NOEXCEPT +# define CATCH_INTERNAL_CONFIG_CPP11_NOEXCEPT +# endif + +# ifndef CATCH_INTERNAL_CONFIG_CPP11_GENERATED_METHODS +# define CATCH_INTERNAL_CONFIG_CPP11_GENERATED_METHODS +# endif + +# ifndef CATCH_INTERNAL_CONFIG_CPP11_IS_ENUM +# define CATCH_INTERNAL_CONFIG_CPP11_IS_ENUM +# endif + +# ifndef CATCH_INTERNAL_CONFIG_CPP11_TUPLE +# define CATCH_INTERNAL_CONFIG_CPP11_TUPLE +# endif + +# ifndef CATCH_INTERNAL_CONFIG_VARIADIC_MACROS +# define CATCH_INTERNAL_CONFIG_VARIADIC_MACROS +# endif + +# if !defined(CATCH_INTERNAL_CONFIG_CPP11_LONG_LONG) +# define CATCH_INTERNAL_CONFIG_CPP11_LONG_LONG +# endif + +# if !defined(CATCH_INTERNAL_CONFIG_CPP11_OVERRIDE) +# define CATCH_INTERNAL_CONFIG_CPP11_OVERRIDE +# endif +# if !defined(CATCH_INTERNAL_CONFIG_CPP11_UNIQUE_PTR) +# define CATCH_INTERNAL_CONFIG_CPP11_UNIQUE_PTR +# endif +# if !defined(CATCH_INTERNAL_CONFIG_CPP11_SHUFFLE) +# define CATCH_INTERNAL_CONFIG_CPP11_SHUFFLE +# endif +# if !defined(CATCH_INTERNAL_CONFIG_CPP11_TYPE_TRAITS) +# define CATCH_INTERNAL_CONFIG_CPP11_TYPE_TRAITS +# endif + +#endif // __cplusplus >= 201103L + +// Now set the actual defines based on the above + anything the user has configured +#if defined(CATCH_INTERNAL_CONFIG_CPP11_NULLPTR) && !defined(CATCH_CONFIG_CPP11_NO_NULLPTR) && !defined(CATCH_CONFIG_CPP11_NULLPTR) && !defined(CATCH_CONFIG_NO_CPP11) +# define CATCH_CONFIG_CPP11_NULLPTR +#endif +#if defined(CATCH_INTERNAL_CONFIG_CPP11_NOEXCEPT) && !defined(CATCH_CONFIG_CPP11_NO_NOEXCEPT) && !defined(CATCH_CONFIG_CPP11_NOEXCEPT) && !defined(CATCH_CONFIG_NO_CPP11) +# define CATCH_CONFIG_CPP11_NOEXCEPT +#endif +#if defined(CATCH_INTERNAL_CONFIG_CPP11_GENERATED_METHODS) && !defined(CATCH_CONFIG_CPP11_NO_GENERATED_METHODS) && !defined(CATCH_CONFIG_CPP11_GENERATED_METHODS) && !defined(CATCH_CONFIG_NO_CPP11) +# define CATCH_CONFIG_CPP11_GENERATED_METHODS +#endif +#if defined(CATCH_INTERNAL_CONFIG_CPP11_IS_ENUM) && !defined(CATCH_CONFIG_CPP11_NO_IS_ENUM) && !defined(CATCH_CONFIG_CPP11_IS_ENUM) && !defined(CATCH_CONFIG_NO_CPP11) +# define CATCH_CONFIG_CPP11_IS_ENUM +#endif +#if defined(CATCH_INTERNAL_CONFIG_CPP11_TUPLE) && !defined(CATCH_CONFIG_CPP11_NO_TUPLE) && !defined(CATCH_CONFIG_CPP11_TUPLE) && !defined(CATCH_CONFIG_NO_CPP11) +# define CATCH_CONFIG_CPP11_TUPLE +#endif +#if defined(CATCH_INTERNAL_CONFIG_VARIADIC_MACROS) && !defined(CATCH_CONFIG_NO_VARIADIC_MACROS) && !defined(CATCH_CONFIG_VARIADIC_MACROS) +# define CATCH_CONFIG_VARIADIC_MACROS +#endif +#if defined(CATCH_INTERNAL_CONFIG_CPP11_LONG_LONG) && !defined(CATCH_CONFIG_CPP11_NO_LONG_LONG) && !defined(CATCH_CONFIG_CPP11_LONG_LONG) && !defined(CATCH_CONFIG_NO_CPP11) +# define CATCH_CONFIG_CPP11_LONG_LONG +#endif +#if defined(CATCH_INTERNAL_CONFIG_CPP11_OVERRIDE) && !defined(CATCH_CONFIG_CPP11_NO_OVERRIDE) && !defined(CATCH_CONFIG_CPP11_OVERRIDE) && !defined(CATCH_CONFIG_NO_CPP11) +# define CATCH_CONFIG_CPP11_OVERRIDE +#endif +#if defined(CATCH_INTERNAL_CONFIG_CPP11_UNIQUE_PTR) && !defined(CATCH_CONFIG_CPP11_NO_UNIQUE_PTR) && !defined(CATCH_CONFIG_CPP11_UNIQUE_PTR) && !defined(CATCH_CONFIG_NO_CPP11) +# define CATCH_CONFIG_CPP11_UNIQUE_PTR +#endif +// Use of __COUNTER__ is suppressed if __JETBRAINS_IDE__ is #defined (meaning we're being parsed by a JetBrains IDE for +// analytics) because, at time of writing, __COUNTER__ is not properly handled by it. +// This does not affect compilation +#if defined(CATCH_INTERNAL_CONFIG_COUNTER) && !defined(CATCH_CONFIG_NO_COUNTER) && !defined(CATCH_CONFIG_COUNTER) && !defined(__JETBRAINS_IDE__) +# define CATCH_CONFIG_COUNTER +#endif +#if defined(CATCH_INTERNAL_CONFIG_CPP11_SHUFFLE) && !defined(CATCH_CONFIG_CPP11_NO_SHUFFLE) && !defined(CATCH_CONFIG_CPP11_SHUFFLE) && !defined(CATCH_CONFIG_NO_CPP11) +# define CATCH_CONFIG_CPP11_SHUFFLE +#endif +# if defined(CATCH_INTERNAL_CONFIG_CPP11_TYPE_TRAITS) && !defined(CATCH_CONFIG_CPP11_NO_TYPE_TRAITS) && !defined(CATCH_CONFIG_CPP11_TYPE_TRAITS) && !defined(CATCH_CONFIG_NO_CPP11) +# define CATCH_CONFIG_CPP11_TYPE_TRAITS +# endif +#if defined(CATCH_INTERNAL_CONFIG_WINDOWS_SEH) && !defined(CATCH_CONFIG_NO_WINDOWS_SEH) && !defined(CATCH_CONFIG_WINDOWS_SEH) +# define CATCH_CONFIG_WINDOWS_SEH +#endif +// This is set by default, because we assume that unix compilers are posix-signal-compatible by default. +#if !defined(CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS) && !defined(CATCH_CONFIG_NO_POSIX_SIGNALS) && !defined(CATCH_CONFIG_POSIX_SIGNALS) +# define CATCH_CONFIG_POSIX_SIGNALS +#endif + +#if !defined(CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS) +# define CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS +# define CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS +#endif +#if !defined(CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS) +# define CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS +# define CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS +#endif + +// noexcept support: +#if defined(CATCH_CONFIG_CPP11_NOEXCEPT) && !defined(CATCH_NOEXCEPT) +# define CATCH_NOEXCEPT noexcept +# define CATCH_NOEXCEPT_IS(x) noexcept(x) +#else +# define CATCH_NOEXCEPT throw() +# define CATCH_NOEXCEPT_IS(x) +#endif + +// nullptr support +#ifdef CATCH_CONFIG_CPP11_NULLPTR +# define CATCH_NULL nullptr +#else +# define CATCH_NULL NULL +#endif + +// override support +#ifdef CATCH_CONFIG_CPP11_OVERRIDE +# define CATCH_OVERRIDE override +#else +# define CATCH_OVERRIDE +#endif + +// unique_ptr support +#ifdef CATCH_CONFIG_CPP11_UNIQUE_PTR +# define CATCH_AUTO_PTR( T ) std::unique_ptr +#else +# define CATCH_AUTO_PTR( T ) std::auto_ptr +#endif + +#define INTERNAL_CATCH_UNIQUE_NAME_LINE2( name, line ) name##line +#define INTERNAL_CATCH_UNIQUE_NAME_LINE( name, line ) INTERNAL_CATCH_UNIQUE_NAME_LINE2( name, line ) +#ifdef CATCH_CONFIG_COUNTER +# define INTERNAL_CATCH_UNIQUE_NAME( name ) INTERNAL_CATCH_UNIQUE_NAME_LINE( name, __COUNTER__ ) +#else +# define INTERNAL_CATCH_UNIQUE_NAME( name ) INTERNAL_CATCH_UNIQUE_NAME_LINE( name, __LINE__ ) +#endif + +#define INTERNAL_CATCH_STRINGIFY2( expr ) #expr +#define INTERNAL_CATCH_STRINGIFY( expr ) INTERNAL_CATCH_STRINGIFY2( expr ) + +#include +#include + +namespace Catch { + + struct IConfig; + + struct CaseSensitive { enum Choice { + Yes, + No + }; }; + + class NonCopyable { +#ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS + NonCopyable( NonCopyable const& ) = delete; + NonCopyable( NonCopyable && ) = delete; + NonCopyable& operator = ( NonCopyable const& ) = delete; + NonCopyable& operator = ( NonCopyable && ) = delete; +#else + NonCopyable( NonCopyable const& info ); + NonCopyable& operator = ( NonCopyable const& ); +#endif + + protected: + NonCopyable() {} + virtual ~NonCopyable(); + }; + + class SafeBool { + public: + typedef void (SafeBool::*type)() const; + + static type makeSafe( bool value ) { + return value ? &SafeBool::trueValue : 0; + } + private: + void trueValue() const {} + }; + + template + inline void deleteAll( ContainerT& container ) { + typename ContainerT::const_iterator it = container.begin(); + typename ContainerT::const_iterator itEnd = container.end(); + for(; it != itEnd; ++it ) + delete *it; + } + template + inline void deleteAllValues( AssociativeContainerT& container ) { + typename AssociativeContainerT::const_iterator it = container.begin(); + typename AssociativeContainerT::const_iterator itEnd = container.end(); + for(; it != itEnd; ++it ) + delete it->second; + } + + bool startsWith( std::string const& s, std::string const& prefix ); + bool startsWith( std::string const& s, char prefix ); + bool endsWith( std::string const& s, std::string const& suffix ); + bool endsWith( std::string const& s, char suffix ); + bool contains( std::string const& s, std::string const& infix ); + void toLowerInPlace( std::string& s ); + std::string toLower( std::string const& s ); + std::string trim( std::string const& str ); + bool replaceInPlace( std::string& str, std::string const& replaceThis, std::string const& withThis ); + + struct pluralise { + pluralise( std::size_t count, std::string const& label ); + + friend std::ostream& operator << ( std::ostream& os, pluralise const& pluraliser ); + + std::size_t m_count; + std::string m_label; + }; + + struct SourceLineInfo { + + SourceLineInfo(); + SourceLineInfo( char const* _file, std::size_t _line ); +# ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS + SourceLineInfo(SourceLineInfo const& other) = default; + SourceLineInfo( SourceLineInfo && ) = default; + SourceLineInfo& operator = ( SourceLineInfo const& ) = default; + SourceLineInfo& operator = ( SourceLineInfo && ) = default; +# endif + bool empty() const; + bool operator == ( SourceLineInfo const& other ) const; + bool operator < ( SourceLineInfo const& other ) const; + + char const* file; + std::size_t line; + }; + + std::ostream& operator << ( std::ostream& os, SourceLineInfo const& info ); + + // This is just here to avoid compiler warnings with macro constants and boolean literals + inline bool isTrue( bool value ){ return value; } + inline bool alwaysTrue() { return true; } + inline bool alwaysFalse() { return false; } + + void throwLogicError( std::string const& message, SourceLineInfo const& locationInfo ); + + void seedRng( IConfig const& config ); + unsigned int rngSeed(); + + // Use this in variadic streaming macros to allow + // >> +StreamEndStop + // as well as + // >> stuff +StreamEndStop + struct StreamEndStop { + std::string operator+() { + return std::string(); + } + }; + template + T const& operator + ( T const& value, StreamEndStop ) { + return value; + } +} + +#define CATCH_INTERNAL_LINEINFO ::Catch::SourceLineInfo( __FILE__, static_cast( __LINE__ ) ) +#define CATCH_INTERNAL_ERROR( msg ) ::Catch::throwLogicError( msg, CATCH_INTERNAL_LINEINFO ); + +namespace Catch { + + class NotImplementedException : public std::exception + { + public: + NotImplementedException( SourceLineInfo const& lineInfo ); + NotImplementedException( NotImplementedException const& ) {} + + virtual ~NotImplementedException() CATCH_NOEXCEPT {} + + virtual const char* what() const CATCH_NOEXCEPT; + + private: + std::string m_what; + SourceLineInfo m_lineInfo; + }; + +} // end namespace Catch + +/////////////////////////////////////////////////////////////////////////////// +#define CATCH_NOT_IMPLEMENTED throw Catch::NotImplementedException( CATCH_INTERNAL_LINEINFO ) + +// #included from: internal/catch_context.h +#define TWOBLUECUBES_CATCH_CONTEXT_H_INCLUDED + +// #included from: catch_interfaces_generators.h +#define TWOBLUECUBES_CATCH_INTERFACES_GENERATORS_H_INCLUDED + +#include + +namespace Catch { + + struct IGeneratorInfo { + virtual ~IGeneratorInfo(); + virtual bool moveNext() = 0; + virtual std::size_t getCurrentIndex() const = 0; + }; + + struct IGeneratorsForTest { + virtual ~IGeneratorsForTest(); + + virtual IGeneratorInfo& getGeneratorInfo( std::string const& fileInfo, std::size_t size ) = 0; + virtual bool moveNext() = 0; + }; + + IGeneratorsForTest* createGeneratorsForTest(); + +} // end namespace Catch + +// #included from: catch_ptr.hpp +#define TWOBLUECUBES_CATCH_PTR_HPP_INCLUDED + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +namespace Catch { + + // An intrusive reference counting smart pointer. + // T must implement addRef() and release() methods + // typically implementing the IShared interface + template + class Ptr { + public: + Ptr() : m_p( CATCH_NULL ){} + Ptr( T* p ) : m_p( p ){ + if( m_p ) + m_p->addRef(); + } + Ptr( Ptr const& other ) : m_p( other.m_p ){ + if( m_p ) + m_p->addRef(); + } + ~Ptr(){ + if( m_p ) + m_p->release(); + } + void reset() { + if( m_p ) + m_p->release(); + m_p = CATCH_NULL; + } + Ptr& operator = ( T* p ){ + Ptr temp( p ); + swap( temp ); + return *this; + } + Ptr& operator = ( Ptr const& other ){ + Ptr temp( other ); + swap( temp ); + return *this; + } + void swap( Ptr& other ) { std::swap( m_p, other.m_p ); } + T* get() const{ return m_p; } + T& operator*() const { return *m_p; } + T* operator->() const { return m_p; } + bool operator !() const { return m_p == CATCH_NULL; } + operator SafeBool::type() const { return SafeBool::makeSafe( m_p != CATCH_NULL ); } + + private: + T* m_p; + }; + + struct IShared : NonCopyable { + virtual ~IShared(); + virtual void addRef() const = 0; + virtual void release() const = 0; + }; + + template + struct SharedImpl : T { + + SharedImpl() : m_rc( 0 ){} + + virtual void addRef() const { + ++m_rc; + } + virtual void release() const { + if( --m_rc == 0 ) + delete this; + } + + mutable unsigned int m_rc; + }; + +} // end namespace Catch + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +namespace Catch { + + class TestCase; + class Stream; + struct IResultCapture; + struct IRunner; + struct IGeneratorsForTest; + struct IConfig; + + struct IContext + { + virtual ~IContext(); + + virtual IResultCapture* getResultCapture() = 0; + virtual IRunner* getRunner() = 0; + virtual size_t getGeneratorIndex( std::string const& fileInfo, size_t totalSize ) = 0; + virtual bool advanceGeneratorsForCurrentTest() = 0; + virtual Ptr getConfig() const = 0; + }; + + struct IMutableContext : IContext + { + virtual ~IMutableContext(); + virtual void setResultCapture( IResultCapture* resultCapture ) = 0; + virtual void setRunner( IRunner* runner ) = 0; + virtual void setConfig( Ptr const& config ) = 0; + }; + + IContext& getCurrentContext(); + IMutableContext& getCurrentMutableContext(); + void cleanUpContext(); + Stream createStream( std::string const& streamName ); + +} + +// #included from: internal/catch_test_registry.hpp +#define TWOBLUECUBES_CATCH_TEST_REGISTRY_HPP_INCLUDED + +// #included from: catch_interfaces_testcase.h +#define TWOBLUECUBES_CATCH_INTERFACES_TESTCASE_H_INCLUDED + +#include + +namespace Catch { + + class TestSpec; + + struct ITestCase : IShared { + virtual void invoke () const = 0; + protected: + virtual ~ITestCase(); + }; + + class TestCase; + struct IConfig; + + struct ITestCaseRegistry { + virtual ~ITestCaseRegistry(); + virtual std::vector const& getAllTests() const = 0; + virtual std::vector const& getAllTestsSorted( IConfig const& config ) const = 0; + }; + + bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config ); + std::vector filterTests( std::vector const& testCases, TestSpec const& testSpec, IConfig const& config ); + std::vector const& getAllTestCasesSorted( IConfig const& config ); + +} + +namespace Catch { + +template +class MethodTestCase : public SharedImpl { + +public: + MethodTestCase( void (C::*method)() ) : m_method( method ) {} + + virtual void invoke() const { + C obj; + (obj.*m_method)(); + } + +private: + virtual ~MethodTestCase() {} + + void (C::*m_method)(); +}; + +typedef void(*TestFunction)(); + +struct NameAndDesc { + NameAndDesc( const char* _name = "", const char* _description= "" ) + : name( _name ), description( _description ) + {} + + const char* name; + const char* description; +}; + +void registerTestCase + ( ITestCase* testCase, + char const* className, + NameAndDesc const& nameAndDesc, + SourceLineInfo const& lineInfo ); + +struct AutoReg { + + AutoReg + ( TestFunction function, + SourceLineInfo const& lineInfo, + NameAndDesc const& nameAndDesc ); + + template + AutoReg + ( void (C::*method)(), + char const* className, + NameAndDesc const& nameAndDesc, + SourceLineInfo const& lineInfo ) { + + registerTestCase + ( new MethodTestCase( method ), + className, + nameAndDesc, + lineInfo ); + } + + ~AutoReg(); + +private: + AutoReg( AutoReg const& ); + void operator= ( AutoReg const& ); +}; + +void registerTestCaseFunction + ( TestFunction function, + SourceLineInfo const& lineInfo, + NameAndDesc const& nameAndDesc ); + +} // end namespace Catch + +#ifdef CATCH_CONFIG_VARIADIC_MACROS + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TESTCASE2( TestName, ... ) \ + static void TestName(); \ + CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \ + namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( &TestName, CATCH_INTERNAL_LINEINFO, Catch::NameAndDesc( __VA_ARGS__ ) ); } \ + CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS \ + static void TestName() + #define INTERNAL_CATCH_TESTCASE( ... ) \ + INTERNAL_CATCH_TESTCASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), __VA_ARGS__ ) + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_METHOD_AS_TEST_CASE( QualifiedMethod, ... ) \ + CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \ + namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( &QualifiedMethod, "&" #QualifiedMethod, Catch::NameAndDesc( __VA_ARGS__ ), CATCH_INTERNAL_LINEINFO ); } \ + CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TEST_CASE_METHOD2( TestName, ClassName, ... )\ + CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \ + namespace{ \ + struct TestName : ClassName{ \ + void test(); \ + }; \ + Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar ) ( &TestName::test, #ClassName, Catch::NameAndDesc( __VA_ARGS__ ), CATCH_INTERNAL_LINEINFO ); \ + } \ + CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS \ + void TestName::test() + #define INTERNAL_CATCH_TEST_CASE_METHOD( ClassName, ... ) \ + INTERNAL_CATCH_TEST_CASE_METHOD2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), ClassName, __VA_ARGS__ ) + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_REGISTER_TESTCASE( Function, ... ) \ + CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \ + Catch::AutoReg( Function, CATCH_INTERNAL_LINEINFO, Catch::NameAndDesc( __VA_ARGS__ ) ); \ + CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS + +#else + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TESTCASE2( TestName, Name, Desc ) \ + static void TestName(); \ + CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \ + namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( &TestName, CATCH_INTERNAL_LINEINFO, Catch::NameAndDesc( Name, Desc ) ); }\ + CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS \ + static void TestName() + #define INTERNAL_CATCH_TESTCASE( Name, Desc ) \ + INTERNAL_CATCH_TESTCASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), Name, Desc ) + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_METHOD_AS_TEST_CASE( QualifiedMethod, Name, Desc ) \ + CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \ + namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( &QualifiedMethod, "&" #QualifiedMethod, Catch::NameAndDesc( Name, Desc ), CATCH_INTERNAL_LINEINFO ); } \ + CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TEST_CASE_METHOD2( TestCaseName, ClassName, TestName, Desc )\ + CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \ + namespace{ \ + struct TestCaseName : ClassName{ \ + void test(); \ + }; \ + Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar ) ( &TestCaseName::test, #ClassName, Catch::NameAndDesc( TestName, Desc ), CATCH_INTERNAL_LINEINFO ); \ + } \ + CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS \ + void TestCaseName::test() + #define INTERNAL_CATCH_TEST_CASE_METHOD( ClassName, TestName, Desc )\ + INTERNAL_CATCH_TEST_CASE_METHOD2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), ClassName, TestName, Desc ) + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_REGISTER_TESTCASE( Function, Name, Desc ) \ + CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \ + Catch::AutoReg( Function, CATCH_INTERNAL_LINEINFO, Catch::NameAndDesc( Name, Desc ) ); \ + CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS + +#endif + +// #included from: internal/catch_capture.hpp +#define TWOBLUECUBES_CATCH_CAPTURE_HPP_INCLUDED + +// #included from: catch_result_builder.h +#define TWOBLUECUBES_CATCH_RESULT_BUILDER_H_INCLUDED + +// #included from: catch_result_type.h +#define TWOBLUECUBES_CATCH_RESULT_TYPE_H_INCLUDED + +namespace Catch { + + // ResultWas::OfType enum + struct ResultWas { enum OfType { + Unknown = -1, + Ok = 0, + Info = 1, + Warning = 2, + + FailureBit = 0x10, + + ExpressionFailed = FailureBit | 1, + ExplicitFailure = FailureBit | 2, + + Exception = 0x100 | FailureBit, + + ThrewException = Exception | 1, + DidntThrowException = Exception | 2, + + FatalErrorCondition = 0x200 | FailureBit + + }; }; + + inline bool isOk( ResultWas::OfType resultType ) { + return ( resultType & ResultWas::FailureBit ) == 0; + } + inline bool isJustInfo( int flags ) { + return flags == ResultWas::Info; + } + + // ResultDisposition::Flags enum + struct ResultDisposition { enum Flags { + Normal = 0x01, + + ContinueOnFailure = 0x02, // Failures fail test, but execution continues + FalseTest = 0x04, // Prefix expression with ! + SuppressFail = 0x08 // Failures are reported but do not fail the test + }; }; + + inline ResultDisposition::Flags operator | ( ResultDisposition::Flags lhs, ResultDisposition::Flags rhs ) { + return static_cast( static_cast( lhs ) | static_cast( rhs ) ); + } + + inline bool shouldContinueOnFailure( int flags ) { return ( flags & ResultDisposition::ContinueOnFailure ) != 0; } + inline bool isFalseTest( int flags ) { return ( flags & ResultDisposition::FalseTest ) != 0; } + inline bool shouldSuppressFailure( int flags ) { return ( flags & ResultDisposition::SuppressFail ) != 0; } + +} // end namespace Catch + +// #included from: catch_assertionresult.h +#define TWOBLUECUBES_CATCH_ASSERTIONRESULT_H_INCLUDED + +#include + +namespace Catch { + + struct STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison; + + struct DecomposedExpression + { + virtual ~DecomposedExpression() {} + virtual bool isBinaryExpression() const { + return false; + } + virtual void reconstructExpression( std::string& dest ) const = 0; + + // Only simple binary comparisons can be decomposed. + // If more complex check is required then wrap sub-expressions in parentheses. + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator + ( T const& ); + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator - ( T const& ); + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator * ( T const& ); + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator / ( T const& ); + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator % ( T const& ); + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator && ( T const& ); + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator || ( T const& ); + + private: + DecomposedExpression& operator = (DecomposedExpression const&); + }; + + struct AssertionInfo + { + AssertionInfo() {} + AssertionInfo( char const * _macroName, + SourceLineInfo const& _lineInfo, + char const * _capturedExpression, + ResultDisposition::Flags _resultDisposition, + char const * _secondArg = ""); + + char const * macroName; + SourceLineInfo lineInfo; + char const * capturedExpression; + ResultDisposition::Flags resultDisposition; + char const * secondArg; + }; + + struct AssertionResultData + { + AssertionResultData() : decomposedExpression( CATCH_NULL ) + , resultType( ResultWas::Unknown ) + , negated( false ) + , parenthesized( false ) {} + + void negate( bool parenthesize ) { + negated = !negated; + parenthesized = parenthesize; + if( resultType == ResultWas::Ok ) + resultType = ResultWas::ExpressionFailed; + else if( resultType == ResultWas::ExpressionFailed ) + resultType = ResultWas::Ok; + } + + std::string const& reconstructExpression() const { + if( decomposedExpression != CATCH_NULL ) { + decomposedExpression->reconstructExpression( reconstructedExpression ); + if( parenthesized ) { + reconstructedExpression.insert( 0, 1, '(' ); + reconstructedExpression.append( 1, ')' ); + } + if( negated ) { + reconstructedExpression.insert( 0, 1, '!' ); + } + decomposedExpression = CATCH_NULL; + } + return reconstructedExpression; + } + + mutable DecomposedExpression const* decomposedExpression; + mutable std::string reconstructedExpression; + std::string message; + ResultWas::OfType resultType; + bool negated; + bool parenthesized; + }; + + class AssertionResult { + public: + AssertionResult(); + AssertionResult( AssertionInfo const& info, AssertionResultData const& data ); + ~AssertionResult(); +# ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS + AssertionResult( AssertionResult const& ) = default; + AssertionResult( AssertionResult && ) = default; + AssertionResult& operator = ( AssertionResult const& ) = default; + AssertionResult& operator = ( AssertionResult && ) = default; +# endif + + bool isOk() const; + bool succeeded() const; + ResultWas::OfType getResultType() const; + bool hasExpression() const; + bool hasMessage() const; + std::string getExpression() const; + std::string getExpressionInMacro() const; + bool hasExpandedExpression() const; + std::string getExpandedExpression() const; + std::string getMessage() const; + SourceLineInfo getSourceInfo() const; + std::string getTestMacroName() const; + void discardDecomposedExpression() const; + void expandDecomposedExpression() const; + + protected: + AssertionInfo m_info; + AssertionResultData m_resultData; + }; + +} // end namespace Catch + +// #included from: catch_matchers.hpp +#define TWOBLUECUBES_CATCH_MATCHERS_HPP_INCLUDED + +namespace Catch { +namespace Matchers { + namespace Impl { + + template struct MatchAllOf; + template struct MatchAnyOf; + template struct MatchNotOf; + + class MatcherUntypedBase { + public: + std::string toString() const { + if( m_cachedToString.empty() ) + m_cachedToString = describe(); + return m_cachedToString; + } + + protected: + virtual ~MatcherUntypedBase(); + virtual std::string describe() const = 0; + mutable std::string m_cachedToString; + private: + MatcherUntypedBase& operator = ( MatcherUntypedBase const& ); + }; + + template + struct MatcherMethod { + virtual bool match( ObjectT const& arg ) const = 0; + }; + template + struct MatcherMethod { + virtual bool match( PtrT* arg ) const = 0; + }; + + template + struct MatcherBase : MatcherUntypedBase, MatcherMethod { + + MatchAllOf operator && ( MatcherBase const& other ) const; + MatchAnyOf operator || ( MatcherBase const& other ) const; + MatchNotOf operator ! () const; + }; + + template + struct MatchAllOf : MatcherBase { + virtual bool match( ArgT const& arg ) const CATCH_OVERRIDE { + for( std::size_t i = 0; i < m_matchers.size(); ++i ) { + if (!m_matchers[i]->match(arg)) + return false; + } + return true; + } + virtual std::string describe() const CATCH_OVERRIDE { + std::string description; + description.reserve( 4 + m_matchers.size()*32 ); + description += "( "; + for( std::size_t i = 0; i < m_matchers.size(); ++i ) { + if( i != 0 ) + description += " and "; + description += m_matchers[i]->toString(); + } + description += " )"; + return description; + } + + MatchAllOf& operator && ( MatcherBase const& other ) { + m_matchers.push_back( &other ); + return *this; + } + + std::vector const*> m_matchers; + }; + template + struct MatchAnyOf : MatcherBase { + + virtual bool match( ArgT const& arg ) const CATCH_OVERRIDE { + for( std::size_t i = 0; i < m_matchers.size(); ++i ) { + if (m_matchers[i]->match(arg)) + return true; + } + return false; + } + virtual std::string describe() const CATCH_OVERRIDE { + std::string description; + description.reserve( 4 + m_matchers.size()*32 ); + description += "( "; + for( std::size_t i = 0; i < m_matchers.size(); ++i ) { + if( i != 0 ) + description += " or "; + description += m_matchers[i]->toString(); + } + description += " )"; + return description; + } + + MatchAnyOf& operator || ( MatcherBase const& other ) { + m_matchers.push_back( &other ); + return *this; + } + + std::vector const*> m_matchers; + }; + + template + struct MatchNotOf : MatcherBase { + + MatchNotOf( MatcherBase const& underlyingMatcher ) : m_underlyingMatcher( underlyingMatcher ) {} + + virtual bool match( ArgT const& arg ) const CATCH_OVERRIDE { + return !m_underlyingMatcher.match( arg ); + } + + virtual std::string describe() const CATCH_OVERRIDE { + return "not " + m_underlyingMatcher.toString(); + } + MatcherBase const& m_underlyingMatcher; + }; + + template + MatchAllOf MatcherBase::operator && ( MatcherBase const& other ) const { + return MatchAllOf() && *this && other; + } + template + MatchAnyOf MatcherBase::operator || ( MatcherBase const& other ) const { + return MatchAnyOf() || *this || other; + } + template + MatchNotOf MatcherBase::operator ! () const { + return MatchNotOf( *this ); + } + + } // namespace Impl + + // The following functions create the actual matcher objects. + // This allows the types to be inferred + // - deprecated: prefer ||, && and ! + template + inline Impl::MatchNotOf Not( Impl::MatcherBase const& underlyingMatcher ) { + return Impl::MatchNotOf( underlyingMatcher ); + } + template + inline Impl::MatchAllOf AllOf( Impl::MatcherBase const& m1, Impl::MatcherBase const& m2 ) { + return Impl::MatchAllOf() && m1 && m2; + } + template + inline Impl::MatchAllOf AllOf( Impl::MatcherBase const& m1, Impl::MatcherBase const& m2, Impl::MatcherBase const& m3 ) { + return Impl::MatchAllOf() && m1 && m2 && m3; + } + template + inline Impl::MatchAnyOf AnyOf( Impl::MatcherBase const& m1, Impl::MatcherBase const& m2 ) { + return Impl::MatchAnyOf() || m1 || m2; + } + template + inline Impl::MatchAnyOf AnyOf( Impl::MatcherBase const& m1, Impl::MatcherBase const& m2, Impl::MatcherBase const& m3 ) { + return Impl::MatchAnyOf() || m1 || m2 || m3; + } + +} // namespace Matchers + +using namespace Matchers; +using Matchers::Impl::MatcherBase; + +} // namespace Catch + +namespace Catch { + + struct TestFailureException{}; + + template class ExpressionLhs; + + struct CopyableStream { + CopyableStream() {} + CopyableStream( CopyableStream const& other ) { + oss << other.oss.str(); + } + CopyableStream& operator=( CopyableStream const& other ) { + oss.str(std::string()); + oss << other.oss.str(); + return *this; + } + std::ostringstream oss; + }; + + class ResultBuilder : public DecomposedExpression { + public: + ResultBuilder( char const* macroName, + SourceLineInfo const& lineInfo, + char const* capturedExpression, + ResultDisposition::Flags resultDisposition, + char const* secondArg = "" ); + ~ResultBuilder(); + + template + ExpressionLhs operator <= ( T const& operand ); + ExpressionLhs operator <= ( bool value ); + + template + ResultBuilder& operator << ( T const& value ) { + m_stream().oss << value; + return *this; + } + + ResultBuilder& setResultType( ResultWas::OfType result ); + ResultBuilder& setResultType( bool result ); + + void endExpression( DecomposedExpression const& expr ); + + virtual void reconstructExpression( std::string& dest ) const CATCH_OVERRIDE; + + AssertionResult build() const; + AssertionResult build( DecomposedExpression const& expr ) const; + + void useActiveException( ResultDisposition::Flags resultDisposition = ResultDisposition::Normal ); + void captureResult( ResultWas::OfType resultType ); + void captureExpression(); + void captureExpectedException( std::string const& expectedMessage ); + void captureExpectedException( Matchers::Impl::MatcherBase const& matcher ); + void handleResult( AssertionResult const& result ); + void react(); + bool shouldDebugBreak() const; + bool allowThrows() const; + + template + void captureMatch( ArgT const& arg, MatcherT const& matcher, char const* matcherString ); + + void setExceptionGuard(); + void unsetExceptionGuard(); + + private: + AssertionInfo m_assertionInfo; + AssertionResultData m_data; + + static CopyableStream &m_stream() + { + static CopyableStream s; + return s; + } + + bool m_shouldDebugBreak; + bool m_shouldThrow; + bool m_guardException; + }; + +} // namespace Catch + +// Include after due to circular dependency: +// #included from: catch_expression_lhs.hpp +#define TWOBLUECUBES_CATCH_EXPRESSION_LHS_HPP_INCLUDED + +// #included from: catch_evaluate.hpp +#define TWOBLUECUBES_CATCH_EVALUATE_HPP_INCLUDED + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable:4389) // '==' : signed/unsigned mismatch +#pragma warning(disable:4312) // Converting int to T* using reinterpret_cast (issue on x64 platform) +#endif + +#include + +namespace Catch { +namespace Internal { + + enum Operator { + IsEqualTo, + IsNotEqualTo, + IsLessThan, + IsGreaterThan, + IsLessThanOrEqualTo, + IsGreaterThanOrEqualTo + }; + + template struct OperatorTraits { static const char* getName(){ return "*error*"; } }; + template<> struct OperatorTraits { static const char* getName(){ return "=="; } }; + template<> struct OperatorTraits { static const char* getName(){ return "!="; } }; + template<> struct OperatorTraits { static const char* getName(){ return "<"; } }; + template<> struct OperatorTraits { static const char* getName(){ return ">"; } }; + template<> struct OperatorTraits { static const char* getName(){ return "<="; } }; + template<> struct OperatorTraits{ static const char* getName(){ return ">="; } }; + + template + inline T& opCast(T const& t) { return const_cast(t); } + +// nullptr_t support based on pull request #154 from Konstantin Baumann +#ifdef CATCH_CONFIG_CPP11_NULLPTR + inline std::nullptr_t opCast(std::nullptr_t) { return nullptr; } +#endif // CATCH_CONFIG_CPP11_NULLPTR + + // So the compare overloads can be operator agnostic we convey the operator as a template + // enum, which is used to specialise an Evaluator for doing the comparison. + template + class Evaluator{}; + + template + struct Evaluator { + static bool evaluate( T1 const& lhs, T2 const& rhs) { + return bool( opCast( lhs ) == opCast( rhs ) ); + } + }; + template + struct Evaluator { + static bool evaluate( T1 const& lhs, T2 const& rhs ) { + return bool( opCast( lhs ) != opCast( rhs ) ); + } + }; + template + struct Evaluator { + static bool evaluate( T1 const& lhs, T2 const& rhs ) { + return bool( opCast( lhs ) < opCast( rhs ) ); + } + }; + template + struct Evaluator { + static bool evaluate( T1 const& lhs, T2 const& rhs ) { + return bool( opCast( lhs ) > opCast( rhs ) ); + } + }; + template + struct Evaluator { + static bool evaluate( T1 const& lhs, T2 const& rhs ) { + return bool( opCast( lhs ) >= opCast( rhs ) ); + } + }; + template + struct Evaluator { + static bool evaluate( T1 const& lhs, T2 const& rhs ) { + return bool( opCast( lhs ) <= opCast( rhs ) ); + } + }; + + template + bool applyEvaluator( T1 const& lhs, T2 const& rhs ) { + return Evaluator::evaluate( lhs, rhs ); + } + + // This level of indirection allows us to specialise for integer types + // to avoid signed/ unsigned warnings + + // "base" overload + template + bool compare( T1 const& lhs, T2 const& rhs ) { + return Evaluator::evaluate( lhs, rhs ); + } + + // unsigned X to int + template bool compare( unsigned int lhs, int rhs ) { + return applyEvaluator( lhs, static_cast( rhs ) ); + } + template bool compare( unsigned long lhs, int rhs ) { + return applyEvaluator( lhs, static_cast( rhs ) ); + } + template bool compare( unsigned char lhs, int rhs ) { + return applyEvaluator( lhs, static_cast( rhs ) ); + } + + // unsigned X to long + template bool compare( unsigned int lhs, long rhs ) { + return applyEvaluator( lhs, static_cast( rhs ) ); + } + template bool compare( unsigned long lhs, long rhs ) { + return applyEvaluator( lhs, static_cast( rhs ) ); + } + template bool compare( unsigned char lhs, long rhs ) { + return applyEvaluator( lhs, static_cast( rhs ) ); + } + + // int to unsigned X + template bool compare( int lhs, unsigned int rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( int lhs, unsigned long rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( int lhs, unsigned char rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + + // long to unsigned X + template bool compare( long lhs, unsigned int rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( long lhs, unsigned long rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( long lhs, unsigned char rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + + // pointer to long (when comparing against NULL) + template bool compare( long lhs, T* rhs ) { + return Evaluator::evaluate( reinterpret_cast( lhs ), rhs ); + } + template bool compare( T* lhs, long rhs ) { + return Evaluator::evaluate( lhs, reinterpret_cast( rhs ) ); + } + + // pointer to int (when comparing against NULL) + template bool compare( int lhs, T* rhs ) { + return Evaluator::evaluate( reinterpret_cast( lhs ), rhs ); + } + template bool compare( T* lhs, int rhs ) { + return Evaluator::evaluate( lhs, reinterpret_cast( rhs ) ); + } + +#ifdef CATCH_CONFIG_CPP11_LONG_LONG + // long long to unsigned X + template bool compare( long long lhs, unsigned int rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( long long lhs, unsigned long rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( long long lhs, unsigned long long rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( long long lhs, unsigned char rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + + // unsigned long long to X + template bool compare( unsigned long long lhs, int rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( unsigned long long lhs, long rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( unsigned long long lhs, long long rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( unsigned long long lhs, char rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + + // pointer to long long (when comparing against NULL) + template bool compare( long long lhs, T* rhs ) { + return Evaluator::evaluate( reinterpret_cast( lhs ), rhs ); + } + template bool compare( T* lhs, long long rhs ) { + return Evaluator::evaluate( lhs, reinterpret_cast( rhs ) ); + } +#endif // CATCH_CONFIG_CPP11_LONG_LONG + +#ifdef CATCH_CONFIG_CPP11_NULLPTR + // pointer to nullptr_t (when comparing against nullptr) + template bool compare( std::nullptr_t, T* rhs ) { + return Evaluator::evaluate( nullptr, rhs ); + } + template bool compare( T* lhs, std::nullptr_t ) { + return Evaluator::evaluate( lhs, nullptr ); + } +#endif // CATCH_CONFIG_CPP11_NULLPTR + +} // end of namespace Internal +} // end of namespace Catch + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +// #included from: catch_tostring.h +#define TWOBLUECUBES_CATCH_TOSTRING_H_INCLUDED + +#include +#include +#include +#include +#include + +#ifdef __OBJC__ +// #included from: catch_objc_arc.hpp +#define TWOBLUECUBES_CATCH_OBJC_ARC_HPP_INCLUDED + +#import + +#ifdef __has_feature +#define CATCH_ARC_ENABLED __has_feature(objc_arc) +#else +#define CATCH_ARC_ENABLED 0 +#endif + +void arcSafeRelease( NSObject* obj ); +id performOptionalSelector( id obj, SEL sel ); + +#if !CATCH_ARC_ENABLED +inline void arcSafeRelease( NSObject* obj ) { + [obj release]; +} +inline id performOptionalSelector( id obj, SEL sel ) { + if( [obj respondsToSelector: sel] ) + return [obj performSelector: sel]; + return nil; +} +#define CATCH_UNSAFE_UNRETAINED +#define CATCH_ARC_STRONG +#else +inline void arcSafeRelease( NSObject* ){} +inline id performOptionalSelector( id obj, SEL sel ) { +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Warc-performSelector-leaks" +#endif + if( [obj respondsToSelector: sel] ) + return [obj performSelector: sel]; +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + return nil; +} +#define CATCH_UNSAFE_UNRETAINED __unsafe_unretained +#define CATCH_ARC_STRONG __strong +#endif + +#endif + +#ifdef CATCH_CONFIG_CPP11_TUPLE +#include +#endif + +#ifdef CATCH_CONFIG_CPP11_IS_ENUM +#include +#endif + +namespace Catch { + +// Why we're here. +template +std::string toString( T const& value ); + +// Built in overloads + +std::string toString( std::string const& value ); +std::string toString( std::wstring const& value ); +std::string toString( const char* const value ); +std::string toString( char* const value ); +std::string toString( const wchar_t* const value ); +std::string toString( wchar_t* const value ); +std::string toString( int value ); +std::string toString( unsigned long value ); +std::string toString( unsigned int value ); +std::string toString( const double value ); +std::string toString( const float value ); +std::string toString( bool value ); +std::string toString( char value ); +std::string toString( signed char value ); +std::string toString( unsigned char value ); + +#ifdef CATCH_CONFIG_CPP11_LONG_LONG +std::string toString( long long value ); +std::string toString( unsigned long long value ); +#endif + +#ifdef CATCH_CONFIG_CPP11_NULLPTR +std::string toString( std::nullptr_t ); +#endif + +#ifdef __OBJC__ + std::string toString( NSString const * const& nsstring ); + std::string toString( NSString * CATCH_ARC_STRONG & nsstring ); + std::string toString( NSObject* const& nsObject ); +#endif + +namespace Detail { + + extern const std::string unprintableString; + + #if !defined(CATCH_CONFIG_CPP11_STREAM_INSERTABLE_CHECK) + struct BorgType { + template BorgType( T const& ); + }; + + struct TrueType { char sizer[1]; }; + struct FalseType { char sizer[2]; }; + + TrueType& testStreamable( std::ostream& ); + FalseType testStreamable( FalseType ); + + FalseType operator<<( std::ostream const&, BorgType const& ); + + template + struct IsStreamInsertable { + static std::ostream &s; + static T const&t; + enum { value = sizeof( testStreamable(s << t) ) == sizeof( TrueType ) }; + }; +#else + template + class IsStreamInsertable { + template + static auto test(int) + -> decltype( std::declval() << std::declval(), std::true_type() ); + + template + static auto test(...) -> std::false_type; + + public: + static const bool value = decltype(test(0))::value; + }; +#endif + +#if defined(CATCH_CONFIG_CPP11_IS_ENUM) + template::value + > + struct EnumStringMaker + { + static std::string convert( T const& ) { return unprintableString; } + }; + + template + struct EnumStringMaker + { + static std::string convert( T const& v ) + { + return ::Catch::toString( + static_cast::type>(v) + ); + } + }; +#endif + template + struct StringMakerBase { +#if defined(CATCH_CONFIG_CPP11_IS_ENUM) + template + static std::string convert( T const& v ) + { + return EnumStringMaker::convert( v ); + } +#else + template + static std::string convert( T const& ) { return unprintableString; } +#endif + }; + + template<> + struct StringMakerBase { + template + static std::string convert( T const& _value ) { + std::ostringstream oss; + oss << _value; + return oss.str(); + } + }; + + std::string rawMemoryToString( const void *object, std::size_t size ); + + template + inline std::string rawMemoryToString( const T& object ) { + return rawMemoryToString( &object, sizeof(object) ); + } + +} // end namespace Detail + +template +struct StringMaker : + Detail::StringMakerBase::value> {}; + +template +struct StringMaker { + template + static std::string convert( U* p ) { + if( !p ) + return "NULL"; + else + return Detail::rawMemoryToString( p ); + } +}; + +template +struct StringMaker { + static std::string convert( R C::* p ) { + if( !p ) + return "NULL"; + else + return Detail::rawMemoryToString( p ); + } +}; + +namespace Detail { + template + std::string rangeToString( InputIterator first, InputIterator last ); +} + +//template +//struct StringMaker > { +// static std::string convert( std::vector const& v ) { +// return Detail::rangeToString( v.begin(), v.end() ); +// } +//}; + +template +std::string toString( std::vector const& v ) { + return Detail::rangeToString( v.begin(), v.end() ); +} + +#ifdef CATCH_CONFIG_CPP11_TUPLE + +// toString for tuples +namespace TupleDetail { + template< + typename Tuple, + std::size_t N = 0, + bool = (N < std::tuple_size::value) + > + struct ElementPrinter { + static void print( const Tuple& tuple, std::ostream& os ) + { + os << ( N ? ", " : " " ) + << Catch::toString(std::get(tuple)); + ElementPrinter::print(tuple,os); + } + }; + + template< + typename Tuple, + std::size_t N + > + struct ElementPrinter { + static void print( const Tuple&, std::ostream& ) {} + }; + +} + +template +struct StringMaker> { + + static std::string convert( const std::tuple& tuple ) + { + std::ostringstream os; + os << '{'; + TupleDetail::ElementPrinter>::print( tuple, os ); + os << " }"; + return os.str(); + } +}; +#endif // CATCH_CONFIG_CPP11_TUPLE + +namespace Detail { + template + std::string makeString( T const& value ) { + return StringMaker::convert( value ); + } +} // end namespace Detail + +/// \brief converts any type to a string +/// +/// The default template forwards on to ostringstream - except when an +/// ostringstream overload does not exist - in which case it attempts to detect +/// that and writes {?}. +/// Overload (not specialise) this template for custom typs that you don't want +/// to provide an ostream overload for. +template +std::string toString( T const& value ) { + return StringMaker::convert( value ); +} + + namespace Detail { + template + std::string rangeToString( InputIterator first, InputIterator last ) { + std::ostringstream oss; + oss << "{ "; + if( first != last ) { + oss << Catch::toString( *first ); + for( ++first ; first != last ; ++first ) + oss << ", " << Catch::toString( *first ); + } + oss << " }"; + return oss.str(); + } +} + +} // end namespace Catch + +namespace Catch { + +template +class BinaryExpression; + +template +class MatchExpression; + +// Wraps the LHS of an expression and overloads comparison operators +// for also capturing those and RHS (if any) +template +class ExpressionLhs : public DecomposedExpression { +public: + ExpressionLhs( ResultBuilder& rb, T lhs ) : m_rb( rb ), m_lhs( lhs ), m_truthy(false) {} + + ExpressionLhs& operator = ( const ExpressionLhs& ); + + template + BinaryExpression + operator == ( RhsT const& rhs ) { + return captureExpression( rhs ); + } + + template + BinaryExpression + operator != ( RhsT const& rhs ) { + return captureExpression( rhs ); + } + + template + BinaryExpression + operator < ( RhsT const& rhs ) { + return captureExpression( rhs ); + } + + template + BinaryExpression + operator > ( RhsT const& rhs ) { + return captureExpression( rhs ); + } + + template + BinaryExpression + operator <= ( RhsT const& rhs ) { + return captureExpression( rhs ); + } + + template + BinaryExpression + operator >= ( RhsT const& rhs ) { + return captureExpression( rhs ); + } + + BinaryExpression operator == ( bool rhs ) { + return captureExpression( rhs ); + } + + BinaryExpression operator != ( bool rhs ) { + return captureExpression( rhs ); + } + + void endExpression() { + m_truthy = m_lhs ? true : false; + m_rb + .setResultType( m_truthy ) + .endExpression( *this ); + } + + virtual void reconstructExpression( std::string& dest ) const CATCH_OVERRIDE { + dest = Catch::toString( m_lhs ); + } + +private: + template + BinaryExpression captureExpression( RhsT& rhs ) const { + return BinaryExpression( m_rb, m_lhs, rhs ); + } + + template + BinaryExpression captureExpression( bool rhs ) const { + return BinaryExpression( m_rb, m_lhs, rhs ); + } + +private: + ResultBuilder& m_rb; + T m_lhs; + bool m_truthy; +}; + +template +class BinaryExpression : public DecomposedExpression { +public: + BinaryExpression( ResultBuilder& rb, LhsT lhs, RhsT rhs ) + : m_rb( rb ), m_lhs( lhs ), m_rhs( rhs ) {} + + BinaryExpression& operator = ( BinaryExpression& ); + + void endExpression() const { + m_rb + .setResultType( Internal::compare( m_lhs, m_rhs ) ) + .endExpression( *this ); + } + + virtual bool isBinaryExpression() const CATCH_OVERRIDE { + return true; + } + + virtual void reconstructExpression( std::string& dest ) const CATCH_OVERRIDE { + std::string lhs = Catch::toString( m_lhs ); + std::string rhs = Catch::toString( m_rhs ); + char delim = lhs.size() + rhs.size() < 40 && + lhs.find('\n') == std::string::npos && + rhs.find('\n') == std::string::npos ? ' ' : '\n'; + dest.reserve( 7 + lhs.size() + rhs.size() ); + // 2 for spaces around operator + // 2 for operator + // 2 for parentheses (conditionally added later) + // 1 for negation (conditionally added later) + dest = lhs; + dest += delim; + dest += Internal::OperatorTraits::getName(); + dest += delim; + dest += rhs; + } + +private: + ResultBuilder& m_rb; + LhsT m_lhs; + RhsT m_rhs; +}; + +template +class MatchExpression : public DecomposedExpression { +public: + MatchExpression( ArgT arg, MatcherT matcher, char const* matcherString ) + : m_arg( arg ), m_matcher( matcher ), m_matcherString( matcherString ) {} + + virtual bool isBinaryExpression() const CATCH_OVERRIDE { + return true; + } + + virtual void reconstructExpression( std::string& dest ) const CATCH_OVERRIDE { + std::string matcherAsString = m_matcher.toString(); + dest = Catch::toString( m_arg ); + dest += ' '; + if( matcherAsString == Detail::unprintableString ) + dest += m_matcherString; + else + dest += matcherAsString; + } + +private: + ArgT m_arg; + MatcherT m_matcher; + char const* m_matcherString; +}; + +} // end namespace Catch + + +namespace Catch { + + template + inline ExpressionLhs ResultBuilder::operator <= ( T const& operand ) { + return ExpressionLhs( *this, operand ); + } + + inline ExpressionLhs ResultBuilder::operator <= ( bool value ) { + return ExpressionLhs( *this, value ); + } + + template + inline void ResultBuilder::captureMatch( ArgT const& arg, MatcherT const& matcher, + char const* matcherString ) { + MatchExpression expr( arg, matcher, matcherString ); + setResultType( matcher.match( arg ) ); + endExpression( expr ); + } + +} // namespace Catch + +// #included from: catch_message.h +#define TWOBLUECUBES_CATCH_MESSAGE_H_INCLUDED + +#include + +namespace Catch { + + struct MessageInfo { + MessageInfo( std::string const& _macroName, + SourceLineInfo const& _lineInfo, + ResultWas::OfType _type ); + + std::string macroName; + SourceLineInfo lineInfo; + ResultWas::OfType type; + std::string message; + unsigned int sequence; + + bool operator == ( MessageInfo const& other ) const { + return sequence == other.sequence; + } + bool operator < ( MessageInfo const& other ) const { + return sequence < other.sequence; + } + private: + static unsigned int globalCount; + }; + + struct MessageBuilder { + MessageBuilder( std::string const& macroName, + SourceLineInfo const& lineInfo, + ResultWas::OfType type ) + : m_info( macroName, lineInfo, type ) + {} + + template + MessageBuilder& operator << ( T const& value ) { + m_stream << value; + return *this; + } + + MessageInfo m_info; + std::ostringstream m_stream; + }; + + class ScopedMessage { + public: + ScopedMessage( MessageBuilder const& builder ); + ScopedMessage( ScopedMessage const& other ); + ~ScopedMessage(); + + MessageInfo m_info; + }; + +} // end namespace Catch + +// #included from: catch_interfaces_capture.h +#define TWOBLUECUBES_CATCH_INTERFACES_CAPTURE_H_INCLUDED + +#include + +namespace Catch { + + class TestCase; + class AssertionResult; + struct AssertionInfo; + struct SectionInfo; + struct SectionEndInfo; + struct MessageInfo; + class ScopedMessageBuilder; + struct Counts; + + struct IResultCapture { + + virtual ~IResultCapture(); + + virtual void assertionEnded( AssertionResult const& result ) = 0; + virtual bool sectionStarted( SectionInfo const& sectionInfo, + Counts& assertions ) = 0; + virtual void sectionEnded( SectionEndInfo const& endInfo ) = 0; + virtual void sectionEndedEarly( SectionEndInfo const& endInfo ) = 0; + virtual void pushScopedMessage( MessageInfo const& message ) = 0; + virtual void popScopedMessage( MessageInfo const& message ) = 0; + + virtual std::string getCurrentTestName() const = 0; + virtual const AssertionResult* getLastResult() const = 0; + + virtual void exceptionEarlyReported() = 0; + + virtual void handleFatalErrorCondition( std::string const& message ) = 0; + }; + + IResultCapture& getResultCapture(); +} + +// #included from: catch_debugger.h +#define TWOBLUECUBES_CATCH_DEBUGGER_H_INCLUDED + +// #included from: catch_platform.h +#define TWOBLUECUBES_CATCH_PLATFORM_H_INCLUDED + +#if defined(__MAC_OS_X_VERSION_MIN_REQUIRED) +# define CATCH_PLATFORM_MAC +#elif defined(__IPHONE_OS_VERSION_MIN_REQUIRED) +# define CATCH_PLATFORM_IPHONE +#elif defined(linux) || defined(__linux) || defined(__linux__) +# define CATCH_PLATFORM_LINUX +#elif defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER) +# define CATCH_PLATFORM_WINDOWS +# if !defined(NOMINMAX) && !defined(CATCH_CONFIG_NO_NOMINMAX) +# define CATCH_DEFINES_NOMINMAX +# endif +# if !defined(WIN32_LEAN_AND_MEAN) && !defined(CATCH_CONFIG_NO_WIN32_LEAN_AND_MEAN) +# define CATCH_DEFINES_WIN32_LEAN_AND_MEAN +# endif +#endif + +#include + +namespace Catch{ + + bool isDebuggerActive(); + void writeToDebugConsole( std::string const& text ); +} + +#ifdef CATCH_PLATFORM_MAC + + // The following code snippet based on: + // http://cocoawithlove.com/2008/03/break-into-debugger.html + #if defined(__ppc64__) || defined(__ppc__) + #define CATCH_TRAP() \ + __asm__("li r0, 20\nsc\nnop\nli r0, 37\nli r4, 2\nsc\nnop\n" \ + : : : "memory","r0","r3","r4" ) + #else + #define CATCH_TRAP() __asm__("int $3\n" : : ) + #endif + +#elif defined(CATCH_PLATFORM_LINUX) + // If we can use inline assembler, do it because this allows us to break + // directly at the location of the failing check instead of breaking inside + // raise() called from it, i.e. one stack frame below. + #if defined(__GNUC__) && (defined(__i386) || defined(__x86_64)) + #define CATCH_TRAP() asm volatile ("int $3") + #else // Fall back to the generic way. + #include + + #define CATCH_TRAP() raise(SIGTRAP) + #endif +#elif defined(_MSC_VER) + #define CATCH_TRAP() __debugbreak() +#elif defined(__MINGW32__) + extern "C" __declspec(dllimport) void __stdcall DebugBreak(); + #define CATCH_TRAP() DebugBreak() +#endif + +#ifdef CATCH_TRAP + #define CATCH_BREAK_INTO_DEBUGGER() if( Catch::isDebuggerActive() ) { CATCH_TRAP(); } +#else + #define CATCH_BREAK_INTO_DEBUGGER() Catch::alwaysTrue(); +#endif + +// #included from: catch_interfaces_runner.h +#define TWOBLUECUBES_CATCH_INTERFACES_RUNNER_H_INCLUDED + +namespace Catch { + class TestCase; + + struct IRunner { + virtual ~IRunner(); + virtual bool aborting() const = 0; + }; +} + +#if defined(CATCH_CONFIG_FAST_COMPILE) +/////////////////////////////////////////////////////////////////////////////// +// We can speedup compilation significantly by breaking into debugger lower in +// the callstack, because then we don't have to expand CATCH_BREAK_INTO_DEBUGGER +// macro in each assertion +#define INTERNAL_CATCH_REACT( resultBuilder ) \ + resultBuilder.react(); + +/////////////////////////////////////////////////////////////////////////////// +// Another way to speed-up compilation is to omit local try-catch for REQUIRE* +// macros. +// This can potentially cause false negative, if the test code catches +// the exception before it propagates back up to the runner. +#define INTERNAL_CATCH_TEST_NO_TRY( macroName, resultDisposition, expr ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr, resultDisposition ); \ + __catchResult.setExceptionGuard(); \ + CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS \ + ( __catchResult <= expr ).endExpression(); \ + CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS \ + __catchResult.unsetExceptionGuard(); \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::isTrue( false && static_cast( !!(expr) ) ) ) // expr here is never evaluated at runtime but it forces the compiler to give it a look +// The double negation silences MSVC's C4800 warning, the static_cast forces short-circuit evaluation if the type has overloaded &&. + +#define INTERNAL_CHECK_THAT_NO_TRY( macroName, matcher, resultDisposition, arg ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #arg ", " #matcher, resultDisposition ); \ + __catchResult.setExceptionGuard(); \ + __catchResult.captureMatch( arg, matcher, #matcher ); \ + __catchResult.unsetExceptionGuard(); \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) + +#else +/////////////////////////////////////////////////////////////////////////////// +// In the event of a failure works out if the debugger needs to be invoked +// and/or an exception thrown and takes appropriate action. +// This needs to be done as a macro so the debugger will stop in the user +// source code rather than in Catch library code +#define INTERNAL_CATCH_REACT( resultBuilder ) \ + if( resultBuilder.shouldDebugBreak() ) CATCH_BREAK_INTO_DEBUGGER(); \ + resultBuilder.react(); +#endif + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_TEST( macroName, resultDisposition, expr ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr, resultDisposition ); \ + try { \ + CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS \ + ( __catchResult <= expr ).endExpression(); \ + CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS \ + } \ + catch( ... ) { \ + __catchResult.useActiveException( resultDisposition ); \ + } \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::isTrue( false && static_cast( !!(expr) ) ) ) // expr here is never evaluated at runtime but it forces the compiler to give it a look + // The double negation silences MSVC's C4800 warning, the static_cast forces short-circuit evaluation if the type has overloaded &&. + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_IF( macroName, resultDisposition, expr ) \ + INTERNAL_CATCH_TEST( macroName, resultDisposition, expr ); \ + if( Catch::getResultCapture().getLastResult()->succeeded() ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_ELSE( macroName, resultDisposition, expr ) \ + INTERNAL_CATCH_TEST( macroName, resultDisposition, expr ); \ + if( !Catch::getResultCapture().getLastResult()->succeeded() ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_NO_THROW( macroName, resultDisposition, expr ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr, resultDisposition ); \ + try { \ + static_cast(expr); \ + __catchResult.captureResult( Catch::ResultWas::Ok ); \ + } \ + catch( ... ) { \ + __catchResult.useActiveException( resultDisposition ); \ + } \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_THROWS( macroName, resultDisposition, matcher, expr ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr, resultDisposition, #matcher ); \ + if( __catchResult.allowThrows() ) \ + try { \ + static_cast(expr); \ + __catchResult.captureResult( Catch::ResultWas::DidntThrowException ); \ + } \ + catch( ... ) { \ + __catchResult.captureExpectedException( matcher ); \ + } \ + else \ + __catchResult.captureResult( Catch::ResultWas::Ok ); \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_THROWS_AS( macroName, exceptionType, resultDisposition, expr ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr ", " #exceptionType, resultDisposition ); \ + if( __catchResult.allowThrows() ) \ + try { \ + static_cast(expr); \ + __catchResult.captureResult( Catch::ResultWas::DidntThrowException ); \ + } \ + catch( exceptionType ) { \ + __catchResult.captureResult( Catch::ResultWas::Ok ); \ + } \ + catch( ... ) { \ + __catchResult.useActiveException( resultDisposition ); \ + } \ + else \ + __catchResult.captureResult( Catch::ResultWas::Ok ); \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) + +/////////////////////////////////////////////////////////////////////////////// +#ifdef CATCH_CONFIG_VARIADIC_MACROS + #define INTERNAL_CATCH_MSG( macroName, messageType, resultDisposition, ... ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, "", resultDisposition ); \ + __catchResult << __VA_ARGS__ + ::Catch::StreamEndStop(); \ + __catchResult.captureResult( messageType ); \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) +#else + #define INTERNAL_CATCH_MSG( macroName, messageType, resultDisposition, log ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, "", resultDisposition ); \ + __catchResult << log + ::Catch::StreamEndStop(); \ + __catchResult.captureResult( messageType ); \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) +#endif + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_INFO( macroName, log ) \ + Catch::ScopedMessage INTERNAL_CATCH_UNIQUE_NAME( scopedMessage ) = Catch::MessageBuilder( macroName, CATCH_INTERNAL_LINEINFO, Catch::ResultWas::Info ) << log; + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CHECK_THAT( macroName, matcher, resultDisposition, arg ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #arg ", " #matcher, resultDisposition ); \ + try { \ + __catchResult.captureMatch( arg, matcher, #matcher ); \ + } catch( ... ) { \ + __catchResult.useActiveException( resultDisposition | Catch::ResultDisposition::ContinueOnFailure ); \ + } \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) + +// #included from: internal/catch_section.h +#define TWOBLUECUBES_CATCH_SECTION_H_INCLUDED + +// #included from: catch_section_info.h +#define TWOBLUECUBES_CATCH_SECTION_INFO_H_INCLUDED + +// #included from: catch_totals.hpp +#define TWOBLUECUBES_CATCH_TOTALS_HPP_INCLUDED + +#include + +namespace Catch { + + struct Counts { + Counts() : passed( 0 ), failed( 0 ), failedButOk( 0 ) {} + + Counts operator - ( Counts const& other ) const { + Counts diff; + diff.passed = passed - other.passed; + diff.failed = failed - other.failed; + diff.failedButOk = failedButOk - other.failedButOk; + return diff; + } + Counts& operator += ( Counts const& other ) { + passed += other.passed; + failed += other.failed; + failedButOk += other.failedButOk; + return *this; + } + + std::size_t total() const { + return passed + failed + failedButOk; + } + bool allPassed() const { + return failed == 0 && failedButOk == 0; + } + bool allOk() const { + return failed == 0; + } + + std::size_t passed; + std::size_t failed; + std::size_t failedButOk; + }; + + struct Totals { + + Totals operator - ( Totals const& other ) const { + Totals diff; + diff.assertions = assertions - other.assertions; + diff.testCases = testCases - other.testCases; + return diff; + } + + Totals delta( Totals const& prevTotals ) const { + Totals diff = *this - prevTotals; + if( diff.assertions.failed > 0 ) + ++diff.testCases.failed; + else if( diff.assertions.failedButOk > 0 ) + ++diff.testCases.failedButOk; + else + ++diff.testCases.passed; + return diff; + } + + Totals& operator += ( Totals const& other ) { + assertions += other.assertions; + testCases += other.testCases; + return *this; + } + + Counts assertions; + Counts testCases; + }; +} + +#include + +namespace Catch { + + struct SectionInfo { + SectionInfo + ( SourceLineInfo const& _lineInfo, + std::string const& _name, + std::string const& _description = std::string() ); + + std::string name; + std::string description; + SourceLineInfo lineInfo; + }; + + struct SectionEndInfo { + SectionEndInfo( SectionInfo const& _sectionInfo, Counts const& _prevAssertions, double _durationInSeconds ) + : sectionInfo( _sectionInfo ), prevAssertions( _prevAssertions ), durationInSeconds( _durationInSeconds ) + {} + + SectionInfo sectionInfo; + Counts prevAssertions; + double durationInSeconds; + }; + +} // end namespace Catch + +// #included from: catch_timer.h +#define TWOBLUECUBES_CATCH_TIMER_H_INCLUDED + +#ifdef _MSC_VER + +namespace Catch { + typedef unsigned long long UInt64; +} +#else +#include +namespace Catch { + typedef uint64_t UInt64; +} +#endif + +namespace Catch { + class Timer { + public: + Timer() : m_ticks( 0 ) {} + void start(); + unsigned int getElapsedMicroseconds() const; + unsigned int getElapsedMilliseconds() const; + double getElapsedSeconds() const; + + private: + UInt64 m_ticks; + }; + +} // namespace Catch + +#include + +namespace Catch { + + class Section : NonCopyable { + public: + Section( SectionInfo const& info ); + ~Section(); + + // This indicates whether the section should be executed or not + operator bool() const; + + private: + SectionInfo m_info; + + std::string m_name; + Counts m_assertions; + bool m_sectionIncluded; + Timer m_timer; + }; + +} // end namespace Catch + +#ifdef CATCH_CONFIG_VARIADIC_MACROS + #define INTERNAL_CATCH_SECTION( ... ) \ + if( Catch::Section const& INTERNAL_CATCH_UNIQUE_NAME( catch_internal_Section ) = Catch::SectionInfo( CATCH_INTERNAL_LINEINFO, __VA_ARGS__ ) ) +#else + #define INTERNAL_CATCH_SECTION( name, desc ) \ + if( Catch::Section const& INTERNAL_CATCH_UNIQUE_NAME( catch_internal_Section ) = Catch::SectionInfo( CATCH_INTERNAL_LINEINFO, name, desc ) ) +#endif + +// #included from: internal/catch_generators.hpp +#define TWOBLUECUBES_CATCH_GENERATORS_HPP_INCLUDED + +#include +#include +#include + +namespace Catch { + +template +struct IGenerator { + virtual ~IGenerator() {} + virtual T getValue( std::size_t index ) const = 0; + virtual std::size_t size () const = 0; +}; + +template +class BetweenGenerator : public IGenerator { +public: + BetweenGenerator( T from, T to ) : m_from( from ), m_to( to ){} + + virtual T getValue( std::size_t index ) const { + return m_from+static_cast( index ); + } + + virtual std::size_t size() const { + return static_cast( 1+m_to-m_from ); + } + +private: + + T m_from; + T m_to; +}; + +template +class ValuesGenerator : public IGenerator { +public: + ValuesGenerator(){} + + void add( T value ) { + m_values.push_back( value ); + } + + virtual T getValue( std::size_t index ) const { + return m_values[index]; + } + + virtual std::size_t size() const { + return m_values.size(); + } + +private: + std::vector m_values; +}; + +template +class CompositeGenerator { +public: + CompositeGenerator() : m_totalSize( 0 ) {} + + // *** Move semantics, similar to auto_ptr *** + CompositeGenerator( CompositeGenerator& other ) + : m_fileInfo( other.m_fileInfo ), + m_totalSize( 0 ) + { + move( other ); + } + + CompositeGenerator& setFileInfo( const char* fileInfo ) { + m_fileInfo = fileInfo; + return *this; + } + + ~CompositeGenerator() { + deleteAll( m_composed ); + } + + operator T () const { + size_t overallIndex = getCurrentContext().getGeneratorIndex( m_fileInfo, m_totalSize ); + + typename std::vector*>::const_iterator it = m_composed.begin(); + typename std::vector*>::const_iterator itEnd = m_composed.end(); + for( size_t index = 0; it != itEnd; ++it ) + { + const IGenerator* generator = *it; + if( overallIndex >= index && overallIndex < index + generator->size() ) + { + return generator->getValue( overallIndex-index ); + } + index += generator->size(); + } + CATCH_INTERNAL_ERROR( "Indexed past end of generated range" ); + return T(); // Suppress spurious "not all control paths return a value" warning in Visual Studio - if you know how to fix this please do so + } + + void add( const IGenerator* generator ) { + m_totalSize += generator->size(); + m_composed.push_back( generator ); + } + + CompositeGenerator& then( CompositeGenerator& other ) { + move( other ); + return *this; + } + + CompositeGenerator& then( T value ) { + ValuesGenerator* valuesGen = new ValuesGenerator(); + valuesGen->add( value ); + add( valuesGen ); + return *this; + } + +private: + + void move( CompositeGenerator& other ) { + m_composed.insert( m_composed.end(), other.m_composed.begin(), other.m_composed.end() ); + m_totalSize += other.m_totalSize; + other.m_composed.clear(); + } + + std::vector*> m_composed; + std::string m_fileInfo; + size_t m_totalSize; +}; + +namespace Generators +{ + template + CompositeGenerator between( T from, T to ) { + CompositeGenerator generators; + generators.add( new BetweenGenerator( from, to ) ); + return generators; + } + + template + CompositeGenerator values( T val1, T val2 ) { + CompositeGenerator generators; + ValuesGenerator* valuesGen = new ValuesGenerator(); + valuesGen->add( val1 ); + valuesGen->add( val2 ); + generators.add( valuesGen ); + return generators; + } + + template + CompositeGenerator values( T val1, T val2, T val3 ){ + CompositeGenerator generators; + ValuesGenerator* valuesGen = new ValuesGenerator(); + valuesGen->add( val1 ); + valuesGen->add( val2 ); + valuesGen->add( val3 ); + generators.add( valuesGen ); + return generators; + } + + template + CompositeGenerator values( T val1, T val2, T val3, T val4 ) { + CompositeGenerator generators; + ValuesGenerator* valuesGen = new ValuesGenerator(); + valuesGen->add( val1 ); + valuesGen->add( val2 ); + valuesGen->add( val3 ); + valuesGen->add( val4 ); + generators.add( valuesGen ); + return generators; + } + +} // end namespace Generators + +using namespace Generators; + +} // end namespace Catch + +#define INTERNAL_CATCH_LINESTR2( line ) #line +#define INTERNAL_CATCH_LINESTR( line ) INTERNAL_CATCH_LINESTR2( line ) + +#define INTERNAL_CATCH_GENERATE( expr ) expr.setFileInfo( __FILE__ "(" INTERNAL_CATCH_LINESTR( __LINE__ ) ")" ) + +// #included from: internal/catch_interfaces_exception.h +#define TWOBLUECUBES_CATCH_INTERFACES_EXCEPTION_H_INCLUDED + +#include +#include + +// #included from: catch_interfaces_registry_hub.h +#define TWOBLUECUBES_CATCH_INTERFACES_REGISTRY_HUB_H_INCLUDED + +#include + +namespace Catch { + + class TestCase; + struct ITestCaseRegistry; + struct IExceptionTranslatorRegistry; + struct IExceptionTranslator; + struct IReporterRegistry; + struct IReporterFactory; + struct ITagAliasRegistry; + + struct IRegistryHub { + virtual ~IRegistryHub(); + + virtual IReporterRegistry const& getReporterRegistry() const = 0; + virtual ITestCaseRegistry const& getTestCaseRegistry() const = 0; + virtual ITagAliasRegistry const& getTagAliasRegistry() const = 0; + + virtual IExceptionTranslatorRegistry& getExceptionTranslatorRegistry() = 0; + }; + + struct IMutableRegistryHub { + virtual ~IMutableRegistryHub(); + virtual void registerReporter( std::string const& name, Ptr const& factory ) = 0; + virtual void registerListener( Ptr const& factory ) = 0; + virtual void registerTest( TestCase const& testInfo ) = 0; + virtual void registerTranslator( const IExceptionTranslator* translator ) = 0; + virtual void registerTagAlias( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) = 0; + }; + + IRegistryHub& getRegistryHub(); + IMutableRegistryHub& getMutableRegistryHub(); + void cleanUp(); + std::string translateActiveException(); + +} + +namespace Catch { + + typedef std::string(*exceptionTranslateFunction)(); + + struct IExceptionTranslator; + typedef std::vector ExceptionTranslators; + + struct IExceptionTranslator { + virtual ~IExceptionTranslator(); + virtual std::string translate( ExceptionTranslators::const_iterator it, ExceptionTranslators::const_iterator itEnd ) const = 0; + }; + + struct IExceptionTranslatorRegistry { + virtual ~IExceptionTranslatorRegistry(); + + virtual std::string translateActiveException() const = 0; + }; + + class ExceptionTranslatorRegistrar { + template + class ExceptionTranslator : public IExceptionTranslator { + public: + + ExceptionTranslator( std::string(*translateFunction)( T& ) ) + : m_translateFunction( translateFunction ) + {} + + virtual std::string translate( ExceptionTranslators::const_iterator it, ExceptionTranslators::const_iterator itEnd ) const CATCH_OVERRIDE { + try { + if( it == itEnd ) + throw; + else + return (*it)->translate( it+1, itEnd ); + } + catch( T& ex ) { + return m_translateFunction( ex ); + } + } + + protected: + std::string(*m_translateFunction)( T& ); + }; + + public: + template + ExceptionTranslatorRegistrar( std::string(*translateFunction)( T& ) ) { + getMutableRegistryHub().registerTranslator + ( new ExceptionTranslator( translateFunction ) ); + } + }; +} + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_TRANSLATE_EXCEPTION2( translatorName, signature ) \ + static std::string translatorName( signature ); \ + namespace{ Catch::ExceptionTranslatorRegistrar INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionRegistrar )( &translatorName ); }\ + static std::string translatorName( signature ) + +#define INTERNAL_CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION2( INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionTranslator ), signature ) + +// #included from: internal/catch_approx.hpp +#define TWOBLUECUBES_CATCH_APPROX_HPP_INCLUDED + +#include +#include + +#if defined(CATCH_CONFIG_CPP11_TYPE_TRAITS) +#include +#endif + +namespace Catch { +namespace Detail { + + class Approx { + public: + explicit Approx ( double value ) + : m_epsilon( std::numeric_limits::epsilon()*100 ), + m_margin( 0.0 ), + m_scale( 1.0 ), + m_value( value ) + {} + + Approx( Approx const& other ) + : m_epsilon( other.m_epsilon ), + m_margin( other.m_margin ), + m_scale( other.m_scale ), + m_value( other.m_value ) + {} + + static Approx custom() { + return Approx( 0 ); + } + +#if defined(CATCH_CONFIG_CPP11_TYPE_TRAITS) + + template ::value>::type> + Approx operator()( T value ) { + Approx approx( static_cast(value) ); + approx.epsilon( m_epsilon ); + approx.margin( m_margin ); + approx.scale( m_scale ); + return approx; + } + + template ::value>::type> + explicit Approx( T value ): Approx(static_cast(value)) + {} + + template ::value>::type> + friend bool operator == ( const T& lhs, Approx const& rhs ) { + // Thanks to Richard Harris for his help refining this formula + auto lhs_v = double(lhs); + bool relativeOK = std::fabs(lhs_v - rhs.m_value) < rhs.m_epsilon * (rhs.m_scale + (std::max)(std::fabs(lhs_v), std::fabs(rhs.m_value))); + if (relativeOK) { + return true; + } + return std::fabs(lhs_v - rhs.m_value) < rhs.m_margin; + } + + template ::value>::type> + friend bool operator == ( Approx const& lhs, const T& rhs ) { + return operator==( rhs, lhs ); + } + + template ::value>::type> + friend bool operator != ( T lhs, Approx const& rhs ) { + return !operator==( lhs, rhs ); + } + + template ::value>::type> + friend bool operator != ( Approx const& lhs, T rhs ) { + return !operator==( rhs, lhs ); + } + + template ::value>::type> + friend bool operator <= ( T lhs, Approx const& rhs ) { + return double(lhs) < rhs.m_value || lhs == rhs; + } + + template ::value>::type> + friend bool operator <= ( Approx const& lhs, T rhs ) { + return lhs.m_value < double(rhs) || lhs == rhs; + } + + template ::value>::type> + friend bool operator >= ( T lhs, Approx const& rhs ) { + return double(lhs) > rhs.m_value || lhs == rhs; + } + + template ::value>::type> + friend bool operator >= ( Approx const& lhs, T rhs ) { + return lhs.m_value > double(rhs) || lhs == rhs; + } + + template ::value>::type> + Approx& epsilon( T newEpsilon ) { + m_epsilon = double(newEpsilon); + return *this; + } + + template ::value>::type> + Approx& margin( T newMargin ) { + m_margin = double(newMargin); + return *this; + } + + template ::value>::type> + Approx& scale( T newScale ) { + m_scale = double(newScale); + return *this; + } + +#else + + Approx operator()( double value ) { + Approx approx( value ); + approx.epsilon( m_epsilon ); + approx.margin( m_margin ); + approx.scale( m_scale ); + return approx; + } + + friend bool operator == ( double lhs, Approx const& rhs ) { + // Thanks to Richard Harris for his help refining this formula + bool relativeOK = std::fabs( lhs - rhs.m_value ) < rhs.m_epsilon * (rhs.m_scale + (std::max)( std::fabs(lhs), std::fabs(rhs.m_value) ) ); + if (relativeOK) { + return true; + } + return std::fabs(lhs - rhs.m_value) < rhs.m_margin; + } + + friend bool operator == ( Approx const& lhs, double rhs ) { + return operator==( rhs, lhs ); + } + + friend bool operator != ( double lhs, Approx const& rhs ) { + return !operator==( lhs, rhs ); + } + + friend bool operator != ( Approx const& lhs, double rhs ) { + return !operator==( rhs, lhs ); + } + + friend bool operator <= ( double lhs, Approx const& rhs ) { + return lhs < rhs.m_value || lhs == rhs; + } + + friend bool operator <= ( Approx const& lhs, double rhs ) { + return lhs.m_value < rhs || lhs == rhs; + } + + friend bool operator >= ( double lhs, Approx const& rhs ) { + return lhs > rhs.m_value || lhs == rhs; + } + + friend bool operator >= ( Approx const& lhs, double rhs ) { + return lhs.m_value > rhs || lhs == rhs; + } + + Approx& epsilon( double newEpsilon ) { + m_epsilon = newEpsilon; + return *this; + } + + Approx& margin( double newMargin ) { + m_margin = newMargin; + return *this; + } + + Approx& scale( double newScale ) { + m_scale = newScale; + return *this; + } +#endif + + std::string toString() const { + std::ostringstream oss; + oss << "Approx( " << Catch::toString( m_value ) << " )"; + return oss.str(); + } + + private: + double m_epsilon; + double m_margin; + double m_scale; + double m_value; + }; +} + +template<> +inline std::string toString( Detail::Approx const& value ) { + return value.toString(); +} + +} // end namespace Catch + +// #included from: internal/catch_matchers_string.h +#define TWOBLUECUBES_CATCH_MATCHERS_STRING_H_INCLUDED + +namespace Catch { +namespace Matchers { + + namespace StdString { + + struct CasedString + { + CasedString( std::string const& str, CaseSensitive::Choice caseSensitivity ); + std::string adjustString( std::string const& str ) const; + std::string caseSensitivitySuffix() const; + + CaseSensitive::Choice m_caseSensitivity; + std::string m_str; + }; + + struct StringMatcherBase : MatcherBase { + StringMatcherBase( std::string const& operation, CasedString const& comparator ); + virtual std::string describe() const CATCH_OVERRIDE; + + CasedString m_comparator; + std::string m_operation; + }; + + struct EqualsMatcher : StringMatcherBase { + EqualsMatcher( CasedString const& comparator ); + virtual bool match( std::string const& source ) const CATCH_OVERRIDE; + }; + struct ContainsMatcher : StringMatcherBase { + ContainsMatcher( CasedString const& comparator ); + virtual bool match( std::string const& source ) const CATCH_OVERRIDE; + }; + struct StartsWithMatcher : StringMatcherBase { + StartsWithMatcher( CasedString const& comparator ); + virtual bool match( std::string const& source ) const CATCH_OVERRIDE; + }; + struct EndsWithMatcher : StringMatcherBase { + EndsWithMatcher( CasedString const& comparator ); + virtual bool match( std::string const& source ) const CATCH_OVERRIDE; + }; + + } // namespace StdString + + // The following functions create the actual matcher objects. + // This allows the types to be inferred + + StdString::EqualsMatcher Equals( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + StdString::ContainsMatcher Contains( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + StdString::EndsWithMatcher EndsWith( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + StdString::StartsWithMatcher StartsWith( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + +} // namespace Matchers +} // namespace Catch + +// #included from: internal/catch_matchers_vector.h +#define TWOBLUECUBES_CATCH_MATCHERS_VECTOR_H_INCLUDED + +namespace Catch { +namespace Matchers { + + namespace Vector { + + template + struct ContainsElementMatcher : MatcherBase, T> { + + ContainsElementMatcher(T const &comparator) : m_comparator( comparator) {} + + bool match(std::vector const &v) const CATCH_OVERRIDE { + return std::find(v.begin(), v.end(), m_comparator) != v.end(); + } + + virtual std::string describe() const CATCH_OVERRIDE { + return "Contains: " + Catch::toString( m_comparator ); + } + + T const& m_comparator; + }; + + template + struct ContainsMatcher : MatcherBase, std::vector > { + + ContainsMatcher(std::vector const &comparator) : m_comparator( comparator ) {} + + bool match(std::vector const &v) const CATCH_OVERRIDE { + // !TBD: see note in EqualsMatcher + if (m_comparator.size() > v.size()) + return false; + for (size_t i = 0; i < m_comparator.size(); ++i) + if (std::find(v.begin(), v.end(), m_comparator[i]) == v.end()) + return false; + return true; + } + virtual std::string describe() const CATCH_OVERRIDE { + return "Contains: " + Catch::toString( m_comparator ); + } + + std::vector const& m_comparator; + }; + + template + struct EqualsMatcher : MatcherBase, std::vector > { + + EqualsMatcher(std::vector const &comparator) : m_comparator( comparator ) {} + + bool match(std::vector const &v) const CATCH_OVERRIDE { + // !TBD: This currently works if all elements can be compared using != + // - a more general approach would be via a compare template that defaults + // to using !=. but could be specialised for, e.g. std::vector etc + // - then just call that directly + if (m_comparator.size() != v.size()) + return false; + for (size_t i = 0; i < v.size(); ++i) + if (m_comparator[i] != v[i]) + return false; + return true; + } + virtual std::string describe() const CATCH_OVERRIDE { + return "Equals: " + Catch::toString( m_comparator ); + } + std::vector const& m_comparator; + }; + + } // namespace Vector + + // The following functions create the actual matcher objects. + // This allows the types to be inferred + + template + Vector::ContainsMatcher Contains( std::vector const& comparator ) { + return Vector::ContainsMatcher( comparator ); + } + + template + Vector::ContainsElementMatcher VectorContains( T const& comparator ) { + return Vector::ContainsElementMatcher( comparator ); + } + + template + Vector::EqualsMatcher Equals( std::vector const& comparator ) { + return Vector::EqualsMatcher( comparator ); + } + +} // namespace Matchers +} // namespace Catch + +// #included from: internal/catch_interfaces_tag_alias_registry.h +#define TWOBLUECUBES_CATCH_INTERFACES_TAG_ALIAS_REGISTRY_H_INCLUDED + +// #included from: catch_tag_alias.h +#define TWOBLUECUBES_CATCH_TAG_ALIAS_H_INCLUDED + +#include + +namespace Catch { + + struct TagAlias { + TagAlias( std::string const& _tag, SourceLineInfo _lineInfo ) : tag( _tag ), lineInfo( _lineInfo ) {} + + std::string tag; + SourceLineInfo lineInfo; + }; + + struct RegistrarForTagAliases { + RegistrarForTagAliases( char const* alias, char const* tag, SourceLineInfo const& lineInfo ); + }; + +} // end namespace Catch + +#define CATCH_REGISTER_TAG_ALIAS( alias, spec ) namespace{ Catch::RegistrarForTagAliases INTERNAL_CATCH_UNIQUE_NAME( AutoRegisterTagAlias )( alias, spec, CATCH_INTERNAL_LINEINFO ); } +// #included from: catch_option.hpp +#define TWOBLUECUBES_CATCH_OPTION_HPP_INCLUDED + +namespace Catch { + + // An optional type + template + class Option { + public: + Option() : nullableValue( CATCH_NULL ) {} + Option( T const& _value ) + : nullableValue( new( storage ) T( _value ) ) + {} + Option( Option const& _other ) + : nullableValue( _other ? new( storage ) T( *_other ) : CATCH_NULL ) + {} + + ~Option() { + reset(); + } + + Option& operator= ( Option const& _other ) { + if( &_other != this ) { + reset(); + if( _other ) + nullableValue = new( storage ) T( *_other ); + } + return *this; + } + Option& operator = ( T const& _value ) { + reset(); + nullableValue = new( storage ) T( _value ); + return *this; + } + + void reset() { + if( nullableValue ) + nullableValue->~T(); + nullableValue = CATCH_NULL; + } + + T& operator*() { return *nullableValue; } + T const& operator*() const { return *nullableValue; } + T* operator->() { return nullableValue; } + const T* operator->() const { return nullableValue; } + + T valueOr( T const& defaultValue ) const { + return nullableValue ? *nullableValue : defaultValue; + } + + bool some() const { return nullableValue != CATCH_NULL; } + bool none() const { return nullableValue == CATCH_NULL; } + + bool operator !() const { return nullableValue == CATCH_NULL; } + operator SafeBool::type() const { + return SafeBool::makeSafe( some() ); + } + + private: + T *nullableValue; + union { + char storage[sizeof(T)]; + + // These are here to force alignment for the storage + long double dummy1; + void (*dummy2)(); + long double dummy3; +#ifdef CATCH_CONFIG_CPP11_LONG_LONG + long long dummy4; +#endif + }; + }; + +} // end namespace Catch + +namespace Catch { + + struct ITagAliasRegistry { + virtual ~ITagAliasRegistry(); + virtual Option find( std::string const& alias ) const = 0; + virtual std::string expandAliases( std::string const& unexpandedTestSpec ) const = 0; + + static ITagAliasRegistry const& get(); + }; + +} // end namespace Catch + +// These files are included here so the single_include script doesn't put them +// in the conditionally compiled sections +// #included from: internal/catch_test_case_info.h +#define TWOBLUECUBES_CATCH_TEST_CASE_INFO_H_INCLUDED + +#include +#include + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +namespace Catch { + + struct ITestCase; + + struct TestCaseInfo { + enum SpecialProperties{ + None = 0, + IsHidden = 1 << 1, + ShouldFail = 1 << 2, + MayFail = 1 << 3, + Throws = 1 << 4, + NonPortable = 1 << 5 + }; + + TestCaseInfo( std::string const& _name, + std::string const& _className, + std::string const& _description, + std::set const& _tags, + SourceLineInfo const& _lineInfo ); + + TestCaseInfo( TestCaseInfo const& other ); + + friend void setTags( TestCaseInfo& testCaseInfo, std::set const& tags ); + + bool isHidden() const; + bool throws() const; + bool okToFail() const; + bool expectedToFail() const; + + std::string name; + std::string className; + std::string description; + std::set tags; + std::set lcaseTags; + std::string tagsAsString; + SourceLineInfo lineInfo; + SpecialProperties properties; + }; + + class TestCase : public TestCaseInfo { + public: + + TestCase( ITestCase* testCase, TestCaseInfo const& info ); + TestCase( TestCase const& other ); + + TestCase withName( std::string const& _newName ) const; + + void invoke() const; + + TestCaseInfo const& getTestCaseInfo() const; + + void swap( TestCase& other ); + bool operator == ( TestCase const& other ) const; + bool operator < ( TestCase const& other ) const; + TestCase& operator = ( TestCase const& other ); + + private: + Ptr test; + }; + + TestCase makeTestCase( ITestCase* testCase, + std::string const& className, + std::string const& name, + std::string const& description, + SourceLineInfo const& lineInfo ); +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + + +#ifdef __OBJC__ +// #included from: internal/catch_objc.hpp +#define TWOBLUECUBES_CATCH_OBJC_HPP_INCLUDED + +#import + +#include + +// NB. Any general catch headers included here must be included +// in catch.hpp first to make sure they are included by the single +// header for non obj-usage + +/////////////////////////////////////////////////////////////////////////////// +// This protocol is really only here for (self) documenting purposes, since +// all its methods are optional. +@protocol OcFixture + +@optional + +-(void) setUp; +-(void) tearDown; + +@end + +namespace Catch { + + class OcMethod : public SharedImpl { + + public: + OcMethod( Class cls, SEL sel ) : m_cls( cls ), m_sel( sel ) {} + + virtual void invoke() const { + id obj = [[m_cls alloc] init]; + + performOptionalSelector( obj, @selector(setUp) ); + performOptionalSelector( obj, m_sel ); + performOptionalSelector( obj, @selector(tearDown) ); + + arcSafeRelease( obj ); + } + private: + virtual ~OcMethod() {} + + Class m_cls; + SEL m_sel; + }; + + namespace Detail{ + + inline std::string getAnnotation( Class cls, + std::string const& annotationName, + std::string const& testCaseName ) { + NSString* selStr = [[NSString alloc] initWithFormat:@"Catch_%s_%s", annotationName.c_str(), testCaseName.c_str()]; + SEL sel = NSSelectorFromString( selStr ); + arcSafeRelease( selStr ); + id value = performOptionalSelector( cls, sel ); + if( value ) + return [(NSString*)value UTF8String]; + return ""; + } + } + + inline size_t registerTestMethods() { + size_t noTestMethods = 0; + int noClasses = objc_getClassList( CATCH_NULL, 0 ); + + Class* classes = (CATCH_UNSAFE_UNRETAINED Class *)malloc( sizeof(Class) * noClasses); + objc_getClassList( classes, noClasses ); + + for( int c = 0; c < noClasses; c++ ) { + Class cls = classes[c]; + { + u_int count; + Method* methods = class_copyMethodList( cls, &count ); + for( u_int m = 0; m < count ; m++ ) { + SEL selector = method_getName(methods[m]); + std::string methodName = sel_getName(selector); + if( startsWith( methodName, "Catch_TestCase_" ) ) { + std::string testCaseName = methodName.substr( 15 ); + std::string name = Detail::getAnnotation( cls, "Name", testCaseName ); + std::string desc = Detail::getAnnotation( cls, "Description", testCaseName ); + const char* className = class_getName( cls ); + + getMutableRegistryHub().registerTest( makeTestCase( new OcMethod( cls, selector ), className, name.c_str(), desc.c_str(), SourceLineInfo() ) ); + noTestMethods++; + } + } + free(methods); + } + } + return noTestMethods; + } + + namespace Matchers { + namespace Impl { + namespace NSStringMatchers { + + struct StringHolder : MatcherBase{ + StringHolder( NSString* substr ) : m_substr( [substr copy] ){} + StringHolder( StringHolder const& other ) : m_substr( [other.m_substr copy] ){} + StringHolder() { + arcSafeRelease( m_substr ); + } + + virtual bool match( NSString* arg ) const CATCH_OVERRIDE { + return false; + } + + NSString* m_substr; + }; + + struct Equals : StringHolder { + Equals( NSString* substr ) : StringHolder( substr ){} + + virtual bool match( NSString* str ) const CATCH_OVERRIDE { + return (str != nil || m_substr == nil ) && + [str isEqualToString:m_substr]; + } + + virtual std::string describe() const CATCH_OVERRIDE { + return "equals string: " + Catch::toString( m_substr ); + } + }; + + struct Contains : StringHolder { + Contains( NSString* substr ) : StringHolder( substr ){} + + virtual bool match( NSString* str ) const { + return (str != nil || m_substr == nil ) && + [str rangeOfString:m_substr].location != NSNotFound; + } + + virtual std::string describe() const CATCH_OVERRIDE { + return "contains string: " + Catch::toString( m_substr ); + } + }; + + struct StartsWith : StringHolder { + StartsWith( NSString* substr ) : StringHolder( substr ){} + + virtual bool match( NSString* str ) const { + return (str != nil || m_substr == nil ) && + [str rangeOfString:m_substr].location == 0; + } + + virtual std::string describe() const CATCH_OVERRIDE { + return "starts with: " + Catch::toString( m_substr ); + } + }; + struct EndsWith : StringHolder { + EndsWith( NSString* substr ) : StringHolder( substr ){} + + virtual bool match( NSString* str ) const { + return (str != nil || m_substr == nil ) && + [str rangeOfString:m_substr].location == [str length] - [m_substr length]; + } + + virtual std::string describe() const CATCH_OVERRIDE { + return "ends with: " + Catch::toString( m_substr ); + } + }; + + } // namespace NSStringMatchers + } // namespace Impl + + inline Impl::NSStringMatchers::Equals + Equals( NSString* substr ){ return Impl::NSStringMatchers::Equals( substr ); } + + inline Impl::NSStringMatchers::Contains + Contains( NSString* substr ){ return Impl::NSStringMatchers::Contains( substr ); } + + inline Impl::NSStringMatchers::StartsWith + StartsWith( NSString* substr ){ return Impl::NSStringMatchers::StartsWith( substr ); } + + inline Impl::NSStringMatchers::EndsWith + EndsWith( NSString* substr ){ return Impl::NSStringMatchers::EndsWith( substr ); } + + } // namespace Matchers + + using namespace Matchers; + +} // namespace Catch + +/////////////////////////////////////////////////////////////////////////////// +#define OC_TEST_CASE( name, desc )\ ++(NSString*) INTERNAL_CATCH_UNIQUE_NAME( Catch_Name_test ) \ +{\ +return @ name; \ +}\ ++(NSString*) INTERNAL_CATCH_UNIQUE_NAME( Catch_Description_test ) \ +{ \ +return @ desc; \ +} \ +-(void) INTERNAL_CATCH_UNIQUE_NAME( Catch_TestCase_test ) + +#endif + +#ifdef CATCH_IMPL + +// !TBD: Move the leak detector code into a separate header +#ifdef CATCH_CONFIG_WINDOWS_CRTDBG +#include +class LeakDetector { +public: + LeakDetector() { + int flag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); + flag |= _CRTDBG_LEAK_CHECK_DF; + flag |= _CRTDBG_ALLOC_MEM_DF; + _CrtSetDbgFlag(flag); + _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG); + _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDERR); + // Change this to leaking allocation's number to break there + _CrtSetBreakAlloc(-1); + } +}; +#else +class LeakDetector {}; +#endif + +LeakDetector leakDetector; + +// #included from: internal/catch_impl.hpp +#define TWOBLUECUBES_CATCH_IMPL_HPP_INCLUDED + +// Collect all the implementation files together here +// These are the equivalent of what would usually be cpp files + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wweak-vtables" +#endif + +// #included from: ../catch_session.hpp +#define TWOBLUECUBES_CATCH_RUNNER_HPP_INCLUDED + +// #included from: internal/catch_commandline.hpp +#define TWOBLUECUBES_CATCH_COMMANDLINE_HPP_INCLUDED + +// #included from: catch_config.hpp +#define TWOBLUECUBES_CATCH_CONFIG_HPP_INCLUDED + +// #included from: catch_test_spec_parser.hpp +#define TWOBLUECUBES_CATCH_TEST_SPEC_PARSER_HPP_INCLUDED + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +// #included from: catch_test_spec.hpp +#define TWOBLUECUBES_CATCH_TEST_SPEC_HPP_INCLUDED + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +// #included from: catch_wildcard_pattern.hpp +#define TWOBLUECUBES_CATCH_WILDCARD_PATTERN_HPP_INCLUDED + +#include + +namespace Catch +{ + class WildcardPattern { + enum WildcardPosition { + NoWildcard = 0, + WildcardAtStart = 1, + WildcardAtEnd = 2, + WildcardAtBothEnds = WildcardAtStart | WildcardAtEnd + }; + + public: + + WildcardPattern( std::string const& pattern, CaseSensitive::Choice caseSensitivity ) + : m_caseSensitivity( caseSensitivity ), + m_wildcard( NoWildcard ), + m_pattern( adjustCase( pattern ) ) + { + if( startsWith( m_pattern, '*' ) ) { + m_pattern = m_pattern.substr( 1 ); + m_wildcard = WildcardAtStart; + } + if( endsWith( m_pattern, '*' ) ) { + m_pattern = m_pattern.substr( 0, m_pattern.size()-1 ); + m_wildcard = static_cast( m_wildcard | WildcardAtEnd ); + } + } + virtual ~WildcardPattern(); + virtual bool matches( std::string const& str ) const { + switch( m_wildcard ) { + case NoWildcard: + return m_pattern == adjustCase( str ); + case WildcardAtStart: + return endsWith( adjustCase( str ), m_pattern ); + case WildcardAtEnd: + return startsWith( adjustCase( str ), m_pattern ); + case WildcardAtBothEnds: + return contains( adjustCase( str ), m_pattern ); + } + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunreachable-code" +#endif + throw std::logic_error( "Unknown enum" ); +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + } + private: + std::string adjustCase( std::string const& str ) const { + return m_caseSensitivity == CaseSensitive::No ? toLower( str ) : str; + } + CaseSensitive::Choice m_caseSensitivity; + WildcardPosition m_wildcard; + std::string m_pattern; + }; +} + +#include +#include + +namespace Catch { + + class TestSpec { + struct Pattern : SharedImpl<> { + virtual ~Pattern(); + virtual bool matches( TestCaseInfo const& testCase ) const = 0; + }; + class NamePattern : public Pattern { + public: + NamePattern( std::string const& name ) + : m_wildcardPattern( toLower( name ), CaseSensitive::No ) + {} + virtual ~NamePattern(); + virtual bool matches( TestCaseInfo const& testCase ) const { + return m_wildcardPattern.matches( toLower( testCase.name ) ); + } + private: + WildcardPattern m_wildcardPattern; + }; + + class TagPattern : public Pattern { + public: + TagPattern( std::string const& tag ) : m_tag( toLower( tag ) ) {} + virtual ~TagPattern(); + virtual bool matches( TestCaseInfo const& testCase ) const { + return testCase.lcaseTags.find( m_tag ) != testCase.lcaseTags.end(); + } + private: + std::string m_tag; + }; + + class ExcludedPattern : public Pattern { + public: + ExcludedPattern( Ptr const& underlyingPattern ) : m_underlyingPattern( underlyingPattern ) {} + virtual ~ExcludedPattern(); + virtual bool matches( TestCaseInfo const& testCase ) const { return !m_underlyingPattern->matches( testCase ); } + private: + Ptr m_underlyingPattern; + }; + + struct Filter { + std::vector > m_patterns; + + bool matches( TestCaseInfo const& testCase ) const { + // All patterns in a filter must match for the filter to be a match + for( std::vector >::const_iterator it = m_patterns.begin(), itEnd = m_patterns.end(); it != itEnd; ++it ) { + if( !(*it)->matches( testCase ) ) + return false; + } + return true; + } + }; + + public: + bool hasFilters() const { + return !m_filters.empty(); + } + bool matches( TestCaseInfo const& testCase ) const { + // A TestSpec matches if any filter matches + for( std::vector::const_iterator it = m_filters.begin(), itEnd = m_filters.end(); it != itEnd; ++it ) + if( it->matches( testCase ) ) + return true; + return false; + } + + private: + std::vector m_filters; + + friend class TestSpecParser; + }; +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +namespace Catch { + + class TestSpecParser { + enum Mode{ None, Name, QuotedName, Tag, EscapedName }; + Mode m_mode; + bool m_exclusion; + std::size_t m_start, m_pos; + std::string m_arg; + std::vector m_escapeChars; + TestSpec::Filter m_currentFilter; + TestSpec m_testSpec; + ITagAliasRegistry const* m_tagAliases; + + public: + TestSpecParser( ITagAliasRegistry const& tagAliases ) : m_tagAliases( &tagAliases ) {} + + TestSpecParser& parse( std::string const& arg ) { + m_mode = None; + m_exclusion = false; + m_start = std::string::npos; + m_arg = m_tagAliases->expandAliases( arg ); + m_escapeChars.clear(); + for( m_pos = 0; m_pos < m_arg.size(); ++m_pos ) + visitChar( m_arg[m_pos] ); + if( m_mode == Name ) + addPattern(); + return *this; + } + TestSpec testSpec() { + addFilter(); + return m_testSpec; + } + private: + void visitChar( char c ) { + if( m_mode == None ) { + switch( c ) { + case ' ': return; + case '~': m_exclusion = true; return; + case '[': return startNewMode( Tag, ++m_pos ); + case '"': return startNewMode( QuotedName, ++m_pos ); + case '\\': return escape(); + default: startNewMode( Name, m_pos ); break; + } + } + if( m_mode == Name ) { + if( c == ',' ) { + addPattern(); + addFilter(); + } + else if( c == '[' ) { + if( subString() == "exclude:" ) + m_exclusion = true; + else + addPattern(); + startNewMode( Tag, ++m_pos ); + } + else if( c == '\\' ) + escape(); + } + else if( m_mode == EscapedName ) + m_mode = Name; + else if( m_mode == QuotedName && c == '"' ) + addPattern(); + else if( m_mode == Tag && c == ']' ) + addPattern(); + } + void startNewMode( Mode mode, std::size_t start ) { + m_mode = mode; + m_start = start; + } + void escape() { + if( m_mode == None ) + m_start = m_pos; + m_mode = EscapedName; + m_escapeChars.push_back( m_pos ); + } + std::string subString() const { return m_arg.substr( m_start, m_pos - m_start ); } + template + void addPattern() { + std::string token = subString(); + for( size_t i = 0; i < m_escapeChars.size(); ++i ) + token = token.substr( 0, m_escapeChars[i]-m_start-i ) + token.substr( m_escapeChars[i]-m_start-i+1 ); + m_escapeChars.clear(); + if( startsWith( token, "exclude:" ) ) { + m_exclusion = true; + token = token.substr( 8 ); + } + if( !token.empty() ) { + Ptr pattern = new T( token ); + if( m_exclusion ) + pattern = new TestSpec::ExcludedPattern( pattern ); + m_currentFilter.m_patterns.push_back( pattern ); + } + m_exclusion = false; + m_mode = None; + } + void addFilter() { + if( !m_currentFilter.m_patterns.empty() ) { + m_testSpec.m_filters.push_back( m_currentFilter ); + m_currentFilter = TestSpec::Filter(); + } + } + }; + inline TestSpec parseTestSpec( std::string const& arg ) { + return TestSpecParser( ITagAliasRegistry::get() ).parse( arg ).testSpec(); + } + +} // namespace Catch + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +// #included from: catch_interfaces_config.h +#define TWOBLUECUBES_CATCH_INTERFACES_CONFIG_H_INCLUDED + +#include +#include +#include + +namespace Catch { + + struct Verbosity { enum Level { + NoOutput = 0, + Quiet, + Normal + }; }; + + struct WarnAbout { enum What { + Nothing = 0x00, + NoAssertions = 0x01 + }; }; + + struct ShowDurations { enum OrNot { + DefaultForReporter, + Always, + Never + }; }; + struct RunTests { enum InWhatOrder { + InDeclarationOrder, + InLexicographicalOrder, + InRandomOrder + }; }; + struct UseColour { enum YesOrNo { + Auto, + Yes, + No + }; }; + + class TestSpec; + + struct IConfig : IShared { + + virtual ~IConfig(); + + virtual bool allowThrows() const = 0; + virtual std::ostream& stream() const = 0; + virtual std::string name() const = 0; + virtual bool includeSuccessfulResults() const = 0; + virtual bool shouldDebugBreak() const = 0; + virtual bool warnAboutMissingAssertions() const = 0; + virtual int abortAfter() const = 0; + virtual bool showInvisibles() const = 0; + virtual ShowDurations::OrNot showDurations() const = 0; + virtual TestSpec const& testSpec() const = 0; + virtual RunTests::InWhatOrder runOrder() const = 0; + virtual unsigned int rngSeed() const = 0; + virtual UseColour::YesOrNo useColour() const = 0; + virtual std::vector const& getSectionsToRun() const = 0; + + }; +} + +// #included from: catch_stream.h +#define TWOBLUECUBES_CATCH_STREAM_H_INCLUDED + +// #included from: catch_streambuf.h +#define TWOBLUECUBES_CATCH_STREAMBUF_H_INCLUDED + +#include + +namespace Catch { + + class StreamBufBase : public std::streambuf { + public: + virtual ~StreamBufBase() CATCH_NOEXCEPT; + }; +} + +#include +#include +#include +#include + +namespace Catch { + + std::ostream& cout(); + std::ostream& cerr(); + + struct IStream { + virtual ~IStream() CATCH_NOEXCEPT; + virtual std::ostream& stream() const = 0; + }; + + class FileStream : public IStream { + mutable std::ofstream m_ofs; + public: + FileStream( std::string const& filename ); + virtual ~FileStream() CATCH_NOEXCEPT; + public: // IStream + virtual std::ostream& stream() const CATCH_OVERRIDE; + }; + + class CoutStream : public IStream { + mutable std::ostream m_os; + public: + CoutStream(); + virtual ~CoutStream() CATCH_NOEXCEPT; + + public: // IStream + virtual std::ostream& stream() const CATCH_OVERRIDE; + }; + + class DebugOutStream : public IStream { + CATCH_AUTO_PTR( StreamBufBase ) m_streamBuf; + mutable std::ostream m_os; + public: + DebugOutStream(); + virtual ~DebugOutStream() CATCH_NOEXCEPT; + + public: // IStream + virtual std::ostream& stream() const CATCH_OVERRIDE; + }; +} + +#include +#include +#include +#include + +#ifndef CATCH_CONFIG_CONSOLE_WIDTH +#define CATCH_CONFIG_CONSOLE_WIDTH 80 +#endif + +namespace Catch { + + struct ConfigData { + + ConfigData() + : listTests( false ), + listTags( false ), + listReporters( false ), + listTestNamesOnly( false ), + listExtraInfo( false ), + showSuccessfulTests( false ), + shouldDebugBreak( false ), + noThrow( false ), + showHelp( false ), + showInvisibles( false ), + filenamesAsTags( false ), + abortAfter( -1 ), + rngSeed( 0 ), + verbosity( Verbosity::Normal ), + warnings( WarnAbout::Nothing ), + showDurations( ShowDurations::DefaultForReporter ), + runOrder( RunTests::InDeclarationOrder ), + useColour( UseColour::Auto ) + {} + + bool listTests; + bool listTags; + bool listReporters; + bool listTestNamesOnly; + bool listExtraInfo; + + bool showSuccessfulTests; + bool shouldDebugBreak; + bool noThrow; + bool showHelp; + bool showInvisibles; + bool filenamesAsTags; + + int abortAfter; + unsigned int rngSeed; + + Verbosity::Level verbosity; + WarnAbout::What warnings; + ShowDurations::OrNot showDurations; + RunTests::InWhatOrder runOrder; + UseColour::YesOrNo useColour; + + std::string outputFilename; + std::string name; + std::string processName; + + std::vector reporterNames; + std::vector testsOrTags; + std::vector sectionsToRun; + }; + + class Config : public SharedImpl { + private: + Config( Config const& other ); + Config& operator = ( Config const& other ); + virtual void dummy(); + public: + + Config() + {} + + Config( ConfigData const& data ) + : m_data( data ), + m_stream( openStream() ) + { + if( !data.testsOrTags.empty() ) { + TestSpecParser parser( ITagAliasRegistry::get() ); + for( std::size_t i = 0; i < data.testsOrTags.size(); ++i ) + parser.parse( data.testsOrTags[i] ); + m_testSpec = parser.testSpec(); + } + } + + virtual ~Config() {} + + std::string const& getFilename() const { + return m_data.outputFilename ; + } + + bool listTests() const { return m_data.listTests; } + bool listTestNamesOnly() const { return m_data.listTestNamesOnly; } + bool listTags() const { return m_data.listTags; } + bool listReporters() const { return m_data.listReporters; } + bool listExtraInfo() const { return m_data.listExtraInfo; } + + std::string getProcessName() const { return m_data.processName; } + + std::vector const& getReporterNames() const { return m_data.reporterNames; } + std::vector const& getSectionsToRun() const CATCH_OVERRIDE { return m_data.sectionsToRun; } + + virtual TestSpec const& testSpec() const CATCH_OVERRIDE { return m_testSpec; } + + bool showHelp() const { return m_data.showHelp; } + + // IConfig interface + virtual bool allowThrows() const CATCH_OVERRIDE { return !m_data.noThrow; } + virtual std::ostream& stream() const CATCH_OVERRIDE { return m_stream->stream(); } + virtual std::string name() const CATCH_OVERRIDE { return m_data.name.empty() ? m_data.processName : m_data.name; } + virtual bool includeSuccessfulResults() const CATCH_OVERRIDE { return m_data.showSuccessfulTests; } + virtual bool warnAboutMissingAssertions() const CATCH_OVERRIDE { return m_data.warnings & WarnAbout::NoAssertions; } + virtual ShowDurations::OrNot showDurations() const CATCH_OVERRIDE { return m_data.showDurations; } + virtual RunTests::InWhatOrder runOrder() const CATCH_OVERRIDE { return m_data.runOrder; } + virtual unsigned int rngSeed() const CATCH_OVERRIDE { return m_data.rngSeed; } + virtual UseColour::YesOrNo useColour() const CATCH_OVERRIDE { return m_data.useColour; } + virtual bool shouldDebugBreak() const CATCH_OVERRIDE { return m_data.shouldDebugBreak; } + virtual int abortAfter() const CATCH_OVERRIDE { return m_data.abortAfter; } + virtual bool showInvisibles() const CATCH_OVERRIDE { return m_data.showInvisibles; } + + private: + + IStream const* openStream() { + if( m_data.outputFilename.empty() ) + return new CoutStream(); + else if( m_data.outputFilename[0] == '%' ) { + if( m_data.outputFilename == "%debug" ) + return new DebugOutStream(); + else + throw std::domain_error( "Unrecognised stream: " + m_data.outputFilename ); + } + else + return new FileStream( m_data.outputFilename ); + } + ConfigData m_data; + + CATCH_AUTO_PTR( IStream const ) m_stream; + TestSpec m_testSpec; + }; + +} // end namespace Catch + +// #included from: catch_clara.h +#define TWOBLUECUBES_CATCH_CLARA_H_INCLUDED + +// Use Catch's value for console width (store Clara's off to the side, if present) +#ifdef CLARA_CONFIG_CONSOLE_WIDTH +#define CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH CLARA_CONFIG_CONSOLE_WIDTH +#undef CLARA_CONFIG_CONSOLE_WIDTH +#endif +#define CLARA_CONFIG_CONSOLE_WIDTH CATCH_CONFIG_CONSOLE_WIDTH + +// Declare Clara inside the Catch namespace +#define STITCH_CLARA_OPEN_NAMESPACE namespace Catch { +// #included from: ../external/clara.h + +// Version 0.0.2.4 + +// Only use header guard if we are not using an outer namespace +#if !defined(TWOBLUECUBES_CLARA_H_INCLUDED) || defined(STITCH_CLARA_OPEN_NAMESPACE) + +#ifndef STITCH_CLARA_OPEN_NAMESPACE +#define TWOBLUECUBES_CLARA_H_INCLUDED +#define STITCH_CLARA_OPEN_NAMESPACE +#define STITCH_CLARA_CLOSE_NAMESPACE +#else +#define STITCH_CLARA_CLOSE_NAMESPACE } +#endif + +#define STITCH_TBC_TEXT_FORMAT_OPEN_NAMESPACE STITCH_CLARA_OPEN_NAMESPACE + +// ----------- #included from tbc_text_format.h ----------- + +// Only use header guard if we are not using an outer namespace +#if !defined(TBC_TEXT_FORMAT_H_INCLUDED) || defined(STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE) +#ifndef STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE +#define TBC_TEXT_FORMAT_H_INCLUDED +#endif + +#include +#include +#include +#include +#include + +// Use optional outer namespace +#ifdef STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE +namespace STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE { +#endif + +namespace Tbc { + +#ifdef TBC_TEXT_FORMAT_CONSOLE_WIDTH + const unsigned int consoleWidth = TBC_TEXT_FORMAT_CONSOLE_WIDTH; +#else + const unsigned int consoleWidth = 80; +#endif + + struct TextAttributes { + TextAttributes() + : initialIndent( std::string::npos ), + indent( 0 ), + width( consoleWidth-1 ), + tabChar( '\t' ) + {} + + TextAttributes& setInitialIndent( std::size_t _value ) { initialIndent = _value; return *this; } + TextAttributes& setIndent( std::size_t _value ) { indent = _value; return *this; } + TextAttributes& setWidth( std::size_t _value ) { width = _value; return *this; } + TextAttributes& setTabChar( char _value ) { tabChar = _value; return *this; } + + std::size_t initialIndent; // indent of first line, or npos + std::size_t indent; // indent of subsequent lines, or all if initialIndent is npos + std::size_t width; // maximum width of text, including indent. Longer text will wrap + char tabChar; // If this char is seen the indent is changed to current pos + }; + + class Text { + public: + Text( std::string const& _str, TextAttributes const& _attr = TextAttributes() ) + : attr( _attr ) + { + std::string wrappableChars = " [({.,/|\\-"; + std::size_t indent = _attr.initialIndent != std::string::npos + ? _attr.initialIndent + : _attr.indent; + std::string remainder = _str; + + while( !remainder.empty() ) { + if( lines.size() >= 1000 ) { + lines.push_back( "... message truncated due to excessive size" ); + return; + } + std::size_t tabPos = std::string::npos; + std::size_t width = (std::min)( remainder.size(), _attr.width - indent ); + std::size_t pos = remainder.find_first_of( '\n' ); + if( pos <= width ) { + width = pos; + } + pos = remainder.find_last_of( _attr.tabChar, width ); + if( pos != std::string::npos ) { + tabPos = pos; + if( remainder[width] == '\n' ) + width--; + remainder = remainder.substr( 0, tabPos ) + remainder.substr( tabPos+1 ); + } + + if( width == remainder.size() ) { + spliceLine( indent, remainder, width ); + } + else if( remainder[width] == '\n' ) { + spliceLine( indent, remainder, width ); + if( width <= 1 || remainder.size() != 1 ) + remainder = remainder.substr( 1 ); + indent = _attr.indent; + } + else { + pos = remainder.find_last_of( wrappableChars, width ); + if( pos != std::string::npos && pos > 0 ) { + spliceLine( indent, remainder, pos ); + if( remainder[0] == ' ' ) + remainder = remainder.substr( 1 ); + } + else { + spliceLine( indent, remainder, width-1 ); + lines.back() += "-"; + } + if( lines.size() == 1 ) + indent = _attr.indent; + if( tabPos != std::string::npos ) + indent += tabPos; + } + } + } + + void spliceLine( std::size_t _indent, std::string& _remainder, std::size_t _pos ) { + lines.push_back( std::string( _indent, ' ' ) + _remainder.substr( 0, _pos ) ); + _remainder = _remainder.substr( _pos ); + } + + typedef std::vector::const_iterator const_iterator; + + const_iterator begin() const { return lines.begin(); } + const_iterator end() const { return lines.end(); } + std::string const& last() const { return lines.back(); } + std::size_t size() const { return lines.size(); } + std::string const& operator[]( std::size_t _index ) const { return lines[_index]; } + std::string toString() const { + std::ostringstream oss; + oss << *this; + return oss.str(); + } + + inline friend std::ostream& operator << ( std::ostream& _stream, Text const& _text ) { + for( Text::const_iterator it = _text.begin(), itEnd = _text.end(); + it != itEnd; ++it ) { + if( it != _text.begin() ) + _stream << "\n"; + _stream << *it; + } + return _stream; + } + + private: + std::string str; + TextAttributes attr; + std::vector lines; + }; + +} // end namespace Tbc + +#ifdef STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE +} // end outer namespace +#endif + +#endif // TBC_TEXT_FORMAT_H_INCLUDED + +// ----------- end of #include from tbc_text_format.h ----------- +// ........... back in clara.h + +#undef STITCH_TBC_TEXT_FORMAT_OPEN_NAMESPACE + +// ----------- #included from clara_compilers.h ----------- + +#ifndef TWOBLUECUBES_CLARA_COMPILERS_H_INCLUDED +#define TWOBLUECUBES_CLARA_COMPILERS_H_INCLUDED + +// Detect a number of compiler features - mostly C++11/14 conformance - by compiler +// The following features are defined: +// +// CLARA_CONFIG_CPP11_NULLPTR : is nullptr supported? +// CLARA_CONFIG_CPP11_NOEXCEPT : is noexcept supported? +// CLARA_CONFIG_CPP11_GENERATED_METHODS : The delete and default keywords for compiler generated methods +// CLARA_CONFIG_CPP11_OVERRIDE : is override supported? +// CLARA_CONFIG_CPP11_UNIQUE_PTR : is unique_ptr supported (otherwise use auto_ptr) + +// CLARA_CONFIG_CPP11_OR_GREATER : Is C++11 supported? + +// CLARA_CONFIG_VARIADIC_MACROS : are variadic macros supported? + +// In general each macro has a _NO_ form +// (e.g. CLARA_CONFIG_CPP11_NO_NULLPTR) which disables the feature. +// Many features, at point of detection, define an _INTERNAL_ macro, so they +// can be combined, en-mass, with the _NO_ forms later. + +// All the C++11 features can be disabled with CLARA_CONFIG_NO_CPP11 + +#ifdef __clang__ + +#if __has_feature(cxx_nullptr) +#define CLARA_INTERNAL_CONFIG_CPP11_NULLPTR +#endif + +#if __has_feature(cxx_noexcept) +#define CLARA_INTERNAL_CONFIG_CPP11_NOEXCEPT +#endif + +#endif // __clang__ + +//////////////////////////////////////////////////////////////////////////////// +// GCC +#ifdef __GNUC__ + +#if __GNUC__ == 4 && __GNUC_MINOR__ >= 6 && defined(__GXX_EXPERIMENTAL_CXX0X__) +#define CLARA_INTERNAL_CONFIG_CPP11_NULLPTR +#endif + +// - otherwise more recent versions define __cplusplus >= 201103L +// and will get picked up below + +#endif // __GNUC__ + +//////////////////////////////////////////////////////////////////////////////// +// Visual C++ +#ifdef _MSC_VER + +#if (_MSC_VER >= 1600) +#define CLARA_INTERNAL_CONFIG_CPP11_NULLPTR +#define CLARA_INTERNAL_CONFIG_CPP11_UNIQUE_PTR +#endif + +#if (_MSC_VER >= 1900 ) // (VC++ 13 (VS2015)) +#define CLARA_INTERNAL_CONFIG_CPP11_NOEXCEPT +#define CLARA_INTERNAL_CONFIG_CPP11_GENERATED_METHODS +#endif + +#endif // _MSC_VER + +//////////////////////////////////////////////////////////////////////////////// +// C++ language feature support + +// catch all support for C++11 +#if defined(__cplusplus) && __cplusplus >= 201103L + +#define CLARA_CPP11_OR_GREATER + +#if !defined(CLARA_INTERNAL_CONFIG_CPP11_NULLPTR) +#define CLARA_INTERNAL_CONFIG_CPP11_NULLPTR +#endif + +#ifndef CLARA_INTERNAL_CONFIG_CPP11_NOEXCEPT +#define CLARA_INTERNAL_CONFIG_CPP11_NOEXCEPT +#endif + +#ifndef CLARA_INTERNAL_CONFIG_CPP11_GENERATED_METHODS +#define CLARA_INTERNAL_CONFIG_CPP11_GENERATED_METHODS +#endif + +#if !defined(CLARA_INTERNAL_CONFIG_CPP11_OVERRIDE) +#define CLARA_INTERNAL_CONFIG_CPP11_OVERRIDE +#endif +#if !defined(CLARA_INTERNAL_CONFIG_CPP11_UNIQUE_PTR) +#define CLARA_INTERNAL_CONFIG_CPP11_UNIQUE_PTR +#endif + +#endif // __cplusplus >= 201103L + +// Now set the actual defines based on the above + anything the user has configured +#if defined(CLARA_INTERNAL_CONFIG_CPP11_NULLPTR) && !defined(CLARA_CONFIG_CPP11_NO_NULLPTR) && !defined(CLARA_CONFIG_CPP11_NULLPTR) && !defined(CLARA_CONFIG_NO_CPP11) +#define CLARA_CONFIG_CPP11_NULLPTR +#endif +#if defined(CLARA_INTERNAL_CONFIG_CPP11_NOEXCEPT) && !defined(CLARA_CONFIG_CPP11_NO_NOEXCEPT) && !defined(CLARA_CONFIG_CPP11_NOEXCEPT) && !defined(CLARA_CONFIG_NO_CPP11) +#define CLARA_CONFIG_CPP11_NOEXCEPT +#endif +#if defined(CLARA_INTERNAL_CONFIG_CPP11_GENERATED_METHODS) && !defined(CLARA_CONFIG_CPP11_NO_GENERATED_METHODS) && !defined(CLARA_CONFIG_CPP11_GENERATED_METHODS) && !defined(CLARA_CONFIG_NO_CPP11) +#define CLARA_CONFIG_CPP11_GENERATED_METHODS +#endif +#if defined(CLARA_INTERNAL_CONFIG_CPP11_OVERRIDE) && !defined(CLARA_CONFIG_NO_OVERRIDE) && !defined(CLARA_CONFIG_CPP11_OVERRIDE) && !defined(CLARA_CONFIG_NO_CPP11) +#define CLARA_CONFIG_CPP11_OVERRIDE +#endif +#if defined(CLARA_INTERNAL_CONFIG_CPP11_UNIQUE_PTR) && !defined(CLARA_CONFIG_NO_UNIQUE_PTR) && !defined(CLARA_CONFIG_CPP11_UNIQUE_PTR) && !defined(CLARA_CONFIG_NO_CPP11) +#define CLARA_CONFIG_CPP11_UNIQUE_PTR +#endif + +// noexcept support: +#if defined(CLARA_CONFIG_CPP11_NOEXCEPT) && !defined(CLARA_NOEXCEPT) +#define CLARA_NOEXCEPT noexcept +# define CLARA_NOEXCEPT_IS(x) noexcept(x) +#else +#define CLARA_NOEXCEPT throw() +# define CLARA_NOEXCEPT_IS(x) +#endif + +// nullptr support +#ifdef CLARA_CONFIG_CPP11_NULLPTR +#define CLARA_NULL nullptr +#else +#define CLARA_NULL NULL +#endif + +// override support +#ifdef CLARA_CONFIG_CPP11_OVERRIDE +#define CLARA_OVERRIDE override +#else +#define CLARA_OVERRIDE +#endif + +// unique_ptr support +#ifdef CLARA_CONFIG_CPP11_UNIQUE_PTR +# define CLARA_AUTO_PTR( T ) std::unique_ptr +#else +# define CLARA_AUTO_PTR( T ) std::auto_ptr +#endif + +#endif // TWOBLUECUBES_CLARA_COMPILERS_H_INCLUDED + +// ----------- end of #include from clara_compilers.h ----------- +// ........... back in clara.h + +#include +#include +#include + +#if defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER) +#define CLARA_PLATFORM_WINDOWS +#endif + +// Use optional outer namespace +#ifdef STITCH_CLARA_OPEN_NAMESPACE +STITCH_CLARA_OPEN_NAMESPACE +#endif + +namespace Clara { + + struct UnpositionalTag {}; + + extern UnpositionalTag _; + +#ifdef CLARA_CONFIG_MAIN + UnpositionalTag _; +#endif + + namespace Detail { + +#ifdef CLARA_CONSOLE_WIDTH + const unsigned int consoleWidth = CLARA_CONFIG_CONSOLE_WIDTH; +#else + const unsigned int consoleWidth = 80; +#endif + + using namespace Tbc; + + inline bool startsWith( std::string const& str, std::string const& prefix ) { + return str.size() >= prefix.size() && str.substr( 0, prefix.size() ) == prefix; + } + + template struct RemoveConstRef{ typedef T type; }; + template struct RemoveConstRef{ typedef T type; }; + template struct RemoveConstRef{ typedef T type; }; + template struct RemoveConstRef{ typedef T type; }; + + template struct IsBool { static const bool value = false; }; + template<> struct IsBool { static const bool value = true; }; + + template + void convertInto( std::string const& _source, T& _dest ) { + std::stringstream ss; + ss << _source; + ss >> _dest; + if( ss.fail() ) + throw std::runtime_error( "Unable to convert " + _source + " to destination type" ); + } + inline void convertInto( std::string const& _source, std::string& _dest ) { + _dest = _source; + } + char toLowerCh(char c) { + return static_cast( std::tolower( c ) ); + } + inline void convertInto( std::string const& _source, bool& _dest ) { + std::string sourceLC = _source; + std::transform( sourceLC.begin(), sourceLC.end(), sourceLC.begin(), toLowerCh ); + if( sourceLC == "y" || sourceLC == "1" || sourceLC == "true" || sourceLC == "yes" || sourceLC == "on" ) + _dest = true; + else if( sourceLC == "n" || sourceLC == "0" || sourceLC == "false" || sourceLC == "no" || sourceLC == "off" ) + _dest = false; + else + throw std::runtime_error( "Expected a boolean value but did not recognise:\n '" + _source + "'" ); + } + + template + struct IArgFunction { + virtual ~IArgFunction() {} +#ifdef CLARA_CONFIG_CPP11_GENERATED_METHODS + IArgFunction() = default; + IArgFunction( IArgFunction const& ) = default; +#endif + virtual void set( ConfigT& config, std::string const& value ) const = 0; + virtual bool takesArg() const = 0; + virtual IArgFunction* clone() const = 0; + }; + + template + class BoundArgFunction { + public: + BoundArgFunction() : functionObj( CLARA_NULL ) {} + BoundArgFunction( IArgFunction* _functionObj ) : functionObj( _functionObj ) {} + BoundArgFunction( BoundArgFunction const& other ) : functionObj( other.functionObj ? other.functionObj->clone() : CLARA_NULL ) {} + BoundArgFunction& operator = ( BoundArgFunction const& other ) { + IArgFunction* newFunctionObj = other.functionObj ? other.functionObj->clone() : CLARA_NULL; + delete functionObj; + functionObj = newFunctionObj; + return *this; + } + ~BoundArgFunction() { delete functionObj; } + + void set( ConfigT& config, std::string const& value ) const { + functionObj->set( config, value ); + } + bool takesArg() const { return functionObj->takesArg(); } + + bool isSet() const { + return functionObj != CLARA_NULL; + } + private: + IArgFunction* functionObj; + }; + + template + struct NullBinder : IArgFunction{ + virtual void set( C&, std::string const& ) const {} + virtual bool takesArg() const { return true; } + virtual IArgFunction* clone() const { return new NullBinder( *this ); } + }; + + template + struct BoundDataMember : IArgFunction{ + BoundDataMember( M C::* _member ) : member( _member ) {} + virtual void set( C& p, std::string const& stringValue ) const { + convertInto( stringValue, p.*member ); + } + virtual bool takesArg() const { return !IsBool::value; } + virtual IArgFunction* clone() const { return new BoundDataMember( *this ); } + M C::* member; + }; + template + struct BoundUnaryMethod : IArgFunction{ + BoundUnaryMethod( void (C::*_member)( M ) ) : member( _member ) {} + virtual void set( C& p, std::string const& stringValue ) const { + typename RemoveConstRef::type value; + convertInto( stringValue, value ); + (p.*member)( value ); + } + virtual bool takesArg() const { return !IsBool::value; } + virtual IArgFunction* clone() const { return new BoundUnaryMethod( *this ); } + void (C::*member)( M ); + }; + template + struct BoundNullaryMethod : IArgFunction{ + BoundNullaryMethod( void (C::*_member)() ) : member( _member ) {} + virtual void set( C& p, std::string const& stringValue ) const { + bool value; + convertInto( stringValue, value ); + if( value ) + (p.*member)(); + } + virtual bool takesArg() const { return false; } + virtual IArgFunction* clone() const { return new BoundNullaryMethod( *this ); } + void (C::*member)(); + }; + + template + struct BoundUnaryFunction : IArgFunction{ + BoundUnaryFunction( void (*_function)( C& ) ) : function( _function ) {} + virtual void set( C& obj, std::string const& stringValue ) const { + bool value; + convertInto( stringValue, value ); + if( value ) + function( obj ); + } + virtual bool takesArg() const { return false; } + virtual IArgFunction* clone() const { return new BoundUnaryFunction( *this ); } + void (*function)( C& ); + }; + + template + struct BoundBinaryFunction : IArgFunction{ + BoundBinaryFunction( void (*_function)( C&, T ) ) : function( _function ) {} + virtual void set( C& obj, std::string const& stringValue ) const { + typename RemoveConstRef::type value; + convertInto( stringValue, value ); + function( obj, value ); + } + virtual bool takesArg() const { return !IsBool::value; } + virtual IArgFunction* clone() const { return new BoundBinaryFunction( *this ); } + void (*function)( C&, T ); + }; + + } // namespace Detail + + inline std::vector argsToVector( int argc, char const* const* const argv ) { + std::vector args( static_cast( argc ) ); + for( std::size_t i = 0; i < static_cast( argc ); ++i ) + args[i] = argv[i]; + + return args; + } + + class Parser { + enum Mode { None, MaybeShortOpt, SlashOpt, ShortOpt, LongOpt, Positional }; + Mode mode; + std::size_t from; + bool inQuotes; + public: + + struct Token { + enum Type { Positional, ShortOpt, LongOpt }; + Token( Type _type, std::string const& _data ) : type( _type ), data( _data ) {} + Type type; + std::string data; + }; + + Parser() : mode( None ), from( 0 ), inQuotes( false ){} + + void parseIntoTokens( std::vector const& args, std::vector& tokens ) { + const std::string doubleDash = "--"; + for( std::size_t i = 1; i < args.size() && args[i] != doubleDash; ++i ) + parseIntoTokens( args[i], tokens); + } + + void parseIntoTokens( std::string const& arg, std::vector& tokens ) { + for( std::size_t i = 0; i < arg.size(); ++i ) { + char c = arg[i]; + if( c == '"' ) + inQuotes = !inQuotes; + mode = handleMode( i, c, arg, tokens ); + } + mode = handleMode( arg.size(), '\0', arg, tokens ); + } + Mode handleMode( std::size_t i, char c, std::string const& arg, std::vector& tokens ) { + switch( mode ) { + case None: return handleNone( i, c ); + case MaybeShortOpt: return handleMaybeShortOpt( i, c ); + case ShortOpt: + case LongOpt: + case SlashOpt: return handleOpt( i, c, arg, tokens ); + case Positional: return handlePositional( i, c, arg, tokens ); + default: throw std::logic_error( "Unknown mode" ); + } + } + + Mode handleNone( std::size_t i, char c ) { + if( inQuotes ) { + from = i; + return Positional; + } + switch( c ) { + case '-': return MaybeShortOpt; +#ifdef CLARA_PLATFORM_WINDOWS + case '/': from = i+1; return SlashOpt; +#endif + default: from = i; return Positional; + } + } + Mode handleMaybeShortOpt( std::size_t i, char c ) { + switch( c ) { + case '-': from = i+1; return LongOpt; + default: from = i; return ShortOpt; + } + } + + Mode handleOpt( std::size_t i, char c, std::string const& arg, std::vector& tokens ) { + if( std::string( ":=\0", 3 ).find( c ) == std::string::npos ) + return mode; + + std::string optName = arg.substr( from, i-from ); + if( mode == ShortOpt ) + for( std::size_t j = 0; j < optName.size(); ++j ) + tokens.push_back( Token( Token::ShortOpt, optName.substr( j, 1 ) ) ); + else if( mode == SlashOpt && optName.size() == 1 ) + tokens.push_back( Token( Token::ShortOpt, optName ) ); + else + tokens.push_back( Token( Token::LongOpt, optName ) ); + return None; + } + Mode handlePositional( std::size_t i, char c, std::string const& arg, std::vector& tokens ) { + if( inQuotes || std::string( "\0", 1 ).find( c ) == std::string::npos ) + return mode; + + std::string data = arg.substr( from, i-from ); + tokens.push_back( Token( Token::Positional, data ) ); + return None; + } + }; + + template + struct CommonArgProperties { + CommonArgProperties() {} + CommonArgProperties( Detail::BoundArgFunction const& _boundField ) : boundField( _boundField ) {} + + Detail::BoundArgFunction boundField; + std::string description; + std::string detail; + std::string placeholder; // Only value if boundField takes an arg + + bool takesArg() const { + return !placeholder.empty(); + } + void validate() const { + if( !boundField.isSet() ) + throw std::logic_error( "option not bound" ); + } + }; + struct OptionArgProperties { + std::vector shortNames; + std::string longName; + + bool hasShortName( std::string const& shortName ) const { + return std::find( shortNames.begin(), shortNames.end(), shortName ) != shortNames.end(); + } + bool hasLongName( std::string const& _longName ) const { + return _longName == longName; + } + }; + struct PositionalArgProperties { + PositionalArgProperties() : position( -1 ) {} + int position; // -1 means non-positional (floating) + + bool isFixedPositional() const { + return position != -1; + } + }; + + template + class CommandLine { + + struct Arg : CommonArgProperties, OptionArgProperties, PositionalArgProperties { + Arg() {} + Arg( Detail::BoundArgFunction const& _boundField ) : CommonArgProperties( _boundField ) {} + + using CommonArgProperties::placeholder; // !TBD + + std::string dbgName() const { + if( !longName.empty() ) + return "--" + longName; + if( !shortNames.empty() ) + return "-" + shortNames[0]; + return "positional args"; + } + std::string commands() const { + std::ostringstream oss; + bool first = true; + std::vector::const_iterator it = shortNames.begin(), itEnd = shortNames.end(); + for(; it != itEnd; ++it ) { + if( first ) + first = false; + else + oss << ", "; + oss << "-" << *it; + } + if( !longName.empty() ) { + if( !first ) + oss << ", "; + oss << "--" << longName; + } + if( !placeholder.empty() ) + oss << " <" << placeholder << ">"; + return oss.str(); + } + }; + + typedef CLARA_AUTO_PTR( Arg ) ArgAutoPtr; + + friend void addOptName( Arg& arg, std::string const& optName ) + { + if( optName.empty() ) + return; + if( Detail::startsWith( optName, "--" ) ) { + if( !arg.longName.empty() ) + throw std::logic_error( "Only one long opt may be specified. '" + + arg.longName + + "' already specified, now attempting to add '" + + optName + "'" ); + arg.longName = optName.substr( 2 ); + } + else if( Detail::startsWith( optName, "-" ) ) + arg.shortNames.push_back( optName.substr( 1 ) ); + else + throw std::logic_error( "option must begin with - or --. Option was: '" + optName + "'" ); + } + friend void setPositionalArg( Arg& arg, int position ) + { + arg.position = position; + } + + class ArgBuilder { + public: + ArgBuilder( Arg* arg ) : m_arg( arg ) {} + + // Bind a non-boolean data member (requires placeholder string) + template + void bind( M C::* field, std::string const& placeholder ) { + m_arg->boundField = new Detail::BoundDataMember( field ); + m_arg->placeholder = placeholder; + } + // Bind a boolean data member (no placeholder required) + template + void bind( bool C::* field ) { + m_arg->boundField = new Detail::BoundDataMember( field ); + } + + // Bind a method taking a single, non-boolean argument (requires a placeholder string) + template + void bind( void (C::* unaryMethod)( M ), std::string const& placeholder ) { + m_arg->boundField = new Detail::BoundUnaryMethod( unaryMethod ); + m_arg->placeholder = placeholder; + } + + // Bind a method taking a single, boolean argument (no placeholder string required) + template + void bind( void (C::* unaryMethod)( bool ) ) { + m_arg->boundField = new Detail::BoundUnaryMethod( unaryMethod ); + } + + // Bind a method that takes no arguments (will be called if opt is present) + template + void bind( void (C::* nullaryMethod)() ) { + m_arg->boundField = new Detail::BoundNullaryMethod( nullaryMethod ); + } + + // Bind a free function taking a single argument - the object to operate on (no placeholder string required) + template + void bind( void (* unaryFunction)( C& ) ) { + m_arg->boundField = new Detail::BoundUnaryFunction( unaryFunction ); + } + + // Bind a free function taking a single argument - the object to operate on (requires a placeholder string) + template + void bind( void (* binaryFunction)( C&, T ), std::string const& placeholder ) { + m_arg->boundField = new Detail::BoundBinaryFunction( binaryFunction ); + m_arg->placeholder = placeholder; + } + + ArgBuilder& describe( std::string const& description ) { + m_arg->description = description; + return *this; + } + ArgBuilder& detail( std::string const& detail ) { + m_arg->detail = detail; + return *this; + } + + protected: + Arg* m_arg; + }; + + class OptBuilder : public ArgBuilder { + public: + OptBuilder( Arg* arg ) : ArgBuilder( arg ) {} + OptBuilder( OptBuilder& other ) : ArgBuilder( other ) {} + + OptBuilder& operator[]( std::string const& optName ) { + addOptName( *ArgBuilder::m_arg, optName ); + return *this; + } + }; + + public: + + CommandLine() + : m_boundProcessName( new Detail::NullBinder() ), + m_highestSpecifiedArgPosition( 0 ), + m_throwOnUnrecognisedTokens( false ) + {} + CommandLine( CommandLine const& other ) + : m_boundProcessName( other.m_boundProcessName ), + m_options ( other.m_options ), + m_positionalArgs( other.m_positionalArgs ), + m_highestSpecifiedArgPosition( other.m_highestSpecifiedArgPosition ), + m_throwOnUnrecognisedTokens( other.m_throwOnUnrecognisedTokens ) + { + if( other.m_floatingArg.get() ) + m_floatingArg.reset( new Arg( *other.m_floatingArg ) ); + } + + CommandLine& setThrowOnUnrecognisedTokens( bool shouldThrow = true ) { + m_throwOnUnrecognisedTokens = shouldThrow; + return *this; + } + + OptBuilder operator[]( std::string const& optName ) { + m_options.push_back( Arg() ); + addOptName( m_options.back(), optName ); + OptBuilder builder( &m_options.back() ); + return builder; + } + + ArgBuilder operator[]( int position ) { + m_positionalArgs.insert( std::make_pair( position, Arg() ) ); + if( position > m_highestSpecifiedArgPosition ) + m_highestSpecifiedArgPosition = position; + setPositionalArg( m_positionalArgs[position], position ); + ArgBuilder builder( &m_positionalArgs[position] ); + return builder; + } + + // Invoke this with the _ instance + ArgBuilder operator[]( UnpositionalTag ) { + if( m_floatingArg.get() ) + throw std::logic_error( "Only one unpositional argument can be added" ); + m_floatingArg.reset( new Arg() ); + ArgBuilder builder( m_floatingArg.get() ); + return builder; + } + + template + void bindProcessName( M C::* field ) { + m_boundProcessName = new Detail::BoundDataMember( field ); + } + template + void bindProcessName( void (C::*_unaryMethod)( M ) ) { + m_boundProcessName = new Detail::BoundUnaryMethod( _unaryMethod ); + } + + void optUsage( std::ostream& os, std::size_t indent = 0, std::size_t width = Detail::consoleWidth ) const { + typename std::vector::const_iterator itBegin = m_options.begin(), itEnd = m_options.end(), it; + std::size_t maxWidth = 0; + for( it = itBegin; it != itEnd; ++it ) + maxWidth = (std::max)( maxWidth, it->commands().size() ); + + for( it = itBegin; it != itEnd; ++it ) { + Detail::Text usage( it->commands(), Detail::TextAttributes() + .setWidth( maxWidth+indent ) + .setIndent( indent ) ); + Detail::Text desc( it->description, Detail::TextAttributes() + .setWidth( width - maxWidth - 3 ) ); + + for( std::size_t i = 0; i < (std::max)( usage.size(), desc.size() ); ++i ) { + std::string usageCol = i < usage.size() ? usage[i] : ""; + os << usageCol; + + if( i < desc.size() && !desc[i].empty() ) + os << std::string( indent + 2 + maxWidth - usageCol.size(), ' ' ) + << desc[i]; + os << "\n"; + } + } + } + std::string optUsage() const { + std::ostringstream oss; + optUsage( oss ); + return oss.str(); + } + + void argSynopsis( std::ostream& os ) const { + for( int i = 1; i <= m_highestSpecifiedArgPosition; ++i ) { + if( i > 1 ) + os << " "; + typename std::map::const_iterator it = m_positionalArgs.find( i ); + if( it != m_positionalArgs.end() ) + os << "<" << it->second.placeholder << ">"; + else if( m_floatingArg.get() ) + os << "<" << m_floatingArg->placeholder << ">"; + else + throw std::logic_error( "non consecutive positional arguments with no floating args" ); + } + // !TBD No indication of mandatory args + if( m_floatingArg.get() ) { + if( m_highestSpecifiedArgPosition > 1 ) + os << " "; + os << "[<" << m_floatingArg->placeholder << "> ...]"; + } + } + std::string argSynopsis() const { + std::ostringstream oss; + argSynopsis( oss ); + return oss.str(); + } + + void usage( std::ostream& os, std::string const& procName ) const { + validate(); + os << "usage:\n " << procName << " "; + argSynopsis( os ); + if( !m_options.empty() ) { + os << " [options]\n\nwhere options are: \n"; + optUsage( os, 2 ); + } + os << "\n"; + } + std::string usage( std::string const& procName ) const { + std::ostringstream oss; + usage( oss, procName ); + return oss.str(); + } + + ConfigT parse( std::vector const& args ) const { + ConfigT config; + parseInto( args, config ); + return config; + } + + std::vector parseInto( std::vector const& args, ConfigT& config ) const { + std::string processName = args.empty() ? std::string() : args[0]; + std::size_t lastSlash = processName.find_last_of( "/\\" ); + if( lastSlash != std::string::npos ) + processName = processName.substr( lastSlash+1 ); + m_boundProcessName.set( config, processName ); + std::vector tokens; + Parser parser; + parser.parseIntoTokens( args, tokens ); + return populate( tokens, config ); + } + + std::vector populate( std::vector const& tokens, ConfigT& config ) const { + validate(); + std::vector unusedTokens = populateOptions( tokens, config ); + unusedTokens = populateFixedArgs( unusedTokens, config ); + unusedTokens = populateFloatingArgs( unusedTokens, config ); + return unusedTokens; + } + + std::vector populateOptions( std::vector const& tokens, ConfigT& config ) const { + std::vector unusedTokens; + std::vector errors; + for( std::size_t i = 0; i < tokens.size(); ++i ) { + Parser::Token const& token = tokens[i]; + typename std::vector::const_iterator it = m_options.begin(), itEnd = m_options.end(); + for(; it != itEnd; ++it ) { + Arg const& arg = *it; + + try { + if( ( token.type == Parser::Token::ShortOpt && arg.hasShortName( token.data ) ) || + ( token.type == Parser::Token::LongOpt && arg.hasLongName( token.data ) ) ) { + if( arg.takesArg() ) { + if( i == tokens.size()-1 || tokens[i+1].type != Parser::Token::Positional ) + errors.push_back( "Expected argument to option: " + token.data ); + else + arg.boundField.set( config, tokens[++i].data ); + } + else { + arg.boundField.set( config, "true" ); + } + break; + } + } + catch( std::exception& ex ) { + errors.push_back( std::string( ex.what() ) + "\n- while parsing: (" + arg.commands() + ")" ); + } + } + if( it == itEnd ) { + if( token.type == Parser::Token::Positional || !m_throwOnUnrecognisedTokens ) + unusedTokens.push_back( token ); + else if( errors.empty() && m_throwOnUnrecognisedTokens ) + errors.push_back( "unrecognised option: " + token.data ); + } + } + if( !errors.empty() ) { + std::ostringstream oss; + for( std::vector::const_iterator it = errors.begin(), itEnd = errors.end(); + it != itEnd; + ++it ) { + if( it != errors.begin() ) + oss << "\n"; + oss << *it; + } + throw std::runtime_error( oss.str() ); + } + return unusedTokens; + } + std::vector populateFixedArgs( std::vector const& tokens, ConfigT& config ) const { + std::vector unusedTokens; + int position = 1; + for( std::size_t i = 0; i < tokens.size(); ++i ) { + Parser::Token const& token = tokens[i]; + typename std::map::const_iterator it = m_positionalArgs.find( position ); + if( it != m_positionalArgs.end() ) + it->second.boundField.set( config, token.data ); + else + unusedTokens.push_back( token ); + if( token.type == Parser::Token::Positional ) + position++; + } + return unusedTokens; + } + std::vector populateFloatingArgs( std::vector const& tokens, ConfigT& config ) const { + if( !m_floatingArg.get() ) + return tokens; + std::vector unusedTokens; + for( std::size_t i = 0; i < tokens.size(); ++i ) { + Parser::Token const& token = tokens[i]; + if( token.type == Parser::Token::Positional ) + m_floatingArg->boundField.set( config, token.data ); + else + unusedTokens.push_back( token ); + } + return unusedTokens; + } + + void validate() const + { + if( m_options.empty() && m_positionalArgs.empty() && !m_floatingArg.get() ) + throw std::logic_error( "No options or arguments specified" ); + + for( typename std::vector::const_iterator it = m_options.begin(), + itEnd = m_options.end(); + it != itEnd; ++it ) + it->validate(); + } + + private: + Detail::BoundArgFunction m_boundProcessName; + std::vector m_options; + std::map m_positionalArgs; + ArgAutoPtr m_floatingArg; + int m_highestSpecifiedArgPosition; + bool m_throwOnUnrecognisedTokens; + }; + +} // end namespace Clara + +STITCH_CLARA_CLOSE_NAMESPACE +#undef STITCH_CLARA_OPEN_NAMESPACE +#undef STITCH_CLARA_CLOSE_NAMESPACE + +#endif // TWOBLUECUBES_CLARA_H_INCLUDED +#undef STITCH_CLARA_OPEN_NAMESPACE + +// Restore Clara's value for console width, if present +#ifdef CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH +#define CLARA_CONFIG_CONSOLE_WIDTH CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH +#undef CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH +#endif + +#include +#include + +namespace Catch { + + inline void abortAfterFirst( ConfigData& config ) { config.abortAfter = 1; } + inline void abortAfterX( ConfigData& config, int x ) { + if( x < 1 ) + throw std::runtime_error( "Value after -x or --abortAfter must be greater than zero" ); + config.abortAfter = x; + } + inline void addTestOrTags( ConfigData& config, std::string const& _testSpec ) { config.testsOrTags.push_back( _testSpec ); } + inline void addSectionToRun( ConfigData& config, std::string const& sectionName ) { config.sectionsToRun.push_back( sectionName ); } + inline void addReporterName( ConfigData& config, std::string const& _reporterName ) { config.reporterNames.push_back( _reporterName ); } + + inline void addWarning( ConfigData& config, std::string const& _warning ) { + if( _warning == "NoAssertions" ) + config.warnings = static_cast( config.warnings | WarnAbout::NoAssertions ); + else + throw std::runtime_error( "Unrecognised warning: '" + _warning + '\'' ); + } + inline void setOrder( ConfigData& config, std::string const& order ) { + if( startsWith( "declared", order ) ) + config.runOrder = RunTests::InDeclarationOrder; + else if( startsWith( "lexical", order ) ) + config.runOrder = RunTests::InLexicographicalOrder; + else if( startsWith( "random", order ) ) + config.runOrder = RunTests::InRandomOrder; + else + throw std::runtime_error( "Unrecognised ordering: '" + order + '\'' ); + } + inline void setRngSeed( ConfigData& config, std::string const& seed ) { + if( seed == "time" ) { + config.rngSeed = static_cast( std::time(0) ); + } + else { + std::stringstream ss; + ss << seed; + ss >> config.rngSeed; + if( ss.fail() ) + throw std::runtime_error( "Argument to --rng-seed should be the word 'time' or a number" ); + } + } + inline void setVerbosity( ConfigData& config, int level ) { + // !TBD: accept strings? + config.verbosity = static_cast( level ); + } + inline void setShowDurations( ConfigData& config, bool _showDurations ) { + config.showDurations = _showDurations + ? ShowDurations::Always + : ShowDurations::Never; + } + inline void setUseColour( ConfigData& config, std::string const& value ) { + std::string mode = toLower( value ); + + if( mode == "yes" ) + config.useColour = UseColour::Yes; + else if( mode == "no" ) + config.useColour = UseColour::No; + else if( mode == "auto" ) + config.useColour = UseColour::Auto; + else + throw std::runtime_error( "colour mode must be one of: auto, yes or no" ); + } + inline void forceColour( ConfigData& config ) { + config.useColour = UseColour::Yes; + } + inline void loadTestNamesFromFile( ConfigData& config, std::string const& _filename ) { + std::ifstream f( _filename.c_str() ); + if( !f.is_open() ) + throw std::domain_error( "Unable to load input file: " + _filename ); + + std::string line; + while( std::getline( f, line ) ) { + line = trim(line); + if( !line.empty() && !startsWith( line, '#' ) ) { + if( !startsWith( line, '"' ) ) + line = '"' + line + '"'; + addTestOrTags( config, line + ',' ); + } + } + } + + inline Clara::CommandLine makeCommandLineParser() { + + using namespace Clara; + CommandLine cli; + + cli.bindProcessName( &ConfigData::processName ); + + cli["-?"]["-h"]["--help"] + .describe( "display usage information" ) + .bind( &ConfigData::showHelp ); + + cli["-l"]["--list-tests"] + .describe( "list all/matching test cases" ) + .bind( &ConfigData::listTests ); + + cli["-t"]["--list-tags"] + .describe( "list all/matching tags" ) + .bind( &ConfigData::listTags ); + + cli["-s"]["--success"] + .describe( "include successful tests in output" ) + .bind( &ConfigData::showSuccessfulTests ); + + cli["-b"]["--break"] + .describe( "break into debugger on failure" ) + .bind( &ConfigData::shouldDebugBreak ); + + cli["-e"]["--nothrow"] + .describe( "skip exception tests" ) + .bind( &ConfigData::noThrow ); + + cli["-i"]["--invisibles"] + .describe( "show invisibles (tabs, newlines)" ) + .bind( &ConfigData::showInvisibles ); + + cli["-o"]["--out"] + .describe( "output filename" ) + .bind( &ConfigData::outputFilename, "filename" ); + + cli["-r"]["--reporter"] +// .placeholder( "name[:filename]" ) + .describe( "reporter to use (defaults to console)" ) + .bind( &addReporterName, "name" ); + + cli["-n"]["--name"] + .describe( "suite name" ) + .bind( &ConfigData::name, "name" ); + + cli["-a"]["--abort"] + .describe( "abort at first failure" ) + .bind( &abortAfterFirst ); + + cli["-x"]["--abortx"] + .describe( "abort after x failures" ) + .bind( &abortAfterX, "no. failures" ); + + cli["-w"]["--warn"] + .describe( "enable warnings" ) + .bind( &addWarning, "warning name" ); + +// - needs updating if reinstated +// cli.into( &setVerbosity ) +// .describe( "level of verbosity (0=no output)" ) +// .shortOpt( "v") +// .longOpt( "verbosity" ) +// .placeholder( "level" ); + + cli[_] + .describe( "which test or tests to use" ) + .bind( &addTestOrTags, "test name, pattern or tags" ); + + cli["-d"]["--durations"] + .describe( "show test durations" ) + .bind( &setShowDurations, "yes|no" ); + + cli["-f"]["--input-file"] + .describe( "load test names to run from a file" ) + .bind( &loadTestNamesFromFile, "filename" ); + + cli["-#"]["--filenames-as-tags"] + .describe( "adds a tag for the filename" ) + .bind( &ConfigData::filenamesAsTags ); + + cli["-c"]["--section"] + .describe( "specify section to run" ) + .bind( &addSectionToRun, "section name" ); + + // Less common commands which don't have a short form + cli["--list-test-names-only"] + .describe( "list all/matching test cases names only" ) + .bind( &ConfigData::listTestNamesOnly ); + + cli["--list-extra-info"] + .describe( "list all/matching test cases with more info" ) + .bind( &ConfigData::listExtraInfo ); + + cli["--list-reporters"] + .describe( "list all reporters" ) + .bind( &ConfigData::listReporters ); + + cli["--order"] + .describe( "test case order (defaults to decl)" ) + .bind( &setOrder, "decl|lex|rand" ); + + cli["--rng-seed"] + .describe( "set a specific seed for random numbers" ) + .bind( &setRngSeed, "'time'|number" ); + + cli["--force-colour"] + .describe( "force colourised output (deprecated)" ) + .bind( &forceColour ); + + cli["--use-colour"] + .describe( "should output be colourised" ) + .bind( &setUseColour, "yes|no" ); + + return cli; + } + +} // end namespace Catch + +// #included from: internal/catch_list.hpp +#define TWOBLUECUBES_CATCH_LIST_HPP_INCLUDED + +// #included from: catch_text.h +#define TWOBLUECUBES_CATCH_TEXT_H_INCLUDED + +#define TBC_TEXT_FORMAT_CONSOLE_WIDTH CATCH_CONFIG_CONSOLE_WIDTH + +#define CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE Catch +// #included from: ../external/tbc_text_format.h +// Only use header guard if we are not using an outer namespace +#ifndef CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE +# ifdef TWOBLUECUBES_TEXT_FORMAT_H_INCLUDED +# ifndef TWOBLUECUBES_TEXT_FORMAT_H_ALREADY_INCLUDED +# define TWOBLUECUBES_TEXT_FORMAT_H_ALREADY_INCLUDED +# endif +# else +# define TWOBLUECUBES_TEXT_FORMAT_H_INCLUDED +# endif +#endif +#ifndef TWOBLUECUBES_TEXT_FORMAT_H_ALREADY_INCLUDED +#include +#include +#include + +// Use optional outer namespace +#ifdef CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE +namespace CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE { +#endif + +namespace Tbc { + +#ifdef TBC_TEXT_FORMAT_CONSOLE_WIDTH + const unsigned int consoleWidth = TBC_TEXT_FORMAT_CONSOLE_WIDTH; +#else + const unsigned int consoleWidth = 80; +#endif + + struct TextAttributes { + TextAttributes() + : initialIndent( std::string::npos ), + indent( 0 ), + width( consoleWidth-1 ) + {} + + TextAttributes& setInitialIndent( std::size_t _value ) { initialIndent = _value; return *this; } + TextAttributes& setIndent( std::size_t _value ) { indent = _value; return *this; } + TextAttributes& setWidth( std::size_t _value ) { width = _value; return *this; } + + std::size_t initialIndent; // indent of first line, or npos + std::size_t indent; // indent of subsequent lines, or all if initialIndent is npos + std::size_t width; // maximum width of text, including indent. Longer text will wrap + }; + + class Text { + public: + Text( std::string const& _str, TextAttributes const& _attr = TextAttributes() ) + : attr( _attr ) + { + const std::string wrappableBeforeChars = "[({<\t"; + const std::string wrappableAfterChars = "])}>-,./|\\"; + const std::string wrappableInsteadOfChars = " \n\r"; + std::string indent = _attr.initialIndent != std::string::npos + ? std::string( _attr.initialIndent, ' ' ) + : std::string( _attr.indent, ' ' ); + + typedef std::string::const_iterator iterator; + iterator it = _str.begin(); + const iterator strEnd = _str.end(); + + while( it != strEnd ) { + + if( lines.size() >= 1000 ) { + lines.push_back( "... message truncated due to excessive size" ); + return; + } + + std::string suffix; + std::size_t width = (std::min)( static_cast( strEnd-it ), _attr.width-static_cast( indent.size() ) ); + iterator itEnd = it+width; + iterator itNext = _str.end(); + + iterator itNewLine = std::find( it, itEnd, '\n' ); + if( itNewLine != itEnd ) + itEnd = itNewLine; + + if( itEnd != strEnd ) { + bool foundWrapPoint = false; + iterator findIt = itEnd; + do { + if( wrappableAfterChars.find( *findIt ) != std::string::npos && findIt != itEnd ) { + itEnd = findIt+1; + itNext = findIt+1; + foundWrapPoint = true; + } + else if( findIt > it && wrappableBeforeChars.find( *findIt ) != std::string::npos ) { + itEnd = findIt; + itNext = findIt; + foundWrapPoint = true; + } + else if( wrappableInsteadOfChars.find( *findIt ) != std::string::npos ) { + itNext = findIt+1; + itEnd = findIt; + foundWrapPoint = true; + } + if( findIt == it ) + break; + else + --findIt; + } + while( !foundWrapPoint ); + + if( !foundWrapPoint ) { + // No good wrap char, so we'll break mid word and add a hyphen + --itEnd; + itNext = itEnd; + suffix = "-"; + } + else { + while( itEnd > it && wrappableInsteadOfChars.find( *(itEnd-1) ) != std::string::npos ) + --itEnd; + } + } + lines.push_back( indent + std::string( it, itEnd ) + suffix ); + + if( indent.size() != _attr.indent ) + indent = std::string( _attr.indent, ' ' ); + it = itNext; + } + } + + typedef std::vector::const_iterator const_iterator; + + const_iterator begin() const { return lines.begin(); } + const_iterator end() const { return lines.end(); } + std::string const& last() const { return lines.back(); } + std::size_t size() const { return lines.size(); } + std::string const& operator[]( std::size_t _index ) const { return lines[_index]; } + std::string toString() const { + std::ostringstream oss; + oss << *this; + return oss.str(); + } + + inline friend std::ostream& operator << ( std::ostream& _stream, Text const& _text ) { + for( Text::const_iterator it = _text.begin(), itEnd = _text.end(); + it != itEnd; ++it ) { + if( it != _text.begin() ) + _stream << "\n"; + _stream << *it; + } + return _stream; + } + + private: + std::string str; + TextAttributes attr; + std::vector lines; + }; + +} // end namespace Tbc + +#ifdef CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE +} // end outer namespace +#endif + +#endif // TWOBLUECUBES_TEXT_FORMAT_H_ALREADY_INCLUDED +#undef CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE + +namespace Catch { + using Tbc::Text; + using Tbc::TextAttributes; +} + +// #included from: catch_console_colour.hpp +#define TWOBLUECUBES_CATCH_CONSOLE_COLOUR_HPP_INCLUDED + +namespace Catch { + + struct Colour { + enum Code { + None = 0, + + White, + Red, + Green, + Blue, + Cyan, + Yellow, + Grey, + + Bright = 0x10, + + BrightRed = Bright | Red, + BrightGreen = Bright | Green, + LightGrey = Bright | Grey, + BrightWhite = Bright | White, + + // By intention + FileName = LightGrey, + Warning = Yellow, + ResultError = BrightRed, + ResultSuccess = BrightGreen, + ResultExpectedFailure = Warning, + + Error = BrightRed, + Success = Green, + + OriginalExpression = Cyan, + ReconstructedExpression = Yellow, + + SecondaryText = LightGrey, + Headers = White + }; + + // Use constructed object for RAII guard + Colour( Code _colourCode ); + Colour( Colour const& other ); + ~Colour(); + + // Use static method for one-shot changes + static void use( Code _colourCode ); + + private: + bool m_moved; + }; + + inline std::ostream& operator << ( std::ostream& os, Colour const& ) { return os; } + +} // end namespace Catch + +// #included from: catch_interfaces_reporter.h +#define TWOBLUECUBES_CATCH_INTERFACES_REPORTER_H_INCLUDED + +#include +#include +#include + +namespace Catch +{ + struct ReporterConfig { + explicit ReporterConfig( Ptr const& _fullConfig ) + : m_stream( &_fullConfig->stream() ), m_fullConfig( _fullConfig ) {} + + ReporterConfig( Ptr const& _fullConfig, std::ostream& _stream ) + : m_stream( &_stream ), m_fullConfig( _fullConfig ) {} + + std::ostream& stream() const { return *m_stream; } + Ptr fullConfig() const { return m_fullConfig; } + + private: + std::ostream* m_stream; + Ptr m_fullConfig; + }; + + struct ReporterPreferences { + ReporterPreferences() + : shouldRedirectStdOut( false ) + {} + + bool shouldRedirectStdOut; + }; + + template + struct LazyStat : Option { + LazyStat() : used( false ) {} + LazyStat& operator=( T const& _value ) { + Option::operator=( _value ); + used = false; + return *this; + } + void reset() { + Option::reset(); + used = false; + } + bool used; + }; + + struct TestRunInfo { + TestRunInfo( std::string const& _name ) : name( _name ) {} + std::string name; + }; + struct GroupInfo { + GroupInfo( std::string const& _name, + std::size_t _groupIndex, + std::size_t _groupsCount ) + : name( _name ), + groupIndex( _groupIndex ), + groupsCounts( _groupsCount ) + {} + + std::string name; + std::size_t groupIndex; + std::size_t groupsCounts; + }; + + struct AssertionStats { + AssertionStats( AssertionResult const& _assertionResult, + std::vector const& _infoMessages, + Totals const& _totals ) + : assertionResult( _assertionResult ), + infoMessages( _infoMessages ), + totals( _totals ) + { + if( assertionResult.hasMessage() ) { + // Copy message into messages list. + // !TBD This should have been done earlier, somewhere + MessageBuilder builder( assertionResult.getTestMacroName(), assertionResult.getSourceInfo(), assertionResult.getResultType() ); + builder << assertionResult.getMessage(); + builder.m_info.message = builder.m_stream.str(); + + infoMessages.push_back( builder.m_info ); + } + } + virtual ~AssertionStats(); + +# ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS + AssertionStats( AssertionStats const& ) = default; + AssertionStats( AssertionStats && ) = default; + AssertionStats& operator = ( AssertionStats const& ) = default; + AssertionStats& operator = ( AssertionStats && ) = default; +# endif + + AssertionResult assertionResult; + std::vector infoMessages; + Totals totals; + }; + + struct SectionStats { + SectionStats( SectionInfo const& _sectionInfo, + Counts const& _assertions, + double _durationInSeconds, + bool _missingAssertions ) + : sectionInfo( _sectionInfo ), + assertions( _assertions ), + durationInSeconds( _durationInSeconds ), + missingAssertions( _missingAssertions ) + {} + virtual ~SectionStats(); +# ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS + SectionStats( SectionStats const& ) = default; + SectionStats( SectionStats && ) = default; + SectionStats& operator = ( SectionStats const& ) = default; + SectionStats& operator = ( SectionStats && ) = default; +# endif + + SectionInfo sectionInfo; + Counts assertions; + double durationInSeconds; + bool missingAssertions; + }; + + struct TestCaseStats { + TestCaseStats( TestCaseInfo const& _testInfo, + Totals const& _totals, + std::string const& _stdOut, + std::string const& _stdErr, + bool _aborting ) + : testInfo( _testInfo ), + totals( _totals ), + stdOut( _stdOut ), + stdErr( _stdErr ), + aborting( _aborting ) + {} + virtual ~TestCaseStats(); + +# ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS + TestCaseStats( TestCaseStats const& ) = default; + TestCaseStats( TestCaseStats && ) = default; + TestCaseStats& operator = ( TestCaseStats const& ) = default; + TestCaseStats& operator = ( TestCaseStats && ) = default; +# endif + + TestCaseInfo testInfo; + Totals totals; + std::string stdOut; + std::string stdErr; + bool aborting; + }; + + struct TestGroupStats { + TestGroupStats( GroupInfo const& _groupInfo, + Totals const& _totals, + bool _aborting ) + : groupInfo( _groupInfo ), + totals( _totals ), + aborting( _aborting ) + {} + TestGroupStats( GroupInfo const& _groupInfo ) + : groupInfo( _groupInfo ), + aborting( false ) + {} + virtual ~TestGroupStats(); + +# ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS + TestGroupStats( TestGroupStats const& ) = default; + TestGroupStats( TestGroupStats && ) = default; + TestGroupStats& operator = ( TestGroupStats const& ) = default; + TestGroupStats& operator = ( TestGroupStats && ) = default; +# endif + + GroupInfo groupInfo; + Totals totals; + bool aborting; + }; + + struct TestRunStats { + TestRunStats( TestRunInfo const& _runInfo, + Totals const& _totals, + bool _aborting ) + : runInfo( _runInfo ), + totals( _totals ), + aborting( _aborting ) + {} + virtual ~TestRunStats(); + +# ifndef CATCH_CONFIG_CPP11_GENERATED_METHODS + TestRunStats( TestRunStats const& _other ) + : runInfo( _other.runInfo ), + totals( _other.totals ), + aborting( _other.aborting ) + {} +# else + TestRunStats( TestRunStats const& ) = default; + TestRunStats( TestRunStats && ) = default; + TestRunStats& operator = ( TestRunStats const& ) = default; + TestRunStats& operator = ( TestRunStats && ) = default; +# endif + + TestRunInfo runInfo; + Totals totals; + bool aborting; + }; + + class MultipleReporters; + + struct IStreamingReporter : IShared { + virtual ~IStreamingReporter(); + + // Implementing class must also provide the following static method: + // static std::string getDescription(); + + virtual ReporterPreferences getPreferences() const = 0; + + virtual void noMatchingTestCases( std::string const& spec ) = 0; + + virtual void testRunStarting( TestRunInfo const& testRunInfo ) = 0; + virtual void testGroupStarting( GroupInfo const& groupInfo ) = 0; + + virtual void testCaseStarting( TestCaseInfo const& testInfo ) = 0; + virtual void sectionStarting( SectionInfo const& sectionInfo ) = 0; + + virtual void assertionStarting( AssertionInfo const& assertionInfo ) = 0; + + // The return value indicates if the messages buffer should be cleared: + virtual bool assertionEnded( AssertionStats const& assertionStats ) = 0; + + virtual void sectionEnded( SectionStats const& sectionStats ) = 0; + virtual void testCaseEnded( TestCaseStats const& testCaseStats ) = 0; + virtual void testGroupEnded( TestGroupStats const& testGroupStats ) = 0; + virtual void testRunEnded( TestRunStats const& testRunStats ) = 0; + + virtual void skipTest( TestCaseInfo const& testInfo ) = 0; + + virtual MultipleReporters* tryAsMulti() { return CATCH_NULL; } + }; + + struct IReporterFactory : IShared { + virtual ~IReporterFactory(); + virtual IStreamingReporter* create( ReporterConfig const& config ) const = 0; + virtual std::string getDescription() const = 0; + }; + + struct IReporterRegistry { + typedef std::map > FactoryMap; + typedef std::vector > Listeners; + + virtual ~IReporterRegistry(); + virtual IStreamingReporter* create( std::string const& name, Ptr const& config ) const = 0; + virtual FactoryMap const& getFactories() const = 0; + virtual Listeners const& getListeners() const = 0; + }; + + Ptr addReporter( Ptr const& existingReporter, Ptr const& additionalReporter ); + +} + +#include +#include + +namespace Catch { + + inline std::size_t listTests( Config const& config ) { + + TestSpec testSpec = config.testSpec(); + if( config.testSpec().hasFilters() ) + Catch::cout() << "Matching test cases:\n"; + else { + Catch::cout() << "All available test cases:\n"; + testSpec = TestSpecParser( ITagAliasRegistry::get() ).parse( "*" ).testSpec(); + } + + std::size_t matchedTests = 0; + TextAttributes nameAttr, descAttr, tagsAttr; + nameAttr.setInitialIndent( 2 ).setIndent( 4 ); + descAttr.setIndent( 4 ); + tagsAttr.setIndent( 6 ); + + std::vector matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config ); + for( std::vector::const_iterator it = matchedTestCases.begin(), itEnd = matchedTestCases.end(); + it != itEnd; + ++it ) { + matchedTests++; + TestCaseInfo const& testCaseInfo = it->getTestCaseInfo(); + Colour::Code colour = testCaseInfo.isHidden() + ? Colour::SecondaryText + : Colour::None; + Colour colourGuard( colour ); + + Catch::cout() << Text( testCaseInfo.name, nameAttr ) << std::endl; + if( config.listExtraInfo() ) { + Catch::cout() << " " << testCaseInfo.lineInfo << std::endl; + std::string description = testCaseInfo.description; + if( description.empty() ) + description = "(NO DESCRIPTION)"; + Catch::cout() << Text( description, descAttr ) << std::endl; + } + if( !testCaseInfo.tags.empty() ) + Catch::cout() << Text( testCaseInfo.tagsAsString, tagsAttr ) << std::endl; + } + + if( !config.testSpec().hasFilters() ) + Catch::cout() << pluralise( matchedTests, "test case" ) << '\n' << std::endl; + else + Catch::cout() << pluralise( matchedTests, "matching test case" ) << '\n' << std::endl; + return matchedTests; + } + + inline std::size_t listTestsNamesOnly( Config const& config ) { + TestSpec testSpec = config.testSpec(); + if( !config.testSpec().hasFilters() ) + testSpec = TestSpecParser( ITagAliasRegistry::get() ).parse( "*" ).testSpec(); + std::size_t matchedTests = 0; + std::vector matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config ); + for( std::vector::const_iterator it = matchedTestCases.begin(), itEnd = matchedTestCases.end(); + it != itEnd; + ++it ) { + matchedTests++; + TestCaseInfo const& testCaseInfo = it->getTestCaseInfo(); + if( startsWith( testCaseInfo.name, '#' ) ) + Catch::cout() << '"' << testCaseInfo.name << '"'; + else + Catch::cout() << testCaseInfo.name; + if ( config.listExtraInfo() ) + Catch::cout() << "\t@" << testCaseInfo.lineInfo; + Catch::cout() << std::endl; + } + return matchedTests; + } + + struct TagInfo { + TagInfo() : count ( 0 ) {} + void add( std::string const& spelling ) { + ++count; + spellings.insert( spelling ); + } + std::string all() const { + std::string out; + for( std::set::const_iterator it = spellings.begin(), itEnd = spellings.end(); + it != itEnd; + ++it ) + out += "[" + *it + "]"; + return out; + } + std::set spellings; + std::size_t count; + }; + + inline std::size_t listTags( Config const& config ) { + TestSpec testSpec = config.testSpec(); + if( config.testSpec().hasFilters() ) + Catch::cout() << "Tags for matching test cases:\n"; + else { + Catch::cout() << "All available tags:\n"; + testSpec = TestSpecParser( ITagAliasRegistry::get() ).parse( "*" ).testSpec(); + } + + std::map tagCounts; + + std::vector matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config ); + for( std::vector::const_iterator it = matchedTestCases.begin(), itEnd = matchedTestCases.end(); + it != itEnd; + ++it ) { + for( std::set::const_iterator tagIt = it->getTestCaseInfo().tags.begin(), + tagItEnd = it->getTestCaseInfo().tags.end(); + tagIt != tagItEnd; + ++tagIt ) { + std::string tagName = *tagIt; + std::string lcaseTagName = toLower( tagName ); + std::map::iterator countIt = tagCounts.find( lcaseTagName ); + if( countIt == tagCounts.end() ) + countIt = tagCounts.insert( std::make_pair( lcaseTagName, TagInfo() ) ).first; + countIt->second.add( tagName ); + } + } + + for( std::map::const_iterator countIt = tagCounts.begin(), + countItEnd = tagCounts.end(); + countIt != countItEnd; + ++countIt ) { + std::ostringstream oss; + oss << " " << std::setw(2) << countIt->second.count << " "; + Text wrapper( countIt->second.all(), TextAttributes() + .setInitialIndent( 0 ) + .setIndent( oss.str().size() ) + .setWidth( CATCH_CONFIG_CONSOLE_WIDTH-10 ) ); + Catch::cout() << oss.str() << wrapper << '\n'; + } + Catch::cout() << pluralise( tagCounts.size(), "tag" ) << '\n' << std::endl; + return tagCounts.size(); + } + + inline std::size_t listReporters( Config const& /*config*/ ) { + Catch::cout() << "Available reporters:\n"; + IReporterRegistry::FactoryMap const& factories = getRegistryHub().getReporterRegistry().getFactories(); + IReporterRegistry::FactoryMap::const_iterator itBegin = factories.begin(), itEnd = factories.end(), it; + std::size_t maxNameLen = 0; + for(it = itBegin; it != itEnd; ++it ) + maxNameLen = (std::max)( maxNameLen, it->first.size() ); + + for(it = itBegin; it != itEnd; ++it ) { + Text wrapper( it->second->getDescription(), TextAttributes() + .setInitialIndent( 0 ) + .setIndent( 7+maxNameLen ) + .setWidth( CATCH_CONFIG_CONSOLE_WIDTH - maxNameLen-8 ) ); + Catch::cout() << " " + << it->first + << ':' + << std::string( maxNameLen - it->first.size() + 2, ' ' ) + << wrapper << '\n'; + } + Catch::cout() << std::endl; + return factories.size(); + } + + inline Option list( Config const& config ) { + Option listedCount; + if( config.listTests() || ( config.listExtraInfo() && !config.listTestNamesOnly() ) ) + listedCount = listedCount.valueOr(0) + listTests( config ); + if( config.listTestNamesOnly() ) + listedCount = listedCount.valueOr(0) + listTestsNamesOnly( config ); + if( config.listTags() ) + listedCount = listedCount.valueOr(0) + listTags( config ); + if( config.listReporters() ) + listedCount = listedCount.valueOr(0) + listReporters( config ); + return listedCount; + } + +} // end namespace Catch + +// #included from: internal/catch_run_context.hpp +#define TWOBLUECUBES_CATCH_RUNNER_IMPL_HPP_INCLUDED + +// #included from: catch_test_case_tracker.hpp +#define TWOBLUECUBES_CATCH_TEST_CASE_TRACKER_HPP_INCLUDED + +#include +#include +#include +#include +#include + +CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS + +namespace Catch { +namespace TestCaseTracking { + + struct NameAndLocation { + std::string name; + SourceLineInfo location; + + NameAndLocation( std::string const& _name, SourceLineInfo const& _location ) + : name( _name ), + location( _location ) + {} + }; + + struct ITracker : SharedImpl<> { + virtual ~ITracker(); + + // static queries + virtual NameAndLocation const& nameAndLocation() const = 0; + + // dynamic queries + virtual bool isComplete() const = 0; // Successfully completed or failed + virtual bool isSuccessfullyCompleted() const = 0; + virtual bool isOpen() const = 0; // Started but not complete + virtual bool hasChildren() const = 0; + + virtual ITracker& parent() = 0; + + // actions + virtual void close() = 0; // Successfully complete + virtual void fail() = 0; + virtual void markAsNeedingAnotherRun() = 0; + + virtual void addChild( Ptr const& child ) = 0; + virtual ITracker* findChild( NameAndLocation const& nameAndLocation ) = 0; + virtual void openChild() = 0; + + // Debug/ checking + virtual bool isSectionTracker() const = 0; + virtual bool isIndexTracker() const = 0; + }; + + class TrackerContext { + + enum RunState { + NotStarted, + Executing, + CompletedCycle + }; + + Ptr m_rootTracker; + ITracker* m_currentTracker; + RunState m_runState; + + public: + + static TrackerContext& instance() { + static TrackerContext s_instance; + return s_instance; + } + + TrackerContext() + : m_currentTracker( CATCH_NULL ), + m_runState( NotStarted ) + {} + + ITracker& startRun(); + + void endRun() { + m_rootTracker.reset(); + m_currentTracker = CATCH_NULL; + m_runState = NotStarted; + } + + void startCycle() { + m_currentTracker = m_rootTracker.get(); + m_runState = Executing; + } + void completeCycle() { + m_runState = CompletedCycle; + } + + bool completedCycle() const { + return m_runState == CompletedCycle; + } + ITracker& currentTracker() { + return *m_currentTracker; + } + void setCurrentTracker( ITracker* tracker ) { + m_currentTracker = tracker; + } + }; + + class TrackerBase : public ITracker { + protected: + enum CycleState { + NotStarted, + Executing, + ExecutingChildren, + NeedsAnotherRun, + CompletedSuccessfully, + Failed + }; + class TrackerHasName { + NameAndLocation m_nameAndLocation; + public: + TrackerHasName( NameAndLocation const& nameAndLocation ) : m_nameAndLocation( nameAndLocation ) {} + bool operator ()( Ptr const& tracker ) { + return + tracker->nameAndLocation().name == m_nameAndLocation.name && + tracker->nameAndLocation().location == m_nameAndLocation.location; + } + }; + typedef std::vector > Children; + NameAndLocation m_nameAndLocation; + TrackerContext& m_ctx; + ITracker* m_parent; + Children m_children; + CycleState m_runState; + public: + TrackerBase( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent ) + : m_nameAndLocation( nameAndLocation ), + m_ctx( ctx ), + m_parent( parent ), + m_runState( NotStarted ) + {} + virtual ~TrackerBase(); + + virtual NameAndLocation const& nameAndLocation() const CATCH_OVERRIDE { + return m_nameAndLocation; + } + virtual bool isComplete() const CATCH_OVERRIDE { + return m_runState == CompletedSuccessfully || m_runState == Failed; + } + virtual bool isSuccessfullyCompleted() const CATCH_OVERRIDE { + return m_runState == CompletedSuccessfully; + } + virtual bool isOpen() const CATCH_OVERRIDE { + return m_runState != NotStarted && !isComplete(); + } + virtual bool hasChildren() const CATCH_OVERRIDE { + return !m_children.empty(); + } + + virtual void addChild( Ptr const& child ) CATCH_OVERRIDE { + m_children.push_back( child ); + } + + virtual ITracker* findChild( NameAndLocation const& nameAndLocation ) CATCH_OVERRIDE { + Children::const_iterator it = std::find_if( m_children.begin(), m_children.end(), TrackerHasName( nameAndLocation ) ); + return( it != m_children.end() ) + ? it->get() + : CATCH_NULL; + } + virtual ITracker& parent() CATCH_OVERRIDE { + assert( m_parent ); // Should always be non-null except for root + return *m_parent; + } + + virtual void openChild() CATCH_OVERRIDE { + if( m_runState != ExecutingChildren ) { + m_runState = ExecutingChildren; + if( m_parent ) + m_parent->openChild(); + } + } + + virtual bool isSectionTracker() const CATCH_OVERRIDE { return false; } + virtual bool isIndexTracker() const CATCH_OVERRIDE { return false; } + + void open() { + m_runState = Executing; + moveToThis(); + if( m_parent ) + m_parent->openChild(); + } + + virtual void close() CATCH_OVERRIDE { + + // Close any still open children (e.g. generators) + while( &m_ctx.currentTracker() != this ) + m_ctx.currentTracker().close(); + + switch( m_runState ) { + case NotStarted: + case CompletedSuccessfully: + case Failed: + throw std::logic_error( "Illogical state" ); + + case NeedsAnotherRun: + break;; + + case Executing: + m_runState = CompletedSuccessfully; + break; + case ExecutingChildren: + if( m_children.empty() || m_children.back()->isComplete() ) + m_runState = CompletedSuccessfully; + break; + + default: + throw std::logic_error( "Unexpected state" ); + } + moveToParent(); + m_ctx.completeCycle(); + } + virtual void fail() CATCH_OVERRIDE { + m_runState = Failed; + if( m_parent ) + m_parent->markAsNeedingAnotherRun(); + moveToParent(); + m_ctx.completeCycle(); + } + virtual void markAsNeedingAnotherRun() CATCH_OVERRIDE { + m_runState = NeedsAnotherRun; + } + private: + void moveToParent() { + assert( m_parent ); + m_ctx.setCurrentTracker( m_parent ); + } + void moveToThis() { + m_ctx.setCurrentTracker( this ); + } + }; + + class SectionTracker : public TrackerBase { + std::vector m_filters; + public: + SectionTracker( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent ) + : TrackerBase( nameAndLocation, ctx, parent ) + { + if( parent ) { + while( !parent->isSectionTracker() ) + parent = &parent->parent(); + + SectionTracker& parentSection = static_cast( *parent ); + addNextFilters( parentSection.m_filters ); + } + } + virtual ~SectionTracker(); + + virtual bool isSectionTracker() const CATCH_OVERRIDE { return true; } + + static SectionTracker& acquire( TrackerContext& ctx, NameAndLocation const& nameAndLocation ) { + SectionTracker* section = CATCH_NULL; + + ITracker& currentTracker = ctx.currentTracker(); + if( ITracker* childTracker = currentTracker.findChild( nameAndLocation ) ) { + assert( childTracker ); + assert( childTracker->isSectionTracker() ); + section = static_cast( childTracker ); + } + else { + section = new SectionTracker( nameAndLocation, ctx, ¤tTracker ); + currentTracker.addChild( section ); + } + if( !ctx.completedCycle() ) + section->tryOpen(); + return *section; + } + + void tryOpen() { + if( !isComplete() && (m_filters.empty() || m_filters[0].empty() || m_filters[0] == m_nameAndLocation.name ) ) + open(); + } + + void addInitialFilters( std::vector const& filters ) { + if( !filters.empty() ) { + m_filters.push_back(""); // Root - should never be consulted + m_filters.push_back(""); // Test Case - not a section filter + m_filters.insert( m_filters.end(), filters.begin(), filters.end() ); + } + } + void addNextFilters( std::vector const& filters ) { + if( filters.size() > 1 ) + m_filters.insert( m_filters.end(), ++filters.begin(), filters.end() ); + } + }; + + class IndexTracker : public TrackerBase { + int m_size; + int m_index; + public: + IndexTracker( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent, int size ) + : TrackerBase( nameAndLocation, ctx, parent ), + m_size( size ), + m_index( -1 ) + {} + virtual ~IndexTracker(); + + virtual bool isIndexTracker() const CATCH_OVERRIDE { return true; } + + static IndexTracker& acquire( TrackerContext& ctx, NameAndLocation const& nameAndLocation, int size ) { + IndexTracker* tracker = CATCH_NULL; + + ITracker& currentTracker = ctx.currentTracker(); + if( ITracker* childTracker = currentTracker.findChild( nameAndLocation ) ) { + assert( childTracker ); + assert( childTracker->isIndexTracker() ); + tracker = static_cast( childTracker ); + } + else { + tracker = new IndexTracker( nameAndLocation, ctx, ¤tTracker, size ); + currentTracker.addChild( tracker ); + } + + if( !ctx.completedCycle() && !tracker->isComplete() ) { + if( tracker->m_runState != ExecutingChildren && tracker->m_runState != NeedsAnotherRun ) + tracker->moveNext(); + tracker->open(); + } + + return *tracker; + } + + int index() const { return m_index; } + + void moveNext() { + m_index++; + m_children.clear(); + } + + virtual void close() CATCH_OVERRIDE { + TrackerBase::close(); + if( m_runState == CompletedSuccessfully && m_index < m_size-1 ) + m_runState = Executing; + } + }; + + inline ITracker& TrackerContext::startRun() { + m_rootTracker = new SectionTracker( NameAndLocation( "{root}", CATCH_INTERNAL_LINEINFO ), *this, CATCH_NULL ); + m_currentTracker = CATCH_NULL; + m_runState = Executing; + return *m_rootTracker; + } + +} // namespace TestCaseTracking + +using TestCaseTracking::ITracker; +using TestCaseTracking::TrackerContext; +using TestCaseTracking::SectionTracker; +using TestCaseTracking::IndexTracker; + +} // namespace Catch + +CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS + +// #included from: catch_fatal_condition.hpp +#define TWOBLUECUBES_CATCH_FATAL_CONDITION_H_INCLUDED + +namespace Catch { + + // Report the error condition + inline void reportFatal( std::string const& message ) { + IContext& context = Catch::getCurrentContext(); + IResultCapture* resultCapture = context.getResultCapture(); + resultCapture->handleFatalErrorCondition( message ); + } + +} // namespace Catch + +#if defined ( CATCH_PLATFORM_WINDOWS ) ///////////////////////////////////////// +// #included from: catch_windows_h_proxy.h + +#define TWOBLUECUBES_CATCH_WINDOWS_H_PROXY_H_INCLUDED + +#ifdef CATCH_DEFINES_NOMINMAX +# define NOMINMAX +#endif +#ifdef CATCH_DEFINES_WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +#endif + +#ifdef __AFXDLL +#include +#else +#include +#endif + +#ifdef CATCH_DEFINES_NOMINMAX +# undef NOMINMAX +#endif +#ifdef CATCH_DEFINES_WIN32_LEAN_AND_MEAN +# undef WIN32_LEAN_AND_MEAN +#endif + + +# if !defined ( CATCH_CONFIG_WINDOWS_SEH ) + +namespace Catch { + struct FatalConditionHandler { + void reset() {} + }; +} + +# else // CATCH_CONFIG_WINDOWS_SEH is defined + +namespace Catch { + + struct SignalDefs { DWORD id; const char* name; }; + extern SignalDefs signalDefs[]; + // There is no 1-1 mapping between signals and windows exceptions. + // Windows can easily distinguish between SO and SigSegV, + // but SigInt, SigTerm, etc are handled differently. + SignalDefs signalDefs[] = { + { EXCEPTION_ILLEGAL_INSTRUCTION, "SIGILL - Illegal instruction signal" }, + { EXCEPTION_STACK_OVERFLOW, "SIGSEGV - Stack overflow" }, + { EXCEPTION_ACCESS_VIOLATION, "SIGSEGV - Segmentation violation signal" }, + { EXCEPTION_INT_DIVIDE_BY_ZERO, "Divide by zero error" }, + }; + + struct FatalConditionHandler { + + static LONG CALLBACK handleVectoredException(PEXCEPTION_POINTERS ExceptionInfo) { + for (int i = 0; i < sizeof(signalDefs) / sizeof(SignalDefs); ++i) { + if (ExceptionInfo->ExceptionRecord->ExceptionCode == signalDefs[i].id) { + reportFatal(signalDefs[i].name); + } + } + // If its not an exception we care about, pass it along. + // This stops us from eating debugger breaks etc. + return EXCEPTION_CONTINUE_SEARCH; + } + + FatalConditionHandler() { + isSet = true; + // 32k seems enough for Catch to handle stack overflow, + // but the value was found experimentally, so there is no strong guarantee + guaranteeSize = 32 * 1024; + exceptionHandlerHandle = CATCH_NULL; + // Register as first handler in current chain + exceptionHandlerHandle = AddVectoredExceptionHandler(1, handleVectoredException); + // Pass in guarantee size to be filled + SetThreadStackGuarantee(&guaranteeSize); + } + + static void reset() { + if (isSet) { + // Unregister handler and restore the old guarantee + RemoveVectoredExceptionHandler(exceptionHandlerHandle); + SetThreadStackGuarantee(&guaranteeSize); + exceptionHandlerHandle = CATCH_NULL; + isSet = false; + } + } + + ~FatalConditionHandler() { + reset(); + } + private: + static bool isSet; + static ULONG guaranteeSize; + static PVOID exceptionHandlerHandle; + }; + + bool FatalConditionHandler::isSet = false; + ULONG FatalConditionHandler::guaranteeSize = 0; + PVOID FatalConditionHandler::exceptionHandlerHandle = CATCH_NULL; + +} // namespace Catch + +# endif // CATCH_CONFIG_WINDOWS_SEH + +#else // Not Windows - assumed to be POSIX compatible ////////////////////////// + +# if !defined(CATCH_CONFIG_POSIX_SIGNALS) + +namespace Catch { + struct FatalConditionHandler { + void reset() {} + }; +} + +# else // CATCH_CONFIG_POSIX_SIGNALS is defined + +#include + +namespace Catch { + + struct SignalDefs { + int id; + const char* name; + }; + extern SignalDefs signalDefs[]; + SignalDefs signalDefs[] = { + { SIGINT, "SIGINT - Terminal interrupt signal" }, + { SIGILL, "SIGILL - Illegal instruction signal" }, + { SIGFPE, "SIGFPE - Floating point error signal" }, + { SIGSEGV, "SIGSEGV - Segmentation violation signal" }, + { SIGTERM, "SIGTERM - Termination request signal" }, + { SIGABRT, "SIGABRT - Abort (abnormal termination) signal" } + }; + + struct FatalConditionHandler { + + static bool isSet; + static struct sigaction oldSigActions [sizeof(signalDefs)/sizeof(SignalDefs)]; + static stack_t oldSigStack; + static char altStackMem[SIGSTKSZ]; + + static void handleSignal( int sig ) { + std::string name = ""; + for (std::size_t i = 0; i < sizeof(signalDefs) / sizeof(SignalDefs); ++i) { + SignalDefs &def = signalDefs[i]; + if (sig == def.id) { + name = def.name; + break; + } + } + reset(); + reportFatal(name); + raise( sig ); + } + + FatalConditionHandler() { + isSet = true; + stack_t sigStack; + sigStack.ss_sp = altStackMem; + sigStack.ss_size = SIGSTKSZ; + sigStack.ss_flags = 0; + sigaltstack(&sigStack, &oldSigStack); + struct sigaction sa = { 0 }; + + sa.sa_handler = handleSignal; + sa.sa_flags = SA_ONSTACK; + for (std::size_t i = 0; i < sizeof(signalDefs)/sizeof(SignalDefs); ++i) { + sigaction(signalDefs[i].id, &sa, &oldSigActions[i]); + } + } + + ~FatalConditionHandler() { + reset(); + } + static void reset() { + if( isSet ) { + // Set signals back to previous values -- hopefully nobody overwrote them in the meantime + for( std::size_t i = 0; i < sizeof(signalDefs)/sizeof(SignalDefs); ++i ) { + sigaction(signalDefs[i].id, &oldSigActions[i], CATCH_NULL); + } + // Return the old stack + sigaltstack(&oldSigStack, CATCH_NULL); + isSet = false; + } + } + }; + + bool FatalConditionHandler::isSet = false; + struct sigaction FatalConditionHandler::oldSigActions[sizeof(signalDefs)/sizeof(SignalDefs)] = {}; + stack_t FatalConditionHandler::oldSigStack = {}; + char FatalConditionHandler::altStackMem[SIGSTKSZ] = {}; + +} // namespace Catch + +# endif // CATCH_CONFIG_POSIX_SIGNALS + +#endif // not Windows + +#include +#include + +namespace Catch { + + class StreamRedirect { + + public: + StreamRedirect( std::ostream& stream, std::string& targetString ) + : m_stream( stream ), + m_prevBuf( stream.rdbuf() ), + m_targetString( targetString ) + { + stream.rdbuf( m_oss.rdbuf() ); + } + + ~StreamRedirect() { + m_targetString += m_oss.str(); + m_stream.rdbuf( m_prevBuf ); + } + + private: + std::ostream& m_stream; + std::streambuf* m_prevBuf; + std::ostringstream m_oss; + std::string& m_targetString; + }; + + /////////////////////////////////////////////////////////////////////////// + + class RunContext : public IResultCapture, public IRunner { + + RunContext( RunContext const& ); + void operator =( RunContext const& ); + + public: + + explicit RunContext( Ptr const& _config, Ptr const& reporter ) + : m_runInfo( _config->name() ), + m_context( getCurrentMutableContext() ), + m_activeTestCase( CATCH_NULL ), + m_config( _config ), + m_reporter( reporter ), + m_shouldReportUnexpected ( true ) + { + m_context.setRunner( this ); + m_context.setConfig( m_config ); + m_context.setResultCapture( this ); + m_reporter->testRunStarting( m_runInfo ); + } + + virtual ~RunContext() { + m_reporter->testRunEnded( TestRunStats( m_runInfo, m_totals, aborting() ) ); + } + + void testGroupStarting( std::string const& testSpec, std::size_t groupIndex, std::size_t groupsCount ) { + m_reporter->testGroupStarting( GroupInfo( testSpec, groupIndex, groupsCount ) ); + } + void testGroupEnded( std::string const& testSpec, Totals const& totals, std::size_t groupIndex, std::size_t groupsCount ) { + m_reporter->testGroupEnded( TestGroupStats( GroupInfo( testSpec, groupIndex, groupsCount ), totals, aborting() ) ); + } + + Totals runTest( TestCase const& testCase ) { + Totals prevTotals = m_totals; + + std::string redirectedCout; + std::string redirectedCerr; + + TestCaseInfo testInfo = testCase.getTestCaseInfo(); + + m_reporter->testCaseStarting( testInfo ); + + m_activeTestCase = &testCase; + + do { + ITracker& rootTracker = m_trackerContext.startRun(); + assert( rootTracker.isSectionTracker() ); + static_cast( rootTracker ).addInitialFilters( m_config->getSectionsToRun() ); + do { + m_trackerContext.startCycle(); + m_testCaseTracker = &SectionTracker::acquire( m_trackerContext, TestCaseTracking::NameAndLocation( testInfo.name, testInfo.lineInfo ) ); + runCurrentTest( redirectedCout, redirectedCerr ); + } + while( !m_testCaseTracker->isSuccessfullyCompleted() && !aborting() ); + } + // !TBD: deprecated - this will be replaced by indexed trackers + while( getCurrentContext().advanceGeneratorsForCurrentTest() && !aborting() ); + + Totals deltaTotals = m_totals.delta( prevTotals ); + if( testInfo.expectedToFail() && deltaTotals.testCases.passed > 0 ) { + deltaTotals.assertions.failed++; + deltaTotals.testCases.passed--; + deltaTotals.testCases.failed++; + } + m_totals.testCases += deltaTotals.testCases; + m_reporter->testCaseEnded( TestCaseStats( testInfo, + deltaTotals, + redirectedCout, + redirectedCerr, + aborting() ) ); + + m_activeTestCase = CATCH_NULL; + m_testCaseTracker = CATCH_NULL; + + return deltaTotals; + } + + Ptr config() const { + return m_config; + } + + private: // IResultCapture + + virtual void assertionEnded( AssertionResult const& result ) { + if( result.getResultType() == ResultWas::Ok ) { + m_totals.assertions.passed++; + } + else if( !result.isOk() ) { + m_totals.assertions.failed++; + } + + // We have no use for the return value (whether messages should be cleared), because messages were made scoped + // and should be let to clear themselves out. + static_cast(m_reporter->assertionEnded(AssertionStats(result, m_messages, m_totals))); + + // Reset working state + m_lastAssertionInfo = AssertionInfo( "", m_lastAssertionInfo.lineInfo, "{Unknown expression after the reported line}" , m_lastAssertionInfo.resultDisposition ); + m_lastResult = result; + } + + virtual bool sectionStarted ( + SectionInfo const& sectionInfo, + Counts& assertions + ) + { + ITracker& sectionTracker = SectionTracker::acquire( m_trackerContext, TestCaseTracking::NameAndLocation( sectionInfo.name, sectionInfo.lineInfo ) ); + if( !sectionTracker.isOpen() ) + return false; + m_activeSections.push_back( §ionTracker ); + + m_lastAssertionInfo.lineInfo = sectionInfo.lineInfo; + + m_reporter->sectionStarting( sectionInfo ); + + assertions = m_totals.assertions; + + return true; + } + bool testForMissingAssertions( Counts& assertions ) { + if( assertions.total() != 0 ) + return false; + if( !m_config->warnAboutMissingAssertions() ) + return false; + if( m_trackerContext.currentTracker().hasChildren() ) + return false; + m_totals.assertions.failed++; + assertions.failed++; + return true; + } + + virtual void sectionEnded( SectionEndInfo const& endInfo ) { + Counts assertions = m_totals.assertions - endInfo.prevAssertions; + bool missingAssertions = testForMissingAssertions( assertions ); + + if( !m_activeSections.empty() ) { + m_activeSections.back()->close(); + m_activeSections.pop_back(); + } + + m_reporter->sectionEnded( SectionStats( endInfo.sectionInfo, assertions, endInfo.durationInSeconds, missingAssertions ) ); + m_messages.clear(); + } + + virtual void sectionEndedEarly( SectionEndInfo const& endInfo ) { + if( m_unfinishedSections.empty() ) + m_activeSections.back()->fail(); + else + m_activeSections.back()->close(); + m_activeSections.pop_back(); + + m_unfinishedSections.push_back( endInfo ); + } + + virtual void pushScopedMessage( MessageInfo const& message ) { + m_messages.push_back( message ); + } + + virtual void popScopedMessage( MessageInfo const& message ) { + m_messages.erase( std::remove( m_messages.begin(), m_messages.end(), message ), m_messages.end() ); + } + + virtual std::string getCurrentTestName() const { + return m_activeTestCase + ? m_activeTestCase->getTestCaseInfo().name + : std::string(); + } + + virtual const AssertionResult* getLastResult() const { + return &m_lastResult; + } + + virtual void exceptionEarlyReported() { + m_shouldReportUnexpected = false; + } + + virtual void handleFatalErrorCondition( std::string const& message ) { + // Don't rebuild the result -- the stringification itself can cause more fatal errors + // Instead, fake a result data. + AssertionResultData tempResult; + tempResult.resultType = ResultWas::FatalErrorCondition; + tempResult.message = message; + AssertionResult result(m_lastAssertionInfo, tempResult); + + getResultCapture().assertionEnded(result); + + handleUnfinishedSections(); + + // Recreate section for test case (as we will lose the one that was in scope) + TestCaseInfo const& testCaseInfo = m_activeTestCase->getTestCaseInfo(); + SectionInfo testCaseSection( testCaseInfo.lineInfo, testCaseInfo.name, testCaseInfo.description ); + + Counts assertions; + assertions.failed = 1; + SectionStats testCaseSectionStats( testCaseSection, assertions, 0, false ); + m_reporter->sectionEnded( testCaseSectionStats ); + + TestCaseInfo testInfo = m_activeTestCase->getTestCaseInfo(); + + Totals deltaTotals; + deltaTotals.testCases.failed = 1; + m_reporter->testCaseEnded( TestCaseStats( testInfo, + deltaTotals, + std::string(), + std::string(), + false ) ); + m_totals.testCases.failed++; + testGroupEnded( std::string(), m_totals, 1, 1 ); + m_reporter->testRunEnded( TestRunStats( m_runInfo, m_totals, false ) ); + } + + public: + // !TBD We need to do this another way! + bool aborting() const { + return m_totals.assertions.failed == static_cast( m_config->abortAfter() ); + } + + private: + + void runCurrentTest( std::string& redirectedCout, std::string& redirectedCerr ) { + TestCaseInfo const& testCaseInfo = m_activeTestCase->getTestCaseInfo(); + SectionInfo testCaseSection( testCaseInfo.lineInfo, testCaseInfo.name, testCaseInfo.description ); + m_reporter->sectionStarting( testCaseSection ); + Counts prevAssertions = m_totals.assertions; + double duration = 0; + m_shouldReportUnexpected = true; + try { + m_lastAssertionInfo = AssertionInfo( "TEST_CASE", testCaseInfo.lineInfo, "", ResultDisposition::Normal ); + + seedRng( *m_config ); + + Timer timer; + timer.start(); + if( m_reporter->getPreferences().shouldRedirectStdOut ) { + StreamRedirect coutRedir( Catch::cout(), redirectedCout ); + StreamRedirect cerrRedir( Catch::cerr(), redirectedCerr ); + invokeActiveTestCase(); + } + else { + invokeActiveTestCase(); + } + duration = timer.getElapsedSeconds(); + } + catch( TestFailureException& ) { + // This just means the test was aborted due to failure + } + catch(...) { + // Under CATCH_CONFIG_FAST_COMPILE, unexpected exceptions under REQUIRE assertions + // are reported without translation at the point of origin. + if (m_shouldReportUnexpected) { + makeUnexpectedResultBuilder().useActiveException(); + } + } + m_testCaseTracker->close(); + handleUnfinishedSections(); + m_messages.clear(); + + Counts assertions = m_totals.assertions - prevAssertions; + bool missingAssertions = testForMissingAssertions( assertions ); + + if( testCaseInfo.okToFail() ) { + std::swap( assertions.failedButOk, assertions.failed ); + m_totals.assertions.failed -= assertions.failedButOk; + m_totals.assertions.failedButOk += assertions.failedButOk; + } + + SectionStats testCaseSectionStats( testCaseSection, assertions, duration, missingAssertions ); + m_reporter->sectionEnded( testCaseSectionStats ); + } + + void invokeActiveTestCase() { + FatalConditionHandler fatalConditionHandler; // Handle signals + m_activeTestCase->invoke(); + fatalConditionHandler.reset(); + } + + private: + + ResultBuilder makeUnexpectedResultBuilder() const { + return ResultBuilder( m_lastAssertionInfo.macroName, + m_lastAssertionInfo.lineInfo, + m_lastAssertionInfo.capturedExpression, + m_lastAssertionInfo.resultDisposition ); + } + + void handleUnfinishedSections() { + // If sections ended prematurely due to an exception we stored their + // infos here so we can tear them down outside the unwind process. + for( std::vector::const_reverse_iterator it = m_unfinishedSections.rbegin(), + itEnd = m_unfinishedSections.rend(); + it != itEnd; + ++it ) + sectionEnded( *it ); + m_unfinishedSections.clear(); + } + + TestRunInfo m_runInfo; + IMutableContext& m_context; + TestCase const* m_activeTestCase; + ITracker* m_testCaseTracker; + ITracker* m_currentSectionTracker; + AssertionResult m_lastResult; + + Ptr m_config; + Totals m_totals; + Ptr m_reporter; + std::vector m_messages; + AssertionInfo m_lastAssertionInfo; + std::vector m_unfinishedSections; + std::vector m_activeSections; + TrackerContext m_trackerContext; + bool m_shouldReportUnexpected; + }; + + IResultCapture& getResultCapture() { + if( IResultCapture* capture = getCurrentContext().getResultCapture() ) + return *capture; + else + throw std::logic_error( "No result capture instance" ); + } + +} // end namespace Catch + +// #included from: internal/catch_version.h +#define TWOBLUECUBES_CATCH_VERSION_H_INCLUDED + +namespace Catch { + + // Versioning information + struct Version { + Version( unsigned int _majorVersion, + unsigned int _minorVersion, + unsigned int _patchNumber, + char const * const _branchName, + unsigned int _buildNumber ); + + unsigned int const majorVersion; + unsigned int const minorVersion; + unsigned int const patchNumber; + + // buildNumber is only used if branchName is not null + char const * const branchName; + unsigned int const buildNumber; + + friend std::ostream& operator << ( std::ostream& os, Version const& version ); + + private: + void operator=( Version const& ); + }; + + inline Version libraryVersion(); +} + +#include +#include +#include + +namespace Catch { + + Ptr createReporter( std::string const& reporterName, Ptr const& config ) { + Ptr reporter = getRegistryHub().getReporterRegistry().create( reporterName, config.get() ); + if( !reporter ) { + std::ostringstream oss; + oss << "No reporter registered with name: '" << reporterName << "'"; + throw std::domain_error( oss.str() ); + } + return reporter; + } + + Ptr makeReporter( Ptr const& config ) { + std::vector reporters = config->getReporterNames(); + if( reporters.empty() ) + reporters.push_back( "console" ); + + Ptr reporter; + for( std::vector::const_iterator it = reporters.begin(), itEnd = reporters.end(); + it != itEnd; + ++it ) + reporter = addReporter( reporter, createReporter( *it, config ) ); + return reporter; + } + Ptr addListeners( Ptr const& config, Ptr reporters ) { + IReporterRegistry::Listeners listeners = getRegistryHub().getReporterRegistry().getListeners(); + for( IReporterRegistry::Listeners::const_iterator it = listeners.begin(), itEnd = listeners.end(); + it != itEnd; + ++it ) + reporters = addReporter(reporters, (*it)->create( ReporterConfig( config ) ) ); + return reporters; + } + + Totals runTests( Ptr const& config ) { + + Ptr iconfig = config.get(); + + Ptr reporter = makeReporter( config ); + reporter = addListeners( iconfig, reporter ); + + RunContext context( iconfig, reporter ); + + Totals totals; + + context.testGroupStarting( config->name(), 1, 1 ); + + TestSpec testSpec = config->testSpec(); + if( !testSpec.hasFilters() ) + testSpec = TestSpecParser( ITagAliasRegistry::get() ).parse( "~[.]" ).testSpec(); // All not hidden tests + + std::vector const& allTestCases = getAllTestCasesSorted( *iconfig ); + for( std::vector::const_iterator it = allTestCases.begin(), itEnd = allTestCases.end(); + it != itEnd; + ++it ) { + if( !context.aborting() && matchTest( *it, testSpec, *iconfig ) ) + totals += context.runTest( *it ); + else + reporter->skipTest( *it ); + } + + context.testGroupEnded( iconfig->name(), totals, 1, 1 ); + return totals; + } + + void applyFilenamesAsTags( IConfig const& config ) { + std::vector const& tests = getAllTestCasesSorted( config ); + for(std::size_t i = 0; i < tests.size(); ++i ) { + TestCase& test = const_cast( tests[i] ); + std::set tags = test.tags; + + std::string filename = test.lineInfo.file; + std::string::size_type lastSlash = filename.find_last_of( "\\/" ); + if( lastSlash != std::string::npos ) + filename = filename.substr( lastSlash+1 ); + + std::string::size_type lastDot = filename.find_last_of( "." ); + if( lastDot != std::string::npos ) + filename = filename.substr( 0, lastDot ); + + tags.insert( "#" + filename ); + setTags( test, tags ); + } + } + + class Session : NonCopyable { + static bool alreadyInstantiated; + + public: + + struct OnUnusedOptions { enum DoWhat { Ignore, Fail }; }; + + Session() + : m_cli( makeCommandLineParser() ) { + if( alreadyInstantiated ) { + std::string msg = "Only one instance of Catch::Session can ever be used"; + Catch::cerr() << msg << std::endl; + throw std::logic_error( msg ); + } + alreadyInstantiated = true; + } + ~Session() { + Catch::cleanUp(); + } + + void showHelp( std::string const& processName ) { + Catch::cout() << "\nCatch v" << libraryVersion() << "\n"; + + m_cli.usage( Catch::cout(), processName ); + Catch::cout() << "For more detail usage please see the project docs\n" << std::endl; + } + + int applyCommandLine( int argc, char const* const* const argv, OnUnusedOptions::DoWhat unusedOptionBehaviour = OnUnusedOptions::Fail ) { + try { + m_cli.setThrowOnUnrecognisedTokens( unusedOptionBehaviour == OnUnusedOptions::Fail ); + m_unusedTokens = m_cli.parseInto( Clara::argsToVector( argc, argv ), m_configData ); + if( m_configData.showHelp ) + showHelp( m_configData.processName ); + m_config.reset(); + } + catch( std::exception& ex ) { + { + Colour colourGuard( Colour::Red ); + Catch::cerr() + << "\nError(s) in input:\n" + << Text( ex.what(), TextAttributes().setIndent(2) ) + << "\n\n"; + } + m_cli.usage( Catch::cout(), m_configData.processName ); + return (std::numeric_limits::max)(); + } + return 0; + } + + void useConfigData( ConfigData const& _configData ) { + m_configData = _configData; + m_config.reset(); + } + + int run( int argc, char const* const* const argv ) { + + int returnCode = applyCommandLine( argc, argv ); + if( returnCode == 0 ) + returnCode = run(); + return returnCode; + } + + #if defined(WIN32) && defined(UNICODE) + int run( int argc, wchar_t const* const* const argv ) { + + char **utf8Argv = new char *[ argc ]; + + for ( int i = 0; i < argc; ++i ) { + int bufSize = WideCharToMultiByte( CP_UTF8, 0, argv[i], -1, NULL, 0, NULL, NULL ); + + utf8Argv[ i ] = new char[ bufSize ]; + + WideCharToMultiByte( CP_UTF8, 0, argv[i], -1, utf8Argv[i], bufSize, NULL, NULL ); + } + + int returnCode = applyCommandLine( argc, utf8Argv ); + if( returnCode == 0 ) + returnCode = run(); + + for ( int i = 0; i < argc; ++i ) + delete [] utf8Argv[ i ]; + + delete [] utf8Argv; + + return returnCode; + } + #endif + + int run() { + if( m_configData.showHelp ) + return 0; + + try + { + config(); // Force config to be constructed + + seedRng( *m_config ); + + if( m_configData.filenamesAsTags ) + applyFilenamesAsTags( *m_config ); + + // Handle list request + if( Option listed = list( config() ) ) + return static_cast( *listed ); + + return static_cast( runTests( m_config ).assertions.failed ); + } + catch( std::exception& ex ) { + Catch::cerr() << ex.what() << std::endl; + return (std::numeric_limits::max)(); + } + } + + Clara::CommandLine const& cli() const { + return m_cli; + } + std::vector const& unusedTokens() const { + return m_unusedTokens; + } + ConfigData& configData() { + return m_configData; + } + Config& config() { + if( !m_config ) + m_config = new Config( m_configData ); + return *m_config; + } + private: + Clara::CommandLine m_cli; + std::vector m_unusedTokens; + ConfigData m_configData; + Ptr m_config; + }; + + bool Session::alreadyInstantiated = false; + +} // end namespace Catch + +// #included from: catch_registry_hub.hpp +#define TWOBLUECUBES_CATCH_REGISTRY_HUB_HPP_INCLUDED + +// #included from: catch_test_case_registry_impl.hpp +#define TWOBLUECUBES_CATCH_TEST_CASE_REGISTRY_IMPL_HPP_INCLUDED + +#include +#include +#include +#include + +namespace Catch { + + struct RandomNumberGenerator { + typedef std::ptrdiff_t result_type; + + result_type operator()( result_type n ) const { return std::rand() % n; } + +#ifdef CATCH_CONFIG_CPP11_SHUFFLE + static constexpr result_type min() { return 0; } + static constexpr result_type max() { return 1000000; } + result_type operator()() const { return std::rand() % max(); } +#endif + template + static void shuffle( V& vector ) { + RandomNumberGenerator rng; +#ifdef CATCH_CONFIG_CPP11_SHUFFLE + std::shuffle( vector.begin(), vector.end(), rng ); +#else + std::random_shuffle( vector.begin(), vector.end(), rng ); +#endif + } + }; + + inline std::vector sortTests( IConfig const& config, std::vector const& unsortedTestCases ) { + + std::vector sorted = unsortedTestCases; + + switch( config.runOrder() ) { + case RunTests::InLexicographicalOrder: + std::sort( sorted.begin(), sorted.end() ); + break; + case RunTests::InRandomOrder: + { + seedRng( config ); + RandomNumberGenerator::shuffle( sorted ); + } + break; + case RunTests::InDeclarationOrder: + // already in declaration order + break; + } + return sorted; + } + bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config ) { + return testSpec.matches( testCase ) && ( config.allowThrows() || !testCase.throws() ); + } + + void enforceNoDuplicateTestCases( std::vector const& functions ) { + std::set seenFunctions; + for( std::vector::const_iterator it = functions.begin(), itEnd = functions.end(); + it != itEnd; + ++it ) { + std::pair::const_iterator, bool> prev = seenFunctions.insert( *it ); + if( !prev.second ) { + std::ostringstream ss; + + ss << Colour( Colour::Red ) + << "error: TEST_CASE( \"" << it->name << "\" ) already defined.\n" + << "\tFirst seen at " << prev.first->getTestCaseInfo().lineInfo << '\n' + << "\tRedefined at " << it->getTestCaseInfo().lineInfo << std::endl; + + throw std::runtime_error(ss.str()); + } + } + } + + std::vector filterTests( std::vector const& testCases, TestSpec const& testSpec, IConfig const& config ) { + std::vector filtered; + filtered.reserve( testCases.size() ); + for( std::vector::const_iterator it = testCases.begin(), itEnd = testCases.end(); + it != itEnd; + ++it ) + if( matchTest( *it, testSpec, config ) ) + filtered.push_back( *it ); + return filtered; + } + std::vector const& getAllTestCasesSorted( IConfig const& config ) { + return getRegistryHub().getTestCaseRegistry().getAllTestsSorted( config ); + } + + class TestRegistry : public ITestCaseRegistry { + public: + TestRegistry() + : m_currentSortOrder( RunTests::InDeclarationOrder ), + m_unnamedCount( 0 ) + {} + virtual ~TestRegistry(); + + virtual void registerTest( TestCase const& testCase ) { + std::string name = testCase.getTestCaseInfo().name; + if( name.empty() ) { + std::ostringstream oss; + oss << "Anonymous test case " << ++m_unnamedCount; + return registerTest( testCase.withName( oss.str() ) ); + } + m_functions.push_back( testCase ); + } + + virtual std::vector const& getAllTests() const { + return m_functions; + } + virtual std::vector const& getAllTestsSorted( IConfig const& config ) const { + if( m_sortedFunctions.empty() ) + enforceNoDuplicateTestCases( m_functions ); + + if( m_currentSortOrder != config.runOrder() || m_sortedFunctions.empty() ) { + m_sortedFunctions = sortTests( config, m_functions ); + m_currentSortOrder = config.runOrder(); + } + return m_sortedFunctions; + } + + private: + std::vector m_functions; + mutable RunTests::InWhatOrder m_currentSortOrder; + mutable std::vector m_sortedFunctions; + size_t m_unnamedCount; + std::ios_base::Init m_ostreamInit; // Forces cout/ cerr to be initialised + }; + + /////////////////////////////////////////////////////////////////////////// + + class FreeFunctionTestCase : public SharedImpl { + public: + + FreeFunctionTestCase( TestFunction fun ) : m_fun( fun ) {} + + virtual void invoke() const { + m_fun(); + } + + private: + virtual ~FreeFunctionTestCase(); + + TestFunction m_fun; + }; + + inline std::string extractClassName( std::string const& classOrQualifiedMethodName ) { + std::string className = classOrQualifiedMethodName; + if( startsWith( className, '&' ) ) + { + std::size_t lastColons = className.rfind( "::" ); + std::size_t penultimateColons = className.rfind( "::", lastColons-1 ); + if( penultimateColons == std::string::npos ) + penultimateColons = 1; + className = className.substr( penultimateColons, lastColons-penultimateColons ); + } + return className; + } + + void registerTestCase + ( ITestCase* testCase, + char const* classOrQualifiedMethodName, + NameAndDesc const& nameAndDesc, + SourceLineInfo const& lineInfo ) { + + getMutableRegistryHub().registerTest + ( makeTestCase + ( testCase, + extractClassName( classOrQualifiedMethodName ), + nameAndDesc.name, + nameAndDesc.description, + lineInfo ) ); + } + void registerTestCaseFunction + ( TestFunction function, + SourceLineInfo const& lineInfo, + NameAndDesc const& nameAndDesc ) { + registerTestCase( new FreeFunctionTestCase( function ), "", nameAndDesc, lineInfo ); + } + + /////////////////////////////////////////////////////////////////////////// + + AutoReg::AutoReg + ( TestFunction function, + SourceLineInfo const& lineInfo, + NameAndDesc const& nameAndDesc ) { + registerTestCaseFunction( function, lineInfo, nameAndDesc ); + } + + AutoReg::~AutoReg() {} + +} // end namespace Catch + +// #included from: catch_reporter_registry.hpp +#define TWOBLUECUBES_CATCH_REPORTER_REGISTRY_HPP_INCLUDED + +#include + +namespace Catch { + + class ReporterRegistry : public IReporterRegistry { + + public: + + virtual ~ReporterRegistry() CATCH_OVERRIDE {} + + virtual IStreamingReporter* create( std::string const& name, Ptr const& config ) const CATCH_OVERRIDE { + FactoryMap::const_iterator it = m_factories.find( name ); + if( it == m_factories.end() ) + return CATCH_NULL; + return it->second->create( ReporterConfig( config ) ); + } + + void registerReporter( std::string const& name, Ptr const& factory ) { + m_factories.insert( std::make_pair( name, factory ) ); + } + void registerListener( Ptr const& factory ) { + m_listeners.push_back( factory ); + } + + virtual FactoryMap const& getFactories() const CATCH_OVERRIDE { + return m_factories; + } + virtual Listeners const& getListeners() const CATCH_OVERRIDE { + return m_listeners; + } + + private: + FactoryMap m_factories; + Listeners m_listeners; + }; +} + +// #included from: catch_exception_translator_registry.hpp +#define TWOBLUECUBES_CATCH_EXCEPTION_TRANSLATOR_REGISTRY_HPP_INCLUDED + +#ifdef __OBJC__ +#import "Foundation/Foundation.h" +#endif + +namespace Catch { + + class ExceptionTranslatorRegistry : public IExceptionTranslatorRegistry { + public: + ~ExceptionTranslatorRegistry() { + deleteAll( m_translators ); + } + + virtual void registerTranslator( const IExceptionTranslator* translator ) { + m_translators.push_back( translator ); + } + + virtual std::string translateActiveException() const { + try { +#ifdef __OBJC__ + // In Objective-C try objective-c exceptions first + @try { + return tryTranslators(); + } + @catch (NSException *exception) { + return Catch::toString( [exception description] ); + } +#else + return tryTranslators(); +#endif + } + catch( TestFailureException& ) { + throw; + } + catch( std::exception& ex ) { + return ex.what(); + } + catch( std::string& msg ) { + return msg; + } + catch( const char* msg ) { + return msg; + } + catch(...) { + return "Unknown exception"; + } + } + + std::string tryTranslators() const { + if( m_translators.empty() ) + throw; + else + return m_translators[0]->translate( m_translators.begin()+1, m_translators.end() ); + } + + private: + std::vector m_translators; + }; +} + +// #included from: catch_tag_alias_registry.h +#define TWOBLUECUBES_CATCH_TAG_ALIAS_REGISTRY_H_INCLUDED + +#include + +namespace Catch { + + class TagAliasRegistry : public ITagAliasRegistry { + public: + virtual ~TagAliasRegistry(); + virtual Option find( std::string const& alias ) const; + virtual std::string expandAliases( std::string const& unexpandedTestSpec ) const; + void add( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ); + + private: + std::map m_registry; + }; + +} // end namespace Catch + +namespace Catch { + + namespace { + + class RegistryHub : public IRegistryHub, public IMutableRegistryHub { + + RegistryHub( RegistryHub const& ); + void operator=( RegistryHub const& ); + + public: // IRegistryHub + RegistryHub() { + } + virtual IReporterRegistry const& getReporterRegistry() const CATCH_OVERRIDE { + return m_reporterRegistry; + } + virtual ITestCaseRegistry const& getTestCaseRegistry() const CATCH_OVERRIDE { + return m_testCaseRegistry; + } + virtual IExceptionTranslatorRegistry& getExceptionTranslatorRegistry() CATCH_OVERRIDE { + return m_exceptionTranslatorRegistry; + } + virtual ITagAliasRegistry const& getTagAliasRegistry() const CATCH_OVERRIDE { + return m_tagAliasRegistry; + } + + public: // IMutableRegistryHub + virtual void registerReporter( std::string const& name, Ptr const& factory ) CATCH_OVERRIDE { + m_reporterRegistry.registerReporter( name, factory ); + } + virtual void registerListener( Ptr const& factory ) CATCH_OVERRIDE { + m_reporterRegistry.registerListener( factory ); + } + virtual void registerTest( TestCase const& testInfo ) CATCH_OVERRIDE { + m_testCaseRegistry.registerTest( testInfo ); + } + virtual void registerTranslator( const IExceptionTranslator* translator ) CATCH_OVERRIDE { + m_exceptionTranslatorRegistry.registerTranslator( translator ); + } + virtual void registerTagAlias( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) CATCH_OVERRIDE { + m_tagAliasRegistry.add( alias, tag, lineInfo ); + } + + private: + TestRegistry m_testCaseRegistry; + ReporterRegistry m_reporterRegistry; + ExceptionTranslatorRegistry m_exceptionTranslatorRegistry; + TagAliasRegistry m_tagAliasRegistry; + }; + + // Single, global, instance + inline RegistryHub*& getTheRegistryHub() { + static RegistryHub* theRegistryHub = CATCH_NULL; + if( !theRegistryHub ) + theRegistryHub = new RegistryHub(); + return theRegistryHub; + } + } + + IRegistryHub& getRegistryHub() { + return *getTheRegistryHub(); + } + IMutableRegistryHub& getMutableRegistryHub() { + return *getTheRegistryHub(); + } + void cleanUp() { + delete getTheRegistryHub(); + getTheRegistryHub() = CATCH_NULL; + cleanUpContext(); + } + std::string translateActiveException() { + return getRegistryHub().getExceptionTranslatorRegistry().translateActiveException(); + } + +} // end namespace Catch + +// #included from: catch_notimplemented_exception.hpp +#define TWOBLUECUBES_CATCH_NOTIMPLEMENTED_EXCEPTION_HPP_INCLUDED + +#include + +namespace Catch { + + NotImplementedException::NotImplementedException( SourceLineInfo const& lineInfo ) + : m_lineInfo( lineInfo ) { + std::ostringstream oss; + oss << lineInfo << ": function "; + oss << "not implemented"; + m_what = oss.str(); + } + + const char* NotImplementedException::what() const CATCH_NOEXCEPT { + return m_what.c_str(); + } + +} // end namespace Catch + +// #included from: catch_context_impl.hpp +#define TWOBLUECUBES_CATCH_CONTEXT_IMPL_HPP_INCLUDED + +// #included from: catch_stream.hpp +#define TWOBLUECUBES_CATCH_STREAM_HPP_INCLUDED + +#include +#include +#include + +namespace Catch { + + template + class StreamBufImpl : public StreamBufBase { + char data[bufferSize]; + WriterF m_writer; + + public: + StreamBufImpl() { + setp( data, data + sizeof(data) ); + } + + ~StreamBufImpl() CATCH_NOEXCEPT { + sync(); + } + + private: + int overflow( int c ) { + sync(); + + if( c != EOF ) { + if( pbase() == epptr() ) + m_writer( std::string( 1, static_cast( c ) ) ); + else + sputc( static_cast( c ) ); + } + return 0; + } + + int sync() { + if( pbase() != pptr() ) { + m_writer( std::string( pbase(), static_cast( pptr() - pbase() ) ) ); + setp( pbase(), epptr() ); + } + return 0; + } + }; + + /////////////////////////////////////////////////////////////////////////// + + FileStream::FileStream( std::string const& filename ) { + m_ofs.open( filename.c_str() ); + if( m_ofs.fail() ) { + std::ostringstream oss; + oss << "Unable to open file: '" << filename << '\''; + throw std::domain_error( oss.str() ); + } + } + + std::ostream& FileStream::stream() const { + return m_ofs; + } + + struct OutputDebugWriter { + + void operator()( std::string const&str ) { + writeToDebugConsole( str ); + } + }; + + DebugOutStream::DebugOutStream() + : m_streamBuf( new StreamBufImpl() ), + m_os( m_streamBuf.get() ) + {} + + std::ostream& DebugOutStream::stream() const { + return m_os; + } + + // Store the streambuf from cout up-front because + // cout may get redirected when running tests + CoutStream::CoutStream() + : m_os( Catch::cout().rdbuf() ) + {} + + std::ostream& CoutStream::stream() const { + return m_os; + } + +#ifndef CATCH_CONFIG_NOSTDOUT // If you #define this you must implement these functions + std::ostream& cout() { + return std::cout; + } + std::ostream& cerr() { + return std::cerr; + } +#endif +} + +namespace Catch { + + class Context : public IMutableContext { + + Context() : m_config( CATCH_NULL ), m_runner( CATCH_NULL ), m_resultCapture( CATCH_NULL ) {} + Context( Context const& ); + void operator=( Context const& ); + + public: + virtual ~Context() { + deleteAllValues( m_generatorsByTestName ); + } + + public: // IContext + virtual IResultCapture* getResultCapture() { + return m_resultCapture; + } + virtual IRunner* getRunner() { + return m_runner; + } + virtual size_t getGeneratorIndex( std::string const& fileInfo, size_t totalSize ) { + return getGeneratorsForCurrentTest() + .getGeneratorInfo( fileInfo, totalSize ) + .getCurrentIndex(); + } + virtual bool advanceGeneratorsForCurrentTest() { + IGeneratorsForTest* generators = findGeneratorsForCurrentTest(); + return generators && generators->moveNext(); + } + + virtual Ptr getConfig() const { + return m_config; + } + + public: // IMutableContext + virtual void setResultCapture( IResultCapture* resultCapture ) { + m_resultCapture = resultCapture; + } + virtual void setRunner( IRunner* runner ) { + m_runner = runner; + } + virtual void setConfig( Ptr const& config ) { + m_config = config; + } + + friend IMutableContext& getCurrentMutableContext(); + + private: + IGeneratorsForTest* findGeneratorsForCurrentTest() { + std::string testName = getResultCapture()->getCurrentTestName(); + + std::map::const_iterator it = + m_generatorsByTestName.find( testName ); + return it != m_generatorsByTestName.end() + ? it->second + : CATCH_NULL; + } + + IGeneratorsForTest& getGeneratorsForCurrentTest() { + IGeneratorsForTest* generators = findGeneratorsForCurrentTest(); + if( !generators ) { + std::string testName = getResultCapture()->getCurrentTestName(); + generators = createGeneratorsForTest(); + m_generatorsByTestName.insert( std::make_pair( testName, generators ) ); + } + return *generators; + } + + private: + Ptr m_config; + IRunner* m_runner; + IResultCapture* m_resultCapture; + std::map m_generatorsByTestName; + }; + + namespace { + Context* currentContext = CATCH_NULL; + } + IMutableContext& getCurrentMutableContext() { + if( !currentContext ) + currentContext = new Context(); + return *currentContext; + } + IContext& getCurrentContext() { + return getCurrentMutableContext(); + } + + void cleanUpContext() { + delete currentContext; + currentContext = CATCH_NULL; + } +} + +// #included from: catch_console_colour_impl.hpp +#define TWOBLUECUBES_CATCH_CONSOLE_COLOUR_IMPL_HPP_INCLUDED + +// #included from: catch_errno_guard.hpp +#define TWOBLUECUBES_CATCH_ERRNO_GUARD_HPP_INCLUDED + +#include + +namespace Catch { + + class ErrnoGuard { + public: + ErrnoGuard():m_oldErrno(errno){} + ~ErrnoGuard() { errno = m_oldErrno; } + private: + int m_oldErrno; + }; + +} + +namespace Catch { + namespace { + + struct IColourImpl { + virtual ~IColourImpl() {} + virtual void use( Colour::Code _colourCode ) = 0; + }; + + struct NoColourImpl : IColourImpl { + void use( Colour::Code ) {} + + static IColourImpl* instance() { + static NoColourImpl s_instance; + return &s_instance; + } + }; + + } // anon namespace +} // namespace Catch + +#if !defined( CATCH_CONFIG_COLOUR_NONE ) && !defined( CATCH_CONFIG_COLOUR_WINDOWS ) && !defined( CATCH_CONFIG_COLOUR_ANSI ) +# ifdef CATCH_PLATFORM_WINDOWS +# define CATCH_CONFIG_COLOUR_WINDOWS +# else +# define CATCH_CONFIG_COLOUR_ANSI +# endif +#endif + +#if defined ( CATCH_CONFIG_COLOUR_WINDOWS ) ///////////////////////////////////////// + +namespace Catch { +namespace { + + class Win32ColourImpl : public IColourImpl { + public: + Win32ColourImpl() : stdoutHandle( GetStdHandle(STD_OUTPUT_HANDLE) ) + { + CONSOLE_SCREEN_BUFFER_INFO csbiInfo; + GetConsoleScreenBufferInfo( stdoutHandle, &csbiInfo ); + originalForegroundAttributes = csbiInfo.wAttributes & ~( BACKGROUND_GREEN | BACKGROUND_RED | BACKGROUND_BLUE | BACKGROUND_INTENSITY ); + originalBackgroundAttributes = csbiInfo.wAttributes & ~( FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE | FOREGROUND_INTENSITY ); + } + + virtual void use( Colour::Code _colourCode ) { + switch( _colourCode ) { + case Colour::None: return setTextAttribute( originalForegroundAttributes ); + case Colour::White: return setTextAttribute( FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE ); + case Colour::Red: return setTextAttribute( FOREGROUND_RED ); + case Colour::Green: return setTextAttribute( FOREGROUND_GREEN ); + case Colour::Blue: return setTextAttribute( FOREGROUND_BLUE ); + case Colour::Cyan: return setTextAttribute( FOREGROUND_BLUE | FOREGROUND_GREEN ); + case Colour::Yellow: return setTextAttribute( FOREGROUND_RED | FOREGROUND_GREEN ); + case Colour::Grey: return setTextAttribute( 0 ); + + case Colour::LightGrey: return setTextAttribute( FOREGROUND_INTENSITY ); + case Colour::BrightRed: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_RED ); + case Colour::BrightGreen: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_GREEN ); + case Colour::BrightWhite: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE ); + + case Colour::Bright: throw std::logic_error( "not a colour" ); + } + } + + private: + void setTextAttribute( WORD _textAttribute ) { + SetConsoleTextAttribute( stdoutHandle, _textAttribute | originalBackgroundAttributes ); + } + HANDLE stdoutHandle; + WORD originalForegroundAttributes; + WORD originalBackgroundAttributes; + }; + + IColourImpl* platformColourInstance() { + static Win32ColourImpl s_instance; + + Ptr config = getCurrentContext().getConfig(); + UseColour::YesOrNo colourMode = config + ? config->useColour() + : UseColour::Auto; + if( colourMode == UseColour::Auto ) + colourMode = !isDebuggerActive() + ? UseColour::Yes + : UseColour::No; + return colourMode == UseColour::Yes + ? &s_instance + : NoColourImpl::instance(); + } + +} // end anon namespace +} // end namespace Catch + +#elif defined( CATCH_CONFIG_COLOUR_ANSI ) ////////////////////////////////////// + +#include + +namespace Catch { +namespace { + + // use POSIX/ ANSI console terminal codes + // Thanks to Adam Strzelecki for original contribution + // (http://github.com/nanoant) + // https://github.com/philsquared/Catch/pull/131 + class PosixColourImpl : public IColourImpl { + public: + virtual void use( Colour::Code _colourCode ) { + switch( _colourCode ) { + case Colour::None: + case Colour::White: return setColour( "[0m" ); + case Colour::Red: return setColour( "[0;31m" ); + case Colour::Green: return setColour( "[0;32m" ); + case Colour::Blue: return setColour( "[0;34m" ); + case Colour::Cyan: return setColour( "[0;36m" ); + case Colour::Yellow: return setColour( "[0;33m" ); + case Colour::Grey: return setColour( "[1;30m" ); + + case Colour::LightGrey: return setColour( "[0;37m" ); + case Colour::BrightRed: return setColour( "[1;31m" ); + case Colour::BrightGreen: return setColour( "[1;32m" ); + case Colour::BrightWhite: return setColour( "[1;37m" ); + + case Colour::Bright: throw std::logic_error( "not a colour" ); + } + } + static IColourImpl* instance() { + static PosixColourImpl s_instance; + return &s_instance; + } + + private: + void setColour( const char* _escapeCode ) { + Catch::cout() << '\033' << _escapeCode; + } + }; + + IColourImpl* platformColourInstance() { + ErrnoGuard guard; + Ptr config = getCurrentContext().getConfig(); + UseColour::YesOrNo colourMode = config + ? config->useColour() + : UseColour::Auto; + if( colourMode == UseColour::Auto ) + colourMode = (!isDebuggerActive() && isatty(STDOUT_FILENO) ) + ? UseColour::Yes + : UseColour::No; + return colourMode == UseColour::Yes + ? PosixColourImpl::instance() + : NoColourImpl::instance(); + } + +} // end anon namespace +} // end namespace Catch + +#else // not Windows or ANSI /////////////////////////////////////////////// + +namespace Catch { + + static IColourImpl* platformColourInstance() { return NoColourImpl::instance(); } + +} // end namespace Catch + +#endif // Windows/ ANSI/ None + +namespace Catch { + + Colour::Colour( Code _colourCode ) : m_moved( false ) { use( _colourCode ); } + Colour::Colour( Colour const& _other ) : m_moved( false ) { const_cast( _other ).m_moved = true; } + Colour::~Colour(){ if( !m_moved ) use( None ); } + + void Colour::use( Code _colourCode ) { + static IColourImpl* impl = platformColourInstance(); + impl->use( _colourCode ); + } + +} // end namespace Catch + +// #included from: catch_generators_impl.hpp +#define TWOBLUECUBES_CATCH_GENERATORS_IMPL_HPP_INCLUDED + +#include +#include +#include + +namespace Catch { + + struct GeneratorInfo : IGeneratorInfo { + + GeneratorInfo( std::size_t size ) + : m_size( size ), + m_currentIndex( 0 ) + {} + + bool moveNext() { + if( ++m_currentIndex == m_size ) { + m_currentIndex = 0; + return false; + } + return true; + } + + std::size_t getCurrentIndex() const { + return m_currentIndex; + } + + std::size_t m_size; + std::size_t m_currentIndex; + }; + + /////////////////////////////////////////////////////////////////////////// + + class GeneratorsForTest : public IGeneratorsForTest { + + public: + ~GeneratorsForTest() { + deleteAll( m_generatorsInOrder ); + } + + IGeneratorInfo& getGeneratorInfo( std::string const& fileInfo, std::size_t size ) { + std::map::const_iterator it = m_generatorsByName.find( fileInfo ); + if( it == m_generatorsByName.end() ) { + IGeneratorInfo* info = new GeneratorInfo( size ); + m_generatorsByName.insert( std::make_pair( fileInfo, info ) ); + m_generatorsInOrder.push_back( info ); + return *info; + } + return *it->second; + } + + bool moveNext() { + std::vector::const_iterator it = m_generatorsInOrder.begin(); + std::vector::const_iterator itEnd = m_generatorsInOrder.end(); + for(; it != itEnd; ++it ) { + if( (*it)->moveNext() ) + return true; + } + return false; + } + + private: + std::map m_generatorsByName; + std::vector m_generatorsInOrder; + }; + + IGeneratorsForTest* createGeneratorsForTest() + { + return new GeneratorsForTest(); + } + +} // end namespace Catch + +// #included from: catch_assertionresult.hpp +#define TWOBLUECUBES_CATCH_ASSERTIONRESULT_HPP_INCLUDED + +namespace Catch { + + AssertionInfo::AssertionInfo( char const * _macroName, + SourceLineInfo const& _lineInfo, + char const * _capturedExpression, + ResultDisposition::Flags _resultDisposition, + char const * _secondArg) + : macroName( _macroName ), + lineInfo( _lineInfo ), + capturedExpression( _capturedExpression ), + resultDisposition( _resultDisposition ), + secondArg( _secondArg ) + {} + + AssertionResult::AssertionResult() {} + + AssertionResult::AssertionResult( AssertionInfo const& info, AssertionResultData const& data ) + : m_info( info ), + m_resultData( data ) + {} + + AssertionResult::~AssertionResult() {} + + // Result was a success + bool AssertionResult::succeeded() const { + return Catch::isOk( m_resultData.resultType ); + } + + // Result was a success, or failure is suppressed + bool AssertionResult::isOk() const { + return Catch::isOk( m_resultData.resultType ) || shouldSuppressFailure( m_info.resultDisposition ); + } + + ResultWas::OfType AssertionResult::getResultType() const { + return m_resultData.resultType; + } + + bool AssertionResult::hasExpression() const { + return m_info.capturedExpression[0] != 0; + } + + bool AssertionResult::hasMessage() const { + return !m_resultData.message.empty(); + } + + std::string capturedExpressionWithSecondArgument( char const * capturedExpression, char const * secondArg ) { + return (secondArg[0] == 0 || secondArg[0] == '"' && secondArg[1] == '"') + ? capturedExpression + : std::string(capturedExpression) + ", " + secondArg; + } + + std::string AssertionResult::getExpression() const { + if( isFalseTest( m_info.resultDisposition ) ) + return '!' + capturedExpressionWithSecondArgument(m_info.capturedExpression, m_info.secondArg); + else + return capturedExpressionWithSecondArgument(m_info.capturedExpression, m_info.secondArg); + } + std::string AssertionResult::getExpressionInMacro() const { + if( m_info.macroName[0] == 0 ) + return capturedExpressionWithSecondArgument(m_info.capturedExpression, m_info.secondArg); + else + return std::string(m_info.macroName) + "( " + capturedExpressionWithSecondArgument(m_info.capturedExpression, m_info.secondArg) + " )"; + } + + bool AssertionResult::hasExpandedExpression() const { + return hasExpression() && getExpandedExpression() != getExpression(); + } + + std::string AssertionResult::getExpandedExpression() const { + return m_resultData.reconstructExpression(); + } + + std::string AssertionResult::getMessage() const { + return m_resultData.message; + } + SourceLineInfo AssertionResult::getSourceInfo() const { + return m_info.lineInfo; + } + + std::string AssertionResult::getTestMacroName() const { + return m_info.macroName; + } + + void AssertionResult::discardDecomposedExpression() const { + m_resultData.decomposedExpression = CATCH_NULL; + } + + void AssertionResult::expandDecomposedExpression() const { + m_resultData.reconstructExpression(); + } + +} // end namespace Catch + +// #included from: catch_test_case_info.hpp +#define TWOBLUECUBES_CATCH_TEST_CASE_INFO_HPP_INCLUDED + +#include + +namespace Catch { + + inline TestCaseInfo::SpecialProperties parseSpecialTag( std::string const& tag ) { + if( startsWith( tag, '.' ) || + tag == "hide" || + tag == "!hide" ) + return TestCaseInfo::IsHidden; + else if( tag == "!throws" ) + return TestCaseInfo::Throws; + else if( tag == "!shouldfail" ) + return TestCaseInfo::ShouldFail; + else if( tag == "!mayfail" ) + return TestCaseInfo::MayFail; + else if( tag == "!nonportable" ) + return TestCaseInfo::NonPortable; + else + return TestCaseInfo::None; + } + inline bool isReservedTag( std::string const& tag ) { + return parseSpecialTag( tag ) == TestCaseInfo::None && tag.size() > 0 && !std::isalnum( tag[0] ); + } + inline void enforceNotReservedTag( std::string const& tag, SourceLineInfo const& _lineInfo ) { + if( isReservedTag( tag ) ) { + std::ostringstream ss; + ss << Colour(Colour::Red) + << "Tag name [" << tag << "] not allowed.\n" + << "Tag names starting with non alpha-numeric characters are reserved\n" + << Colour(Colour::FileName) + << _lineInfo << '\n'; + throw std::runtime_error(ss.str()); + } + } + + TestCase makeTestCase( ITestCase* _testCase, + std::string const& _className, + std::string const& _name, + std::string const& _descOrTags, + SourceLineInfo const& _lineInfo ) + { + bool isHidden( startsWith( _name, "./" ) ); // Legacy support + + // Parse out tags + std::set tags; + std::string desc, tag; + bool inTag = false; + for( std::size_t i = 0; i < _descOrTags.size(); ++i ) { + char c = _descOrTags[i]; + if( !inTag ) { + if( c == '[' ) + inTag = true; + else + desc += c; + } + else { + if( c == ']' ) { + TestCaseInfo::SpecialProperties prop = parseSpecialTag( tag ); + if( prop == TestCaseInfo::IsHidden ) + isHidden = true; + else if( prop == TestCaseInfo::None ) + enforceNotReservedTag( tag, _lineInfo ); + + tags.insert( tag ); + tag.clear(); + inTag = false; + } + else + tag += c; + } + } + if( isHidden ) { + tags.insert( "hide" ); + tags.insert( "." ); + } + + TestCaseInfo info( _name, _className, desc, tags, _lineInfo ); + return TestCase( _testCase, info ); + } + + void setTags( TestCaseInfo& testCaseInfo, std::set const& tags ) + { + testCaseInfo.tags = tags; + testCaseInfo.lcaseTags.clear(); + + std::ostringstream oss; + for( std::set::const_iterator it = tags.begin(), itEnd = tags.end(); it != itEnd; ++it ) { + oss << '[' << *it << ']'; + std::string lcaseTag = toLower( *it ); + testCaseInfo.properties = static_cast( testCaseInfo.properties | parseSpecialTag( lcaseTag ) ); + testCaseInfo.lcaseTags.insert( lcaseTag ); + } + testCaseInfo.tagsAsString = oss.str(); + } + + TestCaseInfo::TestCaseInfo( std::string const& _name, + std::string const& _className, + std::string const& _description, + std::set const& _tags, + SourceLineInfo const& _lineInfo ) + : name( _name ), + className( _className ), + description( _description ), + lineInfo( _lineInfo ), + properties( None ) + { + setTags( *this, _tags ); + } + + TestCaseInfo::TestCaseInfo( TestCaseInfo const& other ) + : name( other.name ), + className( other.className ), + description( other.description ), + tags( other.tags ), + lcaseTags( other.lcaseTags ), + tagsAsString( other.tagsAsString ), + lineInfo( other.lineInfo ), + properties( other.properties ) + {} + + bool TestCaseInfo::isHidden() const { + return ( properties & IsHidden ) != 0; + } + bool TestCaseInfo::throws() const { + return ( properties & Throws ) != 0; + } + bool TestCaseInfo::okToFail() const { + return ( properties & (ShouldFail | MayFail ) ) != 0; + } + bool TestCaseInfo::expectedToFail() const { + return ( properties & (ShouldFail ) ) != 0; + } + + TestCase::TestCase( ITestCase* testCase, TestCaseInfo const& info ) : TestCaseInfo( info ), test( testCase ) {} + + TestCase::TestCase( TestCase const& other ) + : TestCaseInfo( other ), + test( other.test ) + {} + + TestCase TestCase::withName( std::string const& _newName ) const { + TestCase other( *this ); + other.name = _newName; + return other; + } + + void TestCase::swap( TestCase& other ) { + test.swap( other.test ); + name.swap( other.name ); + className.swap( other.className ); + description.swap( other.description ); + tags.swap( other.tags ); + lcaseTags.swap( other.lcaseTags ); + tagsAsString.swap( other.tagsAsString ); + std::swap( TestCaseInfo::properties, static_cast( other ).properties ); + std::swap( lineInfo, other.lineInfo ); + } + + void TestCase::invoke() const { + test->invoke(); + } + + bool TestCase::operator == ( TestCase const& other ) const { + return test.get() == other.test.get() && + name == other.name && + className == other.className; + } + + bool TestCase::operator < ( TestCase const& other ) const { + return name < other.name; + } + TestCase& TestCase::operator = ( TestCase const& other ) { + TestCase temp( other ); + swap( temp ); + return *this; + } + + TestCaseInfo const& TestCase::getTestCaseInfo() const + { + return *this; + } + +} // end namespace Catch + +// #included from: catch_version.hpp +#define TWOBLUECUBES_CATCH_VERSION_HPP_INCLUDED + +namespace Catch { + + Version::Version + ( unsigned int _majorVersion, + unsigned int _minorVersion, + unsigned int _patchNumber, + char const * const _branchName, + unsigned int _buildNumber ) + : majorVersion( _majorVersion ), + minorVersion( _minorVersion ), + patchNumber( _patchNumber ), + branchName( _branchName ), + buildNumber( _buildNumber ) + {} + + std::ostream& operator << ( std::ostream& os, Version const& version ) { + os << version.majorVersion << '.' + << version.minorVersion << '.' + << version.patchNumber; + // branchName is never null -> 0th char is \0 if it is empty + if (version.branchName[0]) { + os << '-' << version.branchName + << '.' << version.buildNumber; + } + return os; + } + + inline Version libraryVersion() { + static Version version( 1, 9, 6, "", 0 ); + return version; + } + +} + +// #included from: catch_message.hpp +#define TWOBLUECUBES_CATCH_MESSAGE_HPP_INCLUDED + +namespace Catch { + + MessageInfo::MessageInfo( std::string const& _macroName, + SourceLineInfo const& _lineInfo, + ResultWas::OfType _type ) + : macroName( _macroName ), + lineInfo( _lineInfo ), + type( _type ), + sequence( ++globalCount ) + {} + + // This may need protecting if threading support is added + unsigned int MessageInfo::globalCount = 0; + + //////////////////////////////////////////////////////////////////////////// + + ScopedMessage::ScopedMessage( MessageBuilder const& builder ) + : m_info( builder.m_info ) + { + m_info.message = builder.m_stream.str(); + getResultCapture().pushScopedMessage( m_info ); + } + ScopedMessage::ScopedMessage( ScopedMessage const& other ) + : m_info( other.m_info ) + {} + + ScopedMessage::~ScopedMessage() { + if ( !std::uncaught_exception() ){ + getResultCapture().popScopedMessage(m_info); + } + } + +} // end namespace Catch + +// #included from: catch_legacy_reporter_adapter.hpp +#define TWOBLUECUBES_CATCH_LEGACY_REPORTER_ADAPTER_HPP_INCLUDED + +// #included from: catch_legacy_reporter_adapter.h +#define TWOBLUECUBES_CATCH_LEGACY_REPORTER_ADAPTER_H_INCLUDED + +namespace Catch +{ + // Deprecated + struct IReporter : IShared { + virtual ~IReporter(); + + virtual bool shouldRedirectStdout() const = 0; + + virtual void StartTesting() = 0; + virtual void EndTesting( Totals const& totals ) = 0; + virtual void StartGroup( std::string const& groupName ) = 0; + virtual void EndGroup( std::string const& groupName, Totals const& totals ) = 0; + virtual void StartTestCase( TestCaseInfo const& testInfo ) = 0; + virtual void EndTestCase( TestCaseInfo const& testInfo, Totals const& totals, std::string const& stdOut, std::string const& stdErr ) = 0; + virtual void StartSection( std::string const& sectionName, std::string const& description ) = 0; + virtual void EndSection( std::string const& sectionName, Counts const& assertions ) = 0; + virtual void NoAssertionsInSection( std::string const& sectionName ) = 0; + virtual void NoAssertionsInTestCase( std::string const& testName ) = 0; + virtual void Aborted() = 0; + virtual void Result( AssertionResult const& result ) = 0; + }; + + class LegacyReporterAdapter : public SharedImpl + { + public: + LegacyReporterAdapter( Ptr const& legacyReporter ); + virtual ~LegacyReporterAdapter(); + + virtual ReporterPreferences getPreferences() const; + virtual void noMatchingTestCases( std::string const& ); + virtual void testRunStarting( TestRunInfo const& ); + virtual void testGroupStarting( GroupInfo const& groupInfo ); + virtual void testCaseStarting( TestCaseInfo const& testInfo ); + virtual void sectionStarting( SectionInfo const& sectionInfo ); + virtual void assertionStarting( AssertionInfo const& ); + virtual bool assertionEnded( AssertionStats const& assertionStats ); + virtual void sectionEnded( SectionStats const& sectionStats ); + virtual void testCaseEnded( TestCaseStats const& testCaseStats ); + virtual void testGroupEnded( TestGroupStats const& testGroupStats ); + virtual void testRunEnded( TestRunStats const& testRunStats ); + virtual void skipTest( TestCaseInfo const& ); + + private: + Ptr m_legacyReporter; + }; +} + +namespace Catch +{ + LegacyReporterAdapter::LegacyReporterAdapter( Ptr const& legacyReporter ) + : m_legacyReporter( legacyReporter ) + {} + LegacyReporterAdapter::~LegacyReporterAdapter() {} + + ReporterPreferences LegacyReporterAdapter::getPreferences() const { + ReporterPreferences prefs; + prefs.shouldRedirectStdOut = m_legacyReporter->shouldRedirectStdout(); + return prefs; + } + + void LegacyReporterAdapter::noMatchingTestCases( std::string const& ) {} + void LegacyReporterAdapter::testRunStarting( TestRunInfo const& ) { + m_legacyReporter->StartTesting(); + } + void LegacyReporterAdapter::testGroupStarting( GroupInfo const& groupInfo ) { + m_legacyReporter->StartGroup( groupInfo.name ); + } + void LegacyReporterAdapter::testCaseStarting( TestCaseInfo const& testInfo ) { + m_legacyReporter->StartTestCase( testInfo ); + } + void LegacyReporterAdapter::sectionStarting( SectionInfo const& sectionInfo ) { + m_legacyReporter->StartSection( sectionInfo.name, sectionInfo.description ); + } + void LegacyReporterAdapter::assertionStarting( AssertionInfo const& ) { + // Not on legacy interface + } + + bool LegacyReporterAdapter::assertionEnded( AssertionStats const& assertionStats ) { + if( assertionStats.assertionResult.getResultType() != ResultWas::Ok ) { + for( std::vector::const_iterator it = assertionStats.infoMessages.begin(), itEnd = assertionStats.infoMessages.end(); + it != itEnd; + ++it ) { + if( it->type == ResultWas::Info ) { + ResultBuilder rb( it->macroName.c_str(), it->lineInfo, "", ResultDisposition::Normal ); + rb << it->message; + rb.setResultType( ResultWas::Info ); + AssertionResult result = rb.build(); + m_legacyReporter->Result( result ); + } + } + } + m_legacyReporter->Result( assertionStats.assertionResult ); + return true; + } + void LegacyReporterAdapter::sectionEnded( SectionStats const& sectionStats ) { + if( sectionStats.missingAssertions ) + m_legacyReporter->NoAssertionsInSection( sectionStats.sectionInfo.name ); + m_legacyReporter->EndSection( sectionStats.sectionInfo.name, sectionStats.assertions ); + } + void LegacyReporterAdapter::testCaseEnded( TestCaseStats const& testCaseStats ) { + m_legacyReporter->EndTestCase + ( testCaseStats.testInfo, + testCaseStats.totals, + testCaseStats.stdOut, + testCaseStats.stdErr ); + } + void LegacyReporterAdapter::testGroupEnded( TestGroupStats const& testGroupStats ) { + if( testGroupStats.aborting ) + m_legacyReporter->Aborted(); + m_legacyReporter->EndGroup( testGroupStats.groupInfo.name, testGroupStats.totals ); + } + void LegacyReporterAdapter::testRunEnded( TestRunStats const& testRunStats ) { + m_legacyReporter->EndTesting( testRunStats.totals ); + } + void LegacyReporterAdapter::skipTest( TestCaseInfo const& ) { + } +} + +// #included from: catch_timer.hpp + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wc++11-long-long" +#endif + +#ifdef CATCH_PLATFORM_WINDOWS + +#else + +#include + +#endif + +namespace Catch { + + namespace { +#ifdef CATCH_PLATFORM_WINDOWS + UInt64 getCurrentTicks() { + static UInt64 hz=0, hzo=0; + if (!hz) { + QueryPerformanceFrequency( reinterpret_cast( &hz ) ); + QueryPerformanceCounter( reinterpret_cast( &hzo ) ); + } + UInt64 t; + QueryPerformanceCounter( reinterpret_cast( &t ) ); + return ((t-hzo)*1000000)/hz; + } +#else + UInt64 getCurrentTicks() { + timeval t; + gettimeofday(&t,CATCH_NULL); + return static_cast( t.tv_sec ) * 1000000ull + static_cast( t.tv_usec ); + } +#endif + } + + void Timer::start() { + m_ticks = getCurrentTicks(); + } + unsigned int Timer::getElapsedMicroseconds() const { + return static_cast(getCurrentTicks() - m_ticks); + } + unsigned int Timer::getElapsedMilliseconds() const { + return static_cast(getElapsedMicroseconds()/1000); + } + double Timer::getElapsedSeconds() const { + return getElapsedMicroseconds()/1000000.0; + } + +} // namespace Catch + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif +// #included from: catch_common.hpp +#define TWOBLUECUBES_CATCH_COMMON_HPP_INCLUDED + +#include +#include + +namespace Catch { + + bool startsWith( std::string const& s, std::string const& prefix ) { + return s.size() >= prefix.size() && std::equal(prefix.begin(), prefix.end(), s.begin()); + } + bool startsWith( std::string const& s, char prefix ) { + return !s.empty() && s[0] == prefix; + } + bool endsWith( std::string const& s, std::string const& suffix ) { + return s.size() >= suffix.size() && std::equal(suffix.rbegin(), suffix.rend(), s.rbegin()); + } + bool endsWith( std::string const& s, char suffix ) { + return !s.empty() && s[s.size()-1] == suffix; + } + bool contains( std::string const& s, std::string const& infix ) { + return s.find( infix ) != std::string::npos; + } + char toLowerCh(char c) { + return static_cast( std::tolower( c ) ); + } + void toLowerInPlace( std::string& s ) { + std::transform( s.begin(), s.end(), s.begin(), toLowerCh ); + } + std::string toLower( std::string const& s ) { + std::string lc = s; + toLowerInPlace( lc ); + return lc; + } + std::string trim( std::string const& str ) { + static char const* whitespaceChars = "\n\r\t "; + std::string::size_type start = str.find_first_not_of( whitespaceChars ); + std::string::size_type end = str.find_last_not_of( whitespaceChars ); + + return start != std::string::npos ? str.substr( start, 1+end-start ) : std::string(); + } + + bool replaceInPlace( std::string& str, std::string const& replaceThis, std::string const& withThis ) { + bool replaced = false; + std::size_t i = str.find( replaceThis ); + while( i != std::string::npos ) { + replaced = true; + str = str.substr( 0, i ) + withThis + str.substr( i+replaceThis.size() ); + if( i < str.size()-withThis.size() ) + i = str.find( replaceThis, i+withThis.size() ); + else + i = std::string::npos; + } + return replaced; + } + + pluralise::pluralise( std::size_t count, std::string const& label ) + : m_count( count ), + m_label( label ) + {} + + std::ostream& operator << ( std::ostream& os, pluralise const& pluraliser ) { + os << pluraliser.m_count << ' ' << pluraliser.m_label; + if( pluraliser.m_count != 1 ) + os << 's'; + return os; + } + + SourceLineInfo::SourceLineInfo() : file(""), line( 0 ){} + SourceLineInfo::SourceLineInfo( char const* _file, std::size_t _line ) + : file( _file ), + line( _line ) + {} + bool SourceLineInfo::empty() const { + return file[0] == '\0'; + } + bool SourceLineInfo::operator == ( SourceLineInfo const& other ) const { + return line == other.line && (file == other.file || std::strcmp(file, other.file) == 0); + } + bool SourceLineInfo::operator < ( SourceLineInfo const& other ) const { + return line < other.line || ( line == other.line && (std::strcmp(file, other.file) < 0)); + } + + void seedRng( IConfig const& config ) { + if( config.rngSeed() != 0 ) + std::srand( config.rngSeed() ); + } + unsigned int rngSeed() { + return getCurrentContext().getConfig()->rngSeed(); + } + + std::ostream& operator << ( std::ostream& os, SourceLineInfo const& info ) { +#ifndef __GNUG__ + os << info.file << '(' << info.line << ')'; +#else + os << info.file << ':' << info.line; +#endif + return os; + } + + void throwLogicError( std::string const& message, SourceLineInfo const& locationInfo ) { + std::ostringstream oss; + oss << locationInfo << ": Internal Catch error: '" << message << '\''; + if( alwaysTrue() ) + throw std::logic_error( oss.str() ); + } +} + +// #included from: catch_section.hpp +#define TWOBLUECUBES_CATCH_SECTION_HPP_INCLUDED + +namespace Catch { + + SectionInfo::SectionInfo + ( SourceLineInfo const& _lineInfo, + std::string const& _name, + std::string const& _description ) + : name( _name ), + description( _description ), + lineInfo( _lineInfo ) + {} + + Section::Section( SectionInfo const& info ) + : m_info( info ), + m_sectionIncluded( getResultCapture().sectionStarted( m_info, m_assertions ) ) + { + m_timer.start(); + } + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable:4996) // std::uncaught_exception is deprecated in C++17 +#endif + Section::~Section() { + if( m_sectionIncluded ) { + SectionEndInfo endInfo( m_info, m_assertions, m_timer.getElapsedSeconds() ); + if( std::uncaught_exception() ) + getResultCapture().sectionEndedEarly( endInfo ); + else + getResultCapture().sectionEnded( endInfo ); + } + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + + // This indicates whether the section should be executed or not + Section::operator bool() const { + return m_sectionIncluded; + } + +} // end namespace Catch + +// #included from: catch_debugger.hpp +#define TWOBLUECUBES_CATCH_DEBUGGER_HPP_INCLUDED + +#ifdef CATCH_PLATFORM_MAC + + #include + #include + #include + #include + #include + + namespace Catch{ + + // The following function is taken directly from the following technical note: + // http://developer.apple.com/library/mac/#qa/qa2004/qa1361.html + + // Returns true if the current process is being debugged (either + // running under the debugger or has a debugger attached post facto). + bool isDebuggerActive(){ + + int mib[4]; + struct kinfo_proc info; + size_t size; + + // Initialize the flags so that, if sysctl fails for some bizarre + // reason, we get a predictable result. + + info.kp_proc.p_flag = 0; + + // Initialize mib, which tells sysctl the info we want, in this case + // we're looking for information about a specific process ID. + + mib[0] = CTL_KERN; + mib[1] = KERN_PROC; + mib[2] = KERN_PROC_PID; + mib[3] = getpid(); + + // Call sysctl. + + size = sizeof(info); + if( sysctl(mib, sizeof(mib) / sizeof(*mib), &info, &size, CATCH_NULL, 0) != 0 ) { + Catch::cerr() << "\n** Call to sysctl failed - unable to determine if debugger is active **\n" << std::endl; + return false; + } + + // We're being debugged if the P_TRACED flag is set. + + return ( (info.kp_proc.p_flag & P_TRACED) != 0 ); + } + } // namespace Catch + +#elif defined(CATCH_PLATFORM_LINUX) + #include + #include + + namespace Catch{ + // The standard POSIX way of detecting a debugger is to attempt to + // ptrace() the process, but this needs to be done from a child and not + // this process itself to still allow attaching to this process later + // if wanted, so is rather heavy. Under Linux we have the PID of the + // "debugger" (which doesn't need to be gdb, of course, it could also + // be strace, for example) in /proc/$PID/status, so just get it from + // there instead. + bool isDebuggerActive(){ + // Libstdc++ has a bug, where std::ifstream sets errno to 0 + // This way our users can properly assert over errno values + ErrnoGuard guard; + std::ifstream in("/proc/self/status"); + for( std::string line; std::getline(in, line); ) { + static const int PREFIX_LEN = 11; + if( line.compare(0, PREFIX_LEN, "TracerPid:\t") == 0 ) { + // We're traced if the PID is not 0 and no other PID starts + // with 0 digit, so it's enough to check for just a single + // character. + return line.length() > PREFIX_LEN && line[PREFIX_LEN] != '0'; + } + } + + return false; + } + } // namespace Catch +#elif defined(_MSC_VER) + extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent(); + namespace Catch { + bool isDebuggerActive() { + return IsDebuggerPresent() != 0; + } + } +#elif defined(__MINGW32__) + extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent(); + namespace Catch { + bool isDebuggerActive() { + return IsDebuggerPresent() != 0; + } + } +#else + namespace Catch { + inline bool isDebuggerActive() { return false; } + } +#endif // Platform + +#ifdef CATCH_PLATFORM_WINDOWS + + namespace Catch { + void writeToDebugConsole( std::string const& text ) { + ::OutputDebugStringA( text.c_str() ); + } + } +#else + namespace Catch { + void writeToDebugConsole( std::string const& text ) { + // !TBD: Need a version for Mac/ XCode and other IDEs + Catch::cout() << text; + } + } +#endif // Platform + +// #included from: catch_tostring.hpp +#define TWOBLUECUBES_CATCH_TOSTRING_HPP_INCLUDED + +namespace Catch { + +namespace Detail { + + const std::string unprintableString = "{?}"; + + namespace { + const int hexThreshold = 255; + + struct Endianness { + enum Arch { Big, Little }; + + static Arch which() { + union _{ + int asInt; + char asChar[sizeof (int)]; + } u; + + u.asInt = 1; + return ( u.asChar[sizeof(int)-1] == 1 ) ? Big : Little; + } + }; + } + + std::string rawMemoryToString( const void *object, std::size_t size ) + { + // Reverse order for little endian architectures + int i = 0, end = static_cast( size ), inc = 1; + if( Endianness::which() == Endianness::Little ) { + i = end-1; + end = inc = -1; + } + + unsigned char const *bytes = static_cast(object); + std::ostringstream os; + os << "0x" << std::setfill('0') << std::hex; + for( ; i != end; i += inc ) + os << std::setw(2) << static_cast(bytes[i]); + return os.str(); + } +} + +std::string toString( std::string const& value ) { + std::string s = value; + if( getCurrentContext().getConfig()->showInvisibles() ) { + for(size_t i = 0; i < s.size(); ++i ) { + std::string subs; + switch( s[i] ) { + case '\n': subs = "\\n"; break; + case '\t': subs = "\\t"; break; + default: break; + } + if( !subs.empty() ) { + s = s.substr( 0, i ) + subs + s.substr( i+1 ); + ++i; + } + } + } + return '"' + s + '"'; +} +std::string toString( std::wstring const& value ) { + + std::string s; + s.reserve( value.size() ); + for(size_t i = 0; i < value.size(); ++i ) + s += value[i] <= 0xff ? static_cast( value[i] ) : '?'; + return Catch::toString( s ); +} + +std::string toString( const char* const value ) { + return value ? Catch::toString( std::string( value ) ) : std::string( "{null string}" ); +} + +std::string toString( char* const value ) { + return Catch::toString( static_cast( value ) ); +} + +std::string toString( const wchar_t* const value ) +{ + return value ? Catch::toString( std::wstring(value) ) : std::string( "{null string}" ); +} + +std::string toString( wchar_t* const value ) +{ + return Catch::toString( static_cast( value ) ); +} + +std::string toString( int value ) { + std::ostringstream oss; + oss << value; + if( value > Detail::hexThreshold ) + oss << " (0x" << std::hex << value << ')'; + return oss.str(); +} + +std::string toString( unsigned long value ) { + std::ostringstream oss; + oss << value; + if( value > Detail::hexThreshold ) + oss << " (0x" << std::hex << value << ')'; + return oss.str(); +} + +std::string toString( unsigned int value ) { + return Catch::toString( static_cast( value ) ); +} + +template +std::string fpToString( T value, int precision ) { + std::ostringstream oss; + oss << std::setprecision( precision ) + << std::fixed + << value; + std::string d = oss.str(); + std::size_t i = d.find_last_not_of( '0' ); + if( i != std::string::npos && i != d.size()-1 ) { + if( d[i] == '.' ) + i++; + d = d.substr( 0, i+1 ); + } + return d; +} + +std::string toString( const double value ) { + return fpToString( value, 10 ); +} +std::string toString( const float value ) { + return fpToString( value, 5 ) + 'f'; +} + +std::string toString( bool value ) { + return value ? "true" : "false"; +} + +std::string toString( char value ) { + if ( value == '\r' ) + return "'\\r'"; + if ( value == '\f' ) + return "'\\f'"; + if ( value == '\n' ) + return "'\\n'"; + if ( value == '\t' ) + return "'\\t'"; + if ( '\0' <= value && value < ' ' ) + return toString( static_cast( value ) ); + char chstr[] = "' '"; + chstr[1] = value; + return chstr; +} + +std::string toString( signed char value ) { + return toString( static_cast( value ) ); +} + +std::string toString( unsigned char value ) { + return toString( static_cast( value ) ); +} + +#ifdef CATCH_CONFIG_CPP11_LONG_LONG +std::string toString( long long value ) { + std::ostringstream oss; + oss << value; + if( value > Detail::hexThreshold ) + oss << " (0x" << std::hex << value << ')'; + return oss.str(); +} +std::string toString( unsigned long long value ) { + std::ostringstream oss; + oss << value; + if( value > Detail::hexThreshold ) + oss << " (0x" << std::hex << value << ')'; + return oss.str(); +} +#endif + +#ifdef CATCH_CONFIG_CPP11_NULLPTR +std::string toString( std::nullptr_t ) { + return "nullptr"; +} +#endif + +#ifdef __OBJC__ + std::string toString( NSString const * const& nsstring ) { + if( !nsstring ) + return "nil"; + return "@" + toString([nsstring UTF8String]); + } + std::string toString( NSString * CATCH_ARC_STRONG & nsstring ) { + if( !nsstring ) + return "nil"; + return "@" + toString([nsstring UTF8String]); + } + std::string toString( NSObject* const& nsObject ) { + return toString( [nsObject description] ); + } +#endif + +} // end namespace Catch + +// #included from: catch_result_builder.hpp +#define TWOBLUECUBES_CATCH_RESULT_BUILDER_HPP_INCLUDED + +namespace Catch { + + ResultBuilder::ResultBuilder( char const* macroName, + SourceLineInfo const& lineInfo, + char const* capturedExpression, + ResultDisposition::Flags resultDisposition, + char const* secondArg ) + : m_assertionInfo( macroName, lineInfo, capturedExpression, resultDisposition, secondArg ), + m_shouldDebugBreak( false ), + m_shouldThrow( false ), + m_guardException( false ) + { + m_stream().oss.str(""); + } + + ResultBuilder::~ResultBuilder() { +#if defined(CATCH_CONFIG_FAST_COMPILE) + if ( m_guardException ) { + m_stream().oss << "Exception translation was disabled by CATCH_CONFIG_FAST_COMPILE"; + captureResult( ResultWas::ThrewException ); + getCurrentContext().getResultCapture()->exceptionEarlyReported(); + } +#endif + } + + ResultBuilder& ResultBuilder::setResultType( ResultWas::OfType result ) { + m_data.resultType = result; + return *this; + } + ResultBuilder& ResultBuilder::setResultType( bool result ) { + m_data.resultType = result ? ResultWas::Ok : ResultWas::ExpressionFailed; + return *this; + } + + void ResultBuilder::endExpression( DecomposedExpression const& expr ) { + AssertionResult result = build( expr ); + handleResult( result ); + } + + void ResultBuilder::useActiveException( ResultDisposition::Flags resultDisposition ) { + m_assertionInfo.resultDisposition = resultDisposition; + m_stream().oss << Catch::translateActiveException(); + captureResult( ResultWas::ThrewException ); + } + + void ResultBuilder::captureResult( ResultWas::OfType resultType ) { + setResultType( resultType ); + captureExpression(); + } + + void ResultBuilder::captureExpectedException( std::string const& expectedMessage ) { + if( expectedMessage.empty() ) + captureExpectedException( Matchers::Impl::MatchAllOf() ); + else + captureExpectedException( Matchers::Equals( expectedMessage ) ); + } + + void ResultBuilder::captureExpectedException( Matchers::Impl::MatcherBase const& matcher ) { + + assert( !isFalseTest( m_assertionInfo.resultDisposition ) ); + AssertionResultData data = m_data; + data.resultType = ResultWas::Ok; + data.reconstructedExpression = capturedExpressionWithSecondArgument(m_assertionInfo.capturedExpression, m_assertionInfo.secondArg); + + std::string actualMessage = Catch::translateActiveException(); + if( !matcher.match( actualMessage ) ) { + data.resultType = ResultWas::ExpressionFailed; + data.reconstructedExpression = actualMessage; + } + AssertionResult result( m_assertionInfo, data ); + handleResult( result ); + } + + void ResultBuilder::captureExpression() { + AssertionResult result = build(); + handleResult( result ); + } + + void ResultBuilder::handleResult( AssertionResult const& result ) + { + getResultCapture().assertionEnded( result ); + + if( !result.isOk() ) { + if( getCurrentContext().getConfig()->shouldDebugBreak() ) + m_shouldDebugBreak = true; + if( getCurrentContext().getRunner()->aborting() || (m_assertionInfo.resultDisposition & ResultDisposition::Normal) ) + m_shouldThrow = true; + } + } + + void ResultBuilder::react() { +#if defined(CATCH_CONFIG_FAST_COMPILE) + if (m_shouldDebugBreak) { + /////////////////////////////////////////////////////////////////// + // To inspect the state during test, you need to go one level up the callstack + // To go back to the test and change execution, jump over the throw statement + /////////////////////////////////////////////////////////////////// + CATCH_BREAK_INTO_DEBUGGER(); + } +#endif + if( m_shouldThrow ) + throw Catch::TestFailureException(); + } + + bool ResultBuilder::shouldDebugBreak() const { return m_shouldDebugBreak; } + bool ResultBuilder::allowThrows() const { return getCurrentContext().getConfig()->allowThrows(); } + + AssertionResult ResultBuilder::build() const + { + return build( *this ); + } + + // CAVEAT: The returned AssertionResult stores a pointer to the argument expr, + // a temporary DecomposedExpression, which in turn holds references to + // operands, possibly temporary as well. + // It should immediately be passed to handleResult; if the expression + // needs to be reported, its string expansion must be composed before + // the temporaries are destroyed. + AssertionResult ResultBuilder::build( DecomposedExpression const& expr ) const + { + assert( m_data.resultType != ResultWas::Unknown ); + AssertionResultData data = m_data; + + // Flip bool results if FalseTest flag is set + if( isFalseTest( m_assertionInfo.resultDisposition ) ) { + data.negate( expr.isBinaryExpression() ); + } + + data.message = m_stream().oss.str(); + data.decomposedExpression = &expr; // for lazy reconstruction + return AssertionResult( m_assertionInfo, data ); + } + + void ResultBuilder::reconstructExpression( std::string& dest ) const { + dest = capturedExpressionWithSecondArgument(m_assertionInfo.capturedExpression, m_assertionInfo.secondArg); + } + + void ResultBuilder::setExceptionGuard() { + m_guardException = true; + } + void ResultBuilder::unsetExceptionGuard() { + m_guardException = false; + } + +} // end namespace Catch + +// #included from: catch_tag_alias_registry.hpp +#define TWOBLUECUBES_CATCH_TAG_ALIAS_REGISTRY_HPP_INCLUDED + +namespace Catch { + + TagAliasRegistry::~TagAliasRegistry() {} + + Option TagAliasRegistry::find( std::string const& alias ) const { + std::map::const_iterator it = m_registry.find( alias ); + if( it != m_registry.end() ) + return it->second; + else + return Option(); + } + + std::string TagAliasRegistry::expandAliases( std::string const& unexpandedTestSpec ) const { + std::string expandedTestSpec = unexpandedTestSpec; + for( std::map::const_iterator it = m_registry.begin(), itEnd = m_registry.end(); + it != itEnd; + ++it ) { + std::size_t pos = expandedTestSpec.find( it->first ); + if( pos != std::string::npos ) { + expandedTestSpec = expandedTestSpec.substr( 0, pos ) + + it->second.tag + + expandedTestSpec.substr( pos + it->first.size() ); + } + } + return expandedTestSpec; + } + + void TagAliasRegistry::add( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) { + + if( !startsWith( alias, "[@" ) || !endsWith( alias, ']' ) ) { + std::ostringstream oss; + oss << Colour( Colour::Red ) + << "error: tag alias, \"" << alias << "\" is not of the form [@alias name].\n" + << Colour( Colour::FileName ) + << lineInfo << '\n'; + throw std::domain_error( oss.str().c_str() ); + } + if( !m_registry.insert( std::make_pair( alias, TagAlias( tag, lineInfo ) ) ).second ) { + std::ostringstream oss; + oss << Colour( Colour::Red ) + << "error: tag alias, \"" << alias << "\" already registered.\n" + << "\tFirst seen at " + << Colour( Colour::Red ) << find(alias)->lineInfo << '\n' + << Colour( Colour::Red ) << "\tRedefined at " + << Colour( Colour::FileName) << lineInfo << '\n'; + throw std::domain_error( oss.str().c_str() ); + } + } + + ITagAliasRegistry::~ITagAliasRegistry() {} + + ITagAliasRegistry const& ITagAliasRegistry::get() { + return getRegistryHub().getTagAliasRegistry(); + } + + RegistrarForTagAliases::RegistrarForTagAliases( char const* alias, char const* tag, SourceLineInfo const& lineInfo ) { + getMutableRegistryHub().registerTagAlias( alias, tag, lineInfo ); + } + +} // end namespace Catch + +// #included from: catch_matchers_string.hpp + +namespace Catch { +namespace Matchers { + + namespace StdString { + + CasedString::CasedString( std::string const& str, CaseSensitive::Choice caseSensitivity ) + : m_caseSensitivity( caseSensitivity ), + m_str( adjustString( str ) ) + {} + std::string CasedString::adjustString( std::string const& str ) const { + return m_caseSensitivity == CaseSensitive::No + ? toLower( str ) + : str; + } + std::string CasedString::caseSensitivitySuffix() const { + return m_caseSensitivity == CaseSensitive::No + ? " (case insensitive)" + : std::string(); + } + + StringMatcherBase::StringMatcherBase( std::string const& operation, CasedString const& comparator ) + : m_comparator( comparator ), + m_operation( operation ) { + } + + std::string StringMatcherBase::describe() const { + std::string description; + description.reserve(5 + m_operation.size() + m_comparator.m_str.size() + + m_comparator.caseSensitivitySuffix().size()); + description += m_operation; + description += ": \""; + description += m_comparator.m_str; + description += "\""; + description += m_comparator.caseSensitivitySuffix(); + return description; + } + + EqualsMatcher::EqualsMatcher( CasedString const& comparator ) : StringMatcherBase( "equals", comparator ) {} + + bool EqualsMatcher::match( std::string const& source ) const { + return m_comparator.adjustString( source ) == m_comparator.m_str; + } + + ContainsMatcher::ContainsMatcher( CasedString const& comparator ) : StringMatcherBase( "contains", comparator ) {} + + bool ContainsMatcher::match( std::string const& source ) const { + return contains( m_comparator.adjustString( source ), m_comparator.m_str ); + } + + StartsWithMatcher::StartsWithMatcher( CasedString const& comparator ) : StringMatcherBase( "starts with", comparator ) {} + + bool StartsWithMatcher::match( std::string const& source ) const { + return startsWith( m_comparator.adjustString( source ), m_comparator.m_str ); + } + + EndsWithMatcher::EndsWithMatcher( CasedString const& comparator ) : StringMatcherBase( "ends with", comparator ) {} + + bool EndsWithMatcher::match( std::string const& source ) const { + return endsWith( m_comparator.adjustString( source ), m_comparator.m_str ); + } + + } // namespace StdString + + StdString::EqualsMatcher Equals( std::string const& str, CaseSensitive::Choice caseSensitivity ) { + return StdString::EqualsMatcher( StdString::CasedString( str, caseSensitivity) ); + } + StdString::ContainsMatcher Contains( std::string const& str, CaseSensitive::Choice caseSensitivity ) { + return StdString::ContainsMatcher( StdString::CasedString( str, caseSensitivity) ); + } + StdString::EndsWithMatcher EndsWith( std::string const& str, CaseSensitive::Choice caseSensitivity ) { + return StdString::EndsWithMatcher( StdString::CasedString( str, caseSensitivity) ); + } + StdString::StartsWithMatcher StartsWith( std::string const& str, CaseSensitive::Choice caseSensitivity ) { + return StdString::StartsWithMatcher( StdString::CasedString( str, caseSensitivity) ); + } + +} // namespace Matchers +} // namespace Catch +// #included from: ../reporters/catch_reporter_multi.hpp +#define TWOBLUECUBES_CATCH_REPORTER_MULTI_HPP_INCLUDED + +namespace Catch { + +class MultipleReporters : public SharedImpl { + typedef std::vector > Reporters; + Reporters m_reporters; + +public: + void add( Ptr const& reporter ) { + m_reporters.push_back( reporter ); + } + +public: // IStreamingReporter + + virtual ReporterPreferences getPreferences() const CATCH_OVERRIDE { + return m_reporters[0]->getPreferences(); + } + + virtual void noMatchingTestCases( std::string const& spec ) CATCH_OVERRIDE { + for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end(); + it != itEnd; + ++it ) + (*it)->noMatchingTestCases( spec ); + } + + virtual void testRunStarting( TestRunInfo const& testRunInfo ) CATCH_OVERRIDE { + for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end(); + it != itEnd; + ++it ) + (*it)->testRunStarting( testRunInfo ); + } + + virtual void testGroupStarting( GroupInfo const& groupInfo ) CATCH_OVERRIDE { + for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end(); + it != itEnd; + ++it ) + (*it)->testGroupStarting( groupInfo ); + } + + virtual void testCaseStarting( TestCaseInfo const& testInfo ) CATCH_OVERRIDE { + for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end(); + it != itEnd; + ++it ) + (*it)->testCaseStarting( testInfo ); + } + + virtual void sectionStarting( SectionInfo const& sectionInfo ) CATCH_OVERRIDE { + for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end(); + it != itEnd; + ++it ) + (*it)->sectionStarting( sectionInfo ); + } + + virtual void assertionStarting( AssertionInfo const& assertionInfo ) CATCH_OVERRIDE { + for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end(); + it != itEnd; + ++it ) + (*it)->assertionStarting( assertionInfo ); + } + + // The return value indicates if the messages buffer should be cleared: + virtual bool assertionEnded( AssertionStats const& assertionStats ) CATCH_OVERRIDE { + bool clearBuffer = false; + for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end(); + it != itEnd; + ++it ) + clearBuffer |= (*it)->assertionEnded( assertionStats ); + return clearBuffer; + } + + virtual void sectionEnded( SectionStats const& sectionStats ) CATCH_OVERRIDE { + for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end(); + it != itEnd; + ++it ) + (*it)->sectionEnded( sectionStats ); + } + + virtual void testCaseEnded( TestCaseStats const& testCaseStats ) CATCH_OVERRIDE { + for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end(); + it != itEnd; + ++it ) + (*it)->testCaseEnded( testCaseStats ); + } + + virtual void testGroupEnded( TestGroupStats const& testGroupStats ) CATCH_OVERRIDE { + for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end(); + it != itEnd; + ++it ) + (*it)->testGroupEnded( testGroupStats ); + } + + virtual void testRunEnded( TestRunStats const& testRunStats ) CATCH_OVERRIDE { + for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end(); + it != itEnd; + ++it ) + (*it)->testRunEnded( testRunStats ); + } + + virtual void skipTest( TestCaseInfo const& testInfo ) CATCH_OVERRIDE { + for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end(); + it != itEnd; + ++it ) + (*it)->skipTest( testInfo ); + } + + virtual MultipleReporters* tryAsMulti() CATCH_OVERRIDE { + return this; + } + +}; + +Ptr addReporter( Ptr const& existingReporter, Ptr const& additionalReporter ) { + Ptr resultingReporter; + + if( existingReporter ) { + MultipleReporters* multi = existingReporter->tryAsMulti(); + if( !multi ) { + multi = new MultipleReporters; + resultingReporter = Ptr( multi ); + if( existingReporter ) + multi->add( existingReporter ); + } + else + resultingReporter = existingReporter; + multi->add( additionalReporter ); + } + else + resultingReporter = additionalReporter; + + return resultingReporter; +} + +} // end namespace Catch + +// #included from: ../reporters/catch_reporter_xml.hpp +#define TWOBLUECUBES_CATCH_REPORTER_XML_HPP_INCLUDED + +// #included from: catch_reporter_bases.hpp +#define TWOBLUECUBES_CATCH_REPORTER_BASES_HPP_INCLUDED + +#include +#include +#include +#include + +namespace Catch { + + namespace { + // Because formatting using c++ streams is stateful, drop down to C is required + // Alternatively we could use stringstream, but its performance is... not good. + std::string getFormattedDuration( double duration ) { + // Max exponent + 1 is required to represent the whole part + // + 1 for decimal point + // + 3 for the 3 decimal places + // + 1 for null terminator + const size_t maxDoubleSize = DBL_MAX_10_EXP + 1 + 1 + 3 + 1; + char buffer[maxDoubleSize]; + + // Save previous errno, to prevent sprintf from overwriting it + ErrnoGuard guard; +#ifdef _MSC_VER + sprintf_s(buffer, "%.3f", duration); +#else + sprintf(buffer, "%.3f", duration); +#endif + return std::string(buffer); + } + } + + struct StreamingReporterBase : SharedImpl { + + StreamingReporterBase( ReporterConfig const& _config ) + : m_config( _config.fullConfig() ), + stream( _config.stream() ) + { + m_reporterPrefs.shouldRedirectStdOut = false; + } + + virtual ReporterPreferences getPreferences() const CATCH_OVERRIDE { + return m_reporterPrefs; + } + + virtual ~StreamingReporterBase() CATCH_OVERRIDE; + + virtual void noMatchingTestCases( std::string const& ) CATCH_OVERRIDE {} + + virtual void testRunStarting( TestRunInfo const& _testRunInfo ) CATCH_OVERRIDE { + currentTestRunInfo = _testRunInfo; + } + virtual void testGroupStarting( GroupInfo const& _groupInfo ) CATCH_OVERRIDE { + currentGroupInfo = _groupInfo; + } + + virtual void testCaseStarting( TestCaseInfo const& _testInfo ) CATCH_OVERRIDE { + currentTestCaseInfo = _testInfo; + } + virtual void sectionStarting( SectionInfo const& _sectionInfo ) CATCH_OVERRIDE { + m_sectionStack.push_back( _sectionInfo ); + } + + virtual void sectionEnded( SectionStats const& /* _sectionStats */ ) CATCH_OVERRIDE { + m_sectionStack.pop_back(); + } + virtual void testCaseEnded( TestCaseStats const& /* _testCaseStats */ ) CATCH_OVERRIDE { + currentTestCaseInfo.reset(); + } + virtual void testGroupEnded( TestGroupStats const& /* _testGroupStats */ ) CATCH_OVERRIDE { + currentGroupInfo.reset(); + } + virtual void testRunEnded( TestRunStats const& /* _testRunStats */ ) CATCH_OVERRIDE { + currentTestCaseInfo.reset(); + currentGroupInfo.reset(); + currentTestRunInfo.reset(); + } + + virtual void skipTest( TestCaseInfo const& ) CATCH_OVERRIDE { + // Don't do anything with this by default. + // It can optionally be overridden in the derived class. + } + + Ptr m_config; + std::ostream& stream; + + LazyStat currentTestRunInfo; + LazyStat currentGroupInfo; + LazyStat currentTestCaseInfo; + + std::vector m_sectionStack; + ReporterPreferences m_reporterPrefs; + }; + + struct CumulativeReporterBase : SharedImpl { + template + struct Node : SharedImpl<> { + explicit Node( T const& _value ) : value( _value ) {} + virtual ~Node() {} + + typedef std::vector > ChildNodes; + T value; + ChildNodes children; + }; + struct SectionNode : SharedImpl<> { + explicit SectionNode( SectionStats const& _stats ) : stats( _stats ) {} + virtual ~SectionNode(); + + bool operator == ( SectionNode const& other ) const { + return stats.sectionInfo.lineInfo == other.stats.sectionInfo.lineInfo; + } + bool operator == ( Ptr const& other ) const { + return operator==( *other ); + } + + SectionStats stats; + typedef std::vector > ChildSections; + typedef std::vector Assertions; + ChildSections childSections; + Assertions assertions; + std::string stdOut; + std::string stdErr; + }; + + struct BySectionInfo { + BySectionInfo( SectionInfo const& other ) : m_other( other ) {} + BySectionInfo( BySectionInfo const& other ) : m_other( other.m_other ) {} + bool operator() ( Ptr const& node ) const { + return node->stats.sectionInfo.lineInfo == m_other.lineInfo; + } + private: + void operator=( BySectionInfo const& ); + SectionInfo const& m_other; + }; + + typedef Node TestCaseNode; + typedef Node TestGroupNode; + typedef Node TestRunNode; + + CumulativeReporterBase( ReporterConfig const& _config ) + : m_config( _config.fullConfig() ), + stream( _config.stream() ) + { + m_reporterPrefs.shouldRedirectStdOut = false; + } + ~CumulativeReporterBase(); + + virtual ReporterPreferences getPreferences() const CATCH_OVERRIDE { + return m_reporterPrefs; + } + + virtual void testRunStarting( TestRunInfo const& ) CATCH_OVERRIDE {} + virtual void testGroupStarting( GroupInfo const& ) CATCH_OVERRIDE {} + + virtual void testCaseStarting( TestCaseInfo const& ) CATCH_OVERRIDE {} + + virtual void sectionStarting( SectionInfo const& sectionInfo ) CATCH_OVERRIDE { + SectionStats incompleteStats( sectionInfo, Counts(), 0, false ); + Ptr node; + if( m_sectionStack.empty() ) { + if( !m_rootSection ) + m_rootSection = new SectionNode( incompleteStats ); + node = m_rootSection; + } + else { + SectionNode& parentNode = *m_sectionStack.back(); + SectionNode::ChildSections::const_iterator it = + std::find_if( parentNode.childSections.begin(), + parentNode.childSections.end(), + BySectionInfo( sectionInfo ) ); + if( it == parentNode.childSections.end() ) { + node = new SectionNode( incompleteStats ); + parentNode.childSections.push_back( node ); + } + else + node = *it; + } + m_sectionStack.push_back( node ); + m_deepestSection = node; + } + + virtual void assertionStarting( AssertionInfo const& ) CATCH_OVERRIDE {} + + virtual bool assertionEnded( AssertionStats const& assertionStats ) CATCH_OVERRIDE { + assert( !m_sectionStack.empty() ); + SectionNode& sectionNode = *m_sectionStack.back(); + sectionNode.assertions.push_back( assertionStats ); + // AssertionResult holds a pointer to a temporary DecomposedExpression, + // which getExpandedExpression() calls to build the expression string. + // Our section stack copy of the assertionResult will likely outlive the + // temporary, so it must be expanded or discarded now to avoid calling + // a destroyed object later. + prepareExpandedExpression( sectionNode.assertions.back().assertionResult ); + return true; + } + virtual void sectionEnded( SectionStats const& sectionStats ) CATCH_OVERRIDE { + assert( !m_sectionStack.empty() ); + SectionNode& node = *m_sectionStack.back(); + node.stats = sectionStats; + m_sectionStack.pop_back(); + } + virtual void testCaseEnded( TestCaseStats const& testCaseStats ) CATCH_OVERRIDE { + Ptr node = new TestCaseNode( testCaseStats ); + assert( m_sectionStack.size() == 0 ); + node->children.push_back( m_rootSection ); + m_testCases.push_back( node ); + m_rootSection.reset(); + + assert( m_deepestSection ); + m_deepestSection->stdOut = testCaseStats.stdOut; + m_deepestSection->stdErr = testCaseStats.stdErr; + } + virtual void testGroupEnded( TestGroupStats const& testGroupStats ) CATCH_OVERRIDE { + Ptr node = new TestGroupNode( testGroupStats ); + node->children.swap( m_testCases ); + m_testGroups.push_back( node ); + } + virtual void testRunEnded( TestRunStats const& testRunStats ) CATCH_OVERRIDE { + Ptr node = new TestRunNode( testRunStats ); + node->children.swap( m_testGroups ); + m_testRuns.push_back( node ); + testRunEndedCumulative(); + } + virtual void testRunEndedCumulative() = 0; + + virtual void skipTest( TestCaseInfo const& ) CATCH_OVERRIDE {} + + virtual void prepareExpandedExpression( AssertionResult& result ) const { + if( result.isOk() ) + result.discardDecomposedExpression(); + else + result.expandDecomposedExpression(); + } + + Ptr m_config; + std::ostream& stream; + std::vector m_assertions; + std::vector > > m_sections; + std::vector > m_testCases; + std::vector > m_testGroups; + + std::vector > m_testRuns; + + Ptr m_rootSection; + Ptr m_deepestSection; + std::vector > m_sectionStack; + ReporterPreferences m_reporterPrefs; + + }; + + template + char const* getLineOfChars() { + static char line[CATCH_CONFIG_CONSOLE_WIDTH] = {0}; + if( !*line ) { + std::memset( line, C, CATCH_CONFIG_CONSOLE_WIDTH-1 ); + line[CATCH_CONFIG_CONSOLE_WIDTH-1] = 0; + } + return line; + } + + struct TestEventListenerBase : StreamingReporterBase { + TestEventListenerBase( ReporterConfig const& _config ) + : StreamingReporterBase( _config ) + {} + + virtual void assertionStarting( AssertionInfo const& ) CATCH_OVERRIDE {} + virtual bool assertionEnded( AssertionStats const& ) CATCH_OVERRIDE { + return false; + } + }; + +} // end namespace Catch + +// #included from: ../internal/catch_reporter_registrars.hpp +#define TWOBLUECUBES_CATCH_REPORTER_REGISTRARS_HPP_INCLUDED + +namespace Catch { + + template + class LegacyReporterRegistrar { + + class ReporterFactory : public IReporterFactory { + virtual IStreamingReporter* create( ReporterConfig const& config ) const { + return new LegacyReporterAdapter( new T( config ) ); + } + + virtual std::string getDescription() const { + return T::getDescription(); + } + }; + + public: + + LegacyReporterRegistrar( std::string const& name ) { + getMutableRegistryHub().registerReporter( name, new ReporterFactory() ); + } + }; + + template + class ReporterRegistrar { + + class ReporterFactory : public SharedImpl { + + // *** Please Note ***: + // - If you end up here looking at a compiler error because it's trying to register + // your custom reporter class be aware that the native reporter interface has changed + // to IStreamingReporter. The "legacy" interface, IReporter, is still supported via + // an adapter. Just use REGISTER_LEGACY_REPORTER to take advantage of the adapter. + // However please consider updating to the new interface as the old one is now + // deprecated and will probably be removed quite soon! + // Please contact me via github if you have any questions at all about this. + // In fact, ideally, please contact me anyway to let me know you've hit this - as I have + // no idea who is actually using custom reporters at all (possibly no-one!). + // The new interface is designed to minimise exposure to interface changes in the future. + virtual IStreamingReporter* create( ReporterConfig const& config ) const { + return new T( config ); + } + + virtual std::string getDescription() const { + return T::getDescription(); + } + }; + + public: + + ReporterRegistrar( std::string const& name ) { + getMutableRegistryHub().registerReporter( name, new ReporterFactory() ); + } + }; + + template + class ListenerRegistrar { + + class ListenerFactory : public SharedImpl { + + virtual IStreamingReporter* create( ReporterConfig const& config ) const { + return new T( config ); + } + virtual std::string getDescription() const { + return std::string(); + } + }; + + public: + + ListenerRegistrar() { + getMutableRegistryHub().registerListener( new ListenerFactory() ); + } + }; +} + +#define INTERNAL_CATCH_REGISTER_LEGACY_REPORTER( name, reporterType ) \ + namespace{ Catch::LegacyReporterRegistrar catch_internal_RegistrarFor##reporterType( name ); } + +#define INTERNAL_CATCH_REGISTER_REPORTER( name, reporterType ) \ + namespace{ Catch::ReporterRegistrar catch_internal_RegistrarFor##reporterType( name ); } + +// Deprecated - use the form without INTERNAL_ +#define INTERNAL_CATCH_REGISTER_LISTENER( listenerType ) \ + namespace{ Catch::ListenerRegistrar catch_internal_RegistrarFor##listenerType; } + +#define CATCH_REGISTER_LISTENER( listenerType ) \ + namespace{ Catch::ListenerRegistrar catch_internal_RegistrarFor##listenerType; } + +// #included from: ../internal/catch_xmlwriter.hpp +#define TWOBLUECUBES_CATCH_XMLWRITER_HPP_INCLUDED + +#include +#include +#include +#include + +namespace Catch { + + class XmlEncode { + public: + enum ForWhat { ForTextNodes, ForAttributes }; + + XmlEncode( std::string const& str, ForWhat forWhat = ForTextNodes ) + : m_str( str ), + m_forWhat( forWhat ) + {} + + void encodeTo( std::ostream& os ) const { + + // Apostrophe escaping not necessary if we always use " to write attributes + // (see: http://www.w3.org/TR/xml/#syntax) + + for( std::size_t i = 0; i < m_str.size(); ++ i ) { + char c = m_str[i]; + switch( c ) { + case '<': os << "<"; break; + case '&': os << "&"; break; + + case '>': + // See: http://www.w3.org/TR/xml/#syntax + if( i > 2 && m_str[i-1] == ']' && m_str[i-2] == ']' ) + os << ">"; + else + os << c; + break; + + case '\"': + if( m_forWhat == ForAttributes ) + os << """; + else + os << c; + break; + + default: + // Escape control chars - based on contribution by @espenalb in PR #465 and + // by @mrpi PR #588 + if ( ( c >= 0 && c < '\x09' ) || ( c > '\x0D' && c < '\x20') || c=='\x7F' ) { + // see http://stackoverflow.com/questions/404107/why-are-control-characters-illegal-in-xml-1-0 + os << "\\x" << std::uppercase << std::hex << std::setfill('0') << std::setw(2) + << static_cast( c ); + } + else + os << c; + } + } + } + + friend std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ) { + xmlEncode.encodeTo( os ); + return os; + } + + private: + std::string m_str; + ForWhat m_forWhat; + }; + + class XmlWriter { + public: + + class ScopedElement { + public: + ScopedElement( XmlWriter* writer ) + : m_writer( writer ) + {} + + ScopedElement( ScopedElement const& other ) + : m_writer( other.m_writer ){ + other.m_writer = CATCH_NULL; + } + + ~ScopedElement() { + if( m_writer ) + m_writer->endElement(); + } + + ScopedElement& writeText( std::string const& text, bool indent = true ) { + m_writer->writeText( text, indent ); + return *this; + } + + template + ScopedElement& writeAttribute( std::string const& name, T const& attribute ) { + m_writer->writeAttribute( name, attribute ); + return *this; + } + + private: + mutable XmlWriter* m_writer; + }; + + XmlWriter() + : m_tagIsOpen( false ), + m_needsNewline( false ), + m_os( Catch::cout() ) + { + writeDeclaration(); + } + + XmlWriter( std::ostream& os ) + : m_tagIsOpen( false ), + m_needsNewline( false ), + m_os( os ) + { + writeDeclaration(); + } + + ~XmlWriter() { + while( !m_tags.empty() ) + endElement(); + } + + XmlWriter& startElement( std::string const& name ) { + ensureTagClosed(); + newlineIfNecessary(); + m_os << m_indent << '<' << name; + m_tags.push_back( name ); + m_indent += " "; + m_tagIsOpen = true; + return *this; + } + + ScopedElement scopedElement( std::string const& name ) { + ScopedElement scoped( this ); + startElement( name ); + return scoped; + } + + XmlWriter& endElement() { + newlineIfNecessary(); + m_indent = m_indent.substr( 0, m_indent.size()-2 ); + if( m_tagIsOpen ) { + m_os << "/>"; + m_tagIsOpen = false; + } + else { + m_os << m_indent << ""; + } + m_os << std::endl; + m_tags.pop_back(); + return *this; + } + + XmlWriter& writeAttribute( std::string const& name, std::string const& attribute ) { + if( !name.empty() && !attribute.empty() ) + m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; + return *this; + } + + XmlWriter& writeAttribute( std::string const& name, bool attribute ) { + m_os << ' ' << name << "=\"" << ( attribute ? "true" : "false" ) << '"'; + return *this; + } + + template + XmlWriter& writeAttribute( std::string const& name, T const& attribute ) { + std::ostringstream oss; + oss << attribute; + return writeAttribute( name, oss.str() ); + } + + XmlWriter& writeText( std::string const& text, bool indent = true ) { + if( !text.empty() ){ + bool tagWasOpen = m_tagIsOpen; + ensureTagClosed(); + if( tagWasOpen && indent ) + m_os << m_indent; + m_os << XmlEncode( text ); + m_needsNewline = true; + } + return *this; + } + + XmlWriter& writeComment( std::string const& text ) { + ensureTagClosed(); + m_os << m_indent << ""; + m_needsNewline = true; + return *this; + } + + void writeStylesheetRef( std::string const& url ) { + m_os << "\n"; + } + + XmlWriter& writeBlankLine() { + ensureTagClosed(); + m_os << '\n'; + return *this; + } + + void ensureTagClosed() { + if( m_tagIsOpen ) { + m_os << ">" << std::endl; + m_tagIsOpen = false; + } + } + + private: + XmlWriter( XmlWriter const& ); + void operator=( XmlWriter const& ); + + void writeDeclaration() { + m_os << "\n"; + } + + void newlineIfNecessary() { + if( m_needsNewline ) { + m_os << std::endl; + m_needsNewline = false; + } + } + + bool m_tagIsOpen; + bool m_needsNewline; + std::vector m_tags; + std::string m_indent; + std::ostream& m_os; + }; + +} + +namespace Catch { + class XmlReporter : public StreamingReporterBase { + public: + XmlReporter( ReporterConfig const& _config ) + : StreamingReporterBase( _config ), + m_xml(_config.stream()), + m_sectionDepth( 0 ) + { + m_reporterPrefs.shouldRedirectStdOut = true; + } + + virtual ~XmlReporter() CATCH_OVERRIDE; + + static std::string getDescription() { + return "Reports test results as an XML document"; + } + + virtual std::string getStylesheetRef() const { + return std::string(); + } + + void writeSourceInfo( SourceLineInfo const& sourceInfo ) { + m_xml + .writeAttribute( "filename", sourceInfo.file ) + .writeAttribute( "line", sourceInfo.line ); + } + + public: // StreamingReporterBase + + virtual void noMatchingTestCases( std::string const& s ) CATCH_OVERRIDE { + StreamingReporterBase::noMatchingTestCases( s ); + } + + virtual void testRunStarting( TestRunInfo const& testInfo ) CATCH_OVERRIDE { + StreamingReporterBase::testRunStarting( testInfo ); + std::string stylesheetRef = getStylesheetRef(); + if( !stylesheetRef.empty() ) + m_xml.writeStylesheetRef( stylesheetRef ); + m_xml.startElement( "Catch" ); + if( !m_config->name().empty() ) + m_xml.writeAttribute( "name", m_config->name() ); + } + + virtual void testGroupStarting( GroupInfo const& groupInfo ) CATCH_OVERRIDE { + StreamingReporterBase::testGroupStarting( groupInfo ); + m_xml.startElement( "Group" ) + .writeAttribute( "name", groupInfo.name ); + } + + virtual void testCaseStarting( TestCaseInfo const& testInfo ) CATCH_OVERRIDE { + StreamingReporterBase::testCaseStarting(testInfo); + m_xml.startElement( "TestCase" ) + .writeAttribute( "name", trim( testInfo.name ) ) + .writeAttribute( "description", testInfo.description ) + .writeAttribute( "tags", testInfo.tagsAsString ); + + writeSourceInfo( testInfo.lineInfo ); + + if ( m_config->showDurations() == ShowDurations::Always ) + m_testCaseTimer.start(); + m_xml.ensureTagClosed(); + } + + virtual void sectionStarting( SectionInfo const& sectionInfo ) CATCH_OVERRIDE { + StreamingReporterBase::sectionStarting( sectionInfo ); + if( m_sectionDepth++ > 0 ) { + m_xml.startElement( "Section" ) + .writeAttribute( "name", trim( sectionInfo.name ) ) + .writeAttribute( "description", sectionInfo.description ); + writeSourceInfo( sectionInfo.lineInfo ); + m_xml.ensureTagClosed(); + } + } + + virtual void assertionStarting( AssertionInfo const& ) CATCH_OVERRIDE { } + + virtual bool assertionEnded( AssertionStats const& assertionStats ) CATCH_OVERRIDE { + + AssertionResult const& result = assertionStats.assertionResult; + + bool includeResults = m_config->includeSuccessfulResults() || !result.isOk(); + + if( includeResults ) { + // Print any info messages in tags. + for( std::vector::const_iterator it = assertionStats.infoMessages.begin(), itEnd = assertionStats.infoMessages.end(); + it != itEnd; + ++it ) { + if( it->type == ResultWas::Info ) { + m_xml.scopedElement( "Info" ) + .writeText( it->message ); + } else if ( it->type == ResultWas::Warning ) { + m_xml.scopedElement( "Warning" ) + .writeText( it->message ); + } + } + } + + // Drop out if result was successful but we're not printing them. + if( !includeResults && result.getResultType() != ResultWas::Warning ) + return true; + + // Print the expression if there is one. + if( result.hasExpression() ) { + m_xml.startElement( "Expression" ) + .writeAttribute( "success", result.succeeded() ) + .writeAttribute( "type", result.getTestMacroName() ); + + writeSourceInfo( result.getSourceInfo() ); + + m_xml.scopedElement( "Original" ) + .writeText( result.getExpression() ); + m_xml.scopedElement( "Expanded" ) + .writeText( result.getExpandedExpression() ); + } + + // And... Print a result applicable to each result type. + switch( result.getResultType() ) { + case ResultWas::ThrewException: + m_xml.startElement( "Exception" ); + writeSourceInfo( result.getSourceInfo() ); + m_xml.writeText( result.getMessage() ); + m_xml.endElement(); + break; + case ResultWas::FatalErrorCondition: + m_xml.startElement( "FatalErrorCondition" ); + writeSourceInfo( result.getSourceInfo() ); + m_xml.writeText( result.getMessage() ); + m_xml.endElement(); + break; + case ResultWas::Info: + m_xml.scopedElement( "Info" ) + .writeText( result.getMessage() ); + break; + case ResultWas::Warning: + // Warning will already have been written + break; + case ResultWas::ExplicitFailure: + m_xml.startElement( "Failure" ); + writeSourceInfo( result.getSourceInfo() ); + m_xml.writeText( result.getMessage() ); + m_xml.endElement(); + break; + default: + break; + } + + if( result.hasExpression() ) + m_xml.endElement(); + + return true; + } + + virtual void sectionEnded( SectionStats const& sectionStats ) CATCH_OVERRIDE { + StreamingReporterBase::sectionEnded( sectionStats ); + if( --m_sectionDepth > 0 ) { + XmlWriter::ScopedElement e = m_xml.scopedElement( "OverallResults" ); + e.writeAttribute( "successes", sectionStats.assertions.passed ); + e.writeAttribute( "failures", sectionStats.assertions.failed ); + e.writeAttribute( "expectedFailures", sectionStats.assertions.failedButOk ); + + if ( m_config->showDurations() == ShowDurations::Always ) + e.writeAttribute( "durationInSeconds", sectionStats.durationInSeconds ); + + m_xml.endElement(); + } + } + + virtual void testCaseEnded( TestCaseStats const& testCaseStats ) CATCH_OVERRIDE { + StreamingReporterBase::testCaseEnded( testCaseStats ); + XmlWriter::ScopedElement e = m_xml.scopedElement( "OverallResult" ); + e.writeAttribute( "success", testCaseStats.totals.assertions.allOk() ); + + if ( m_config->showDurations() == ShowDurations::Always ) + e.writeAttribute( "durationInSeconds", m_testCaseTimer.getElapsedSeconds() ); + + if( !testCaseStats.stdOut.empty() ) + m_xml.scopedElement( "StdOut" ).writeText( trim( testCaseStats.stdOut ), false ); + if( !testCaseStats.stdErr.empty() ) + m_xml.scopedElement( "StdErr" ).writeText( trim( testCaseStats.stdErr ), false ); + + m_xml.endElement(); + } + + virtual void testGroupEnded( TestGroupStats const& testGroupStats ) CATCH_OVERRIDE { + StreamingReporterBase::testGroupEnded( testGroupStats ); + // TODO: Check testGroupStats.aborting and act accordingly. + m_xml.scopedElement( "OverallResults" ) + .writeAttribute( "successes", testGroupStats.totals.assertions.passed ) + .writeAttribute( "failures", testGroupStats.totals.assertions.failed ) + .writeAttribute( "expectedFailures", testGroupStats.totals.assertions.failedButOk ); + m_xml.endElement(); + } + + virtual void testRunEnded( TestRunStats const& testRunStats ) CATCH_OVERRIDE { + StreamingReporterBase::testRunEnded( testRunStats ); + m_xml.scopedElement( "OverallResults" ) + .writeAttribute( "successes", testRunStats.totals.assertions.passed ) + .writeAttribute( "failures", testRunStats.totals.assertions.failed ) + .writeAttribute( "expectedFailures", testRunStats.totals.assertions.failedButOk ); + m_xml.endElement(); + } + + private: + Timer m_testCaseTimer; + XmlWriter m_xml; + int m_sectionDepth; + }; + + INTERNAL_CATCH_REGISTER_REPORTER( "xml", XmlReporter ) + +} // end namespace Catch + +// #included from: ../reporters/catch_reporter_junit.hpp +#define TWOBLUECUBES_CATCH_REPORTER_JUNIT_HPP_INCLUDED + +#include + +namespace Catch { + + namespace { + std::string getCurrentTimestamp() { + // Beware, this is not reentrant because of backward compatibility issues + // Also, UTC only, again because of backward compatibility (%z is C++11) + time_t rawtime; + std::time(&rawtime); + const size_t timeStampSize = sizeof("2017-01-16T17:06:45Z"); + +#ifdef _MSC_VER + std::tm timeInfo = {}; + gmtime_s(&timeInfo, &rawtime); +#else + std::tm* timeInfo; + timeInfo = std::gmtime(&rawtime); +#endif + + char timeStamp[timeStampSize]; + const char * const fmt = "%Y-%m-%dT%H:%M:%SZ"; + +#ifdef _MSC_VER + std::strftime(timeStamp, timeStampSize, fmt, &timeInfo); +#else + std::strftime(timeStamp, timeStampSize, fmt, timeInfo); +#endif + return std::string(timeStamp); + } + + } + + class JunitReporter : public CumulativeReporterBase { + public: + JunitReporter( ReporterConfig const& _config ) + : CumulativeReporterBase( _config ), + xml( _config.stream() ), + m_okToFail( false ) + { + m_reporterPrefs.shouldRedirectStdOut = true; + } + + virtual ~JunitReporter() CATCH_OVERRIDE; + + static std::string getDescription() { + return "Reports test results in an XML format that looks like Ant's junitreport target"; + } + + virtual void noMatchingTestCases( std::string const& /*spec*/ ) CATCH_OVERRIDE {} + + virtual void testRunStarting( TestRunInfo const& runInfo ) CATCH_OVERRIDE { + CumulativeReporterBase::testRunStarting( runInfo ); + xml.startElement( "testsuites" ); + } + + virtual void testGroupStarting( GroupInfo const& groupInfo ) CATCH_OVERRIDE { + suiteTimer.start(); + stdOutForSuite.str(""); + stdErrForSuite.str(""); + unexpectedExceptions = 0; + CumulativeReporterBase::testGroupStarting( groupInfo ); + } + + virtual void testCaseStarting( TestCaseInfo const& testCaseInfo ) CATCH_OVERRIDE { + m_okToFail = testCaseInfo.okToFail(); + } + virtual bool assertionEnded( AssertionStats const& assertionStats ) CATCH_OVERRIDE { + if( assertionStats.assertionResult.getResultType() == ResultWas::ThrewException && !m_okToFail ) + unexpectedExceptions++; + return CumulativeReporterBase::assertionEnded( assertionStats ); + } + + virtual void testCaseEnded( TestCaseStats const& testCaseStats ) CATCH_OVERRIDE { + stdOutForSuite << testCaseStats.stdOut; + stdErrForSuite << testCaseStats.stdErr; + CumulativeReporterBase::testCaseEnded( testCaseStats ); + } + + virtual void testGroupEnded( TestGroupStats const& testGroupStats ) CATCH_OVERRIDE { + double suiteTime = suiteTimer.getElapsedSeconds(); + CumulativeReporterBase::testGroupEnded( testGroupStats ); + writeGroup( *m_testGroups.back(), suiteTime ); + } + + virtual void testRunEndedCumulative() CATCH_OVERRIDE { + xml.endElement(); + } + + void writeGroup( TestGroupNode const& groupNode, double suiteTime ) { + XmlWriter::ScopedElement e = xml.scopedElement( "testsuite" ); + TestGroupStats const& stats = groupNode.value; + xml.writeAttribute( "name", stats.groupInfo.name ); + xml.writeAttribute( "errors", unexpectedExceptions ); + xml.writeAttribute( "failures", stats.totals.assertions.failed-unexpectedExceptions ); + xml.writeAttribute( "tests", stats.totals.assertions.total() ); + xml.writeAttribute( "hostname", "tbd" ); // !TBD + if( m_config->showDurations() == ShowDurations::Never ) + xml.writeAttribute( "time", "" ); + else + xml.writeAttribute( "time", suiteTime ); + xml.writeAttribute( "timestamp", getCurrentTimestamp() ); + + // Write test cases + for( TestGroupNode::ChildNodes::const_iterator + it = groupNode.children.begin(), itEnd = groupNode.children.end(); + it != itEnd; + ++it ) + writeTestCase( **it ); + + xml.scopedElement( "system-out" ).writeText( trim( stdOutForSuite.str() ), false ); + xml.scopedElement( "system-err" ).writeText( trim( stdErrForSuite.str() ), false ); + } + + void writeTestCase( TestCaseNode const& testCaseNode ) { + TestCaseStats const& stats = testCaseNode.value; + + // All test cases have exactly one section - which represents the + // test case itself. That section may have 0-n nested sections + assert( testCaseNode.children.size() == 1 ); + SectionNode const& rootSection = *testCaseNode.children.front(); + + std::string className = stats.testInfo.className; + + if( className.empty() ) { + if( rootSection.childSections.empty() ) + className = "global"; + } + writeSection( className, "", rootSection ); + } + + void writeSection( std::string const& className, + std::string const& rootName, + SectionNode const& sectionNode ) { + std::string name = trim( sectionNode.stats.sectionInfo.name ); + if( !rootName.empty() ) + name = rootName + '/' + name; + + if( !sectionNode.assertions.empty() || + !sectionNode.stdOut.empty() || + !sectionNode.stdErr.empty() ) { + XmlWriter::ScopedElement e = xml.scopedElement( "testcase" ); + if( className.empty() ) { + xml.writeAttribute( "classname", name ); + xml.writeAttribute( "name", "root" ); + } + else { + xml.writeAttribute( "classname", className ); + xml.writeAttribute( "name", name ); + } + xml.writeAttribute( "time", Catch::toString( sectionNode.stats.durationInSeconds ) ); + + writeAssertions( sectionNode ); + + if( !sectionNode.stdOut.empty() ) + xml.scopedElement( "system-out" ).writeText( trim( sectionNode.stdOut ), false ); + if( !sectionNode.stdErr.empty() ) + xml.scopedElement( "system-err" ).writeText( trim( sectionNode.stdErr ), false ); + } + for( SectionNode::ChildSections::const_iterator + it = sectionNode.childSections.begin(), + itEnd = sectionNode.childSections.end(); + it != itEnd; + ++it ) + if( className.empty() ) + writeSection( name, "", **it ); + else + writeSection( className, name, **it ); + } + + void writeAssertions( SectionNode const& sectionNode ) { + for( SectionNode::Assertions::const_iterator + it = sectionNode.assertions.begin(), itEnd = sectionNode.assertions.end(); + it != itEnd; + ++it ) + writeAssertion( *it ); + } + void writeAssertion( AssertionStats const& stats ) { + AssertionResult const& result = stats.assertionResult; + if( !result.isOk() ) { + std::string elementName; + switch( result.getResultType() ) { + case ResultWas::ThrewException: + case ResultWas::FatalErrorCondition: + elementName = "error"; + break; + case ResultWas::ExplicitFailure: + elementName = "failure"; + break; + case ResultWas::ExpressionFailed: + elementName = "failure"; + break; + case ResultWas::DidntThrowException: + elementName = "failure"; + break; + + // We should never see these here: + case ResultWas::Info: + case ResultWas::Warning: + case ResultWas::Ok: + case ResultWas::Unknown: + case ResultWas::FailureBit: + case ResultWas::Exception: + elementName = "internalError"; + break; + } + + XmlWriter::ScopedElement e = xml.scopedElement( elementName ); + + xml.writeAttribute( "message", result.getExpandedExpression() ); + xml.writeAttribute( "type", result.getTestMacroName() ); + + std::ostringstream oss; + if( !result.getMessage().empty() ) + oss << result.getMessage() << '\n'; + for( std::vector::const_iterator + it = stats.infoMessages.begin(), + itEnd = stats.infoMessages.end(); + it != itEnd; + ++it ) + if( it->type == ResultWas::Info ) + oss << it->message << '\n'; + + oss << "at " << result.getSourceInfo(); + xml.writeText( oss.str(), false ); + } + } + + XmlWriter xml; + Timer suiteTimer; + std::ostringstream stdOutForSuite; + std::ostringstream stdErrForSuite; + unsigned int unexpectedExceptions; + bool m_okToFail; + }; + + INTERNAL_CATCH_REGISTER_REPORTER( "junit", JunitReporter ) + +} // end namespace Catch + +// #included from: ../reporters/catch_reporter_console.hpp +#define TWOBLUECUBES_CATCH_REPORTER_CONSOLE_HPP_INCLUDED + +#include +#include + +namespace Catch { + + struct ConsoleReporter : StreamingReporterBase { + ConsoleReporter( ReporterConfig const& _config ) + : StreamingReporterBase( _config ), + m_headerPrinted( false ) + {} + + virtual ~ConsoleReporter() CATCH_OVERRIDE; + static std::string getDescription() { + return "Reports test results as plain lines of text"; + } + + virtual void noMatchingTestCases( std::string const& spec ) CATCH_OVERRIDE { + stream << "No test cases matched '" << spec << '\'' << std::endl; + } + + virtual void assertionStarting( AssertionInfo const& ) CATCH_OVERRIDE { + } + + virtual bool assertionEnded( AssertionStats const& _assertionStats ) CATCH_OVERRIDE { + AssertionResult const& result = _assertionStats.assertionResult; + + bool includeResults = m_config->includeSuccessfulResults() || !result.isOk(); + + // Drop out if result was successful but we're not printing them. + if( !includeResults && result.getResultType() != ResultWas::Warning ) + return false; + + lazyPrint(); + + AssertionPrinter printer( stream, _assertionStats, includeResults ); + printer.print(); + stream << std::endl; + return true; + } + + virtual void sectionStarting( SectionInfo const& _sectionInfo ) CATCH_OVERRIDE { + m_headerPrinted = false; + StreamingReporterBase::sectionStarting( _sectionInfo ); + } + virtual void sectionEnded( SectionStats const& _sectionStats ) CATCH_OVERRIDE { + if( _sectionStats.missingAssertions ) { + lazyPrint(); + Colour colour( Colour::ResultError ); + if( m_sectionStack.size() > 1 ) + stream << "\nNo assertions in section"; + else + stream << "\nNo assertions in test case"; + stream << " '" << _sectionStats.sectionInfo.name << "'\n" << std::endl; + } + if( m_config->showDurations() == ShowDurations::Always ) { + stream << getFormattedDuration(_sectionStats.durationInSeconds) << " s: " << _sectionStats.sectionInfo.name << std::endl; + } + if( m_headerPrinted ) { + m_headerPrinted = false; + } + StreamingReporterBase::sectionEnded( _sectionStats ); + } + + virtual void testCaseEnded( TestCaseStats const& _testCaseStats ) CATCH_OVERRIDE { + StreamingReporterBase::testCaseEnded( _testCaseStats ); + m_headerPrinted = false; + } + virtual void testGroupEnded( TestGroupStats const& _testGroupStats ) CATCH_OVERRIDE { + if( currentGroupInfo.used ) { + printSummaryDivider(); + stream << "Summary for group '" << _testGroupStats.groupInfo.name << "':\n"; + printTotals( _testGroupStats.totals ); + stream << '\n' << std::endl; + } + StreamingReporterBase::testGroupEnded( _testGroupStats ); + } + virtual void testRunEnded( TestRunStats const& _testRunStats ) CATCH_OVERRIDE { + printTotalsDivider( _testRunStats.totals ); + printTotals( _testRunStats.totals ); + stream << std::endl; + StreamingReporterBase::testRunEnded( _testRunStats ); + } + + private: + + class AssertionPrinter { + void operator= ( AssertionPrinter const& ); + public: + AssertionPrinter( std::ostream& _stream, AssertionStats const& _stats, bool _printInfoMessages ) + : stream( _stream ), + stats( _stats ), + result( _stats.assertionResult ), + colour( Colour::None ), + message( result.getMessage() ), + messages( _stats.infoMessages ), + printInfoMessages( _printInfoMessages ) + { + switch( result.getResultType() ) { + case ResultWas::Ok: + colour = Colour::Success; + passOrFail = "PASSED"; + //if( result.hasMessage() ) + if( _stats.infoMessages.size() == 1 ) + messageLabel = "with message"; + if( _stats.infoMessages.size() > 1 ) + messageLabel = "with messages"; + break; + case ResultWas::ExpressionFailed: + if( result.isOk() ) { + colour = Colour::Success; + passOrFail = "FAILED - but was ok"; + } + else { + colour = Colour::Error; + passOrFail = "FAILED"; + } + if( _stats.infoMessages.size() == 1 ) + messageLabel = "with message"; + if( _stats.infoMessages.size() > 1 ) + messageLabel = "with messages"; + break; + case ResultWas::ThrewException: + colour = Colour::Error; + passOrFail = "FAILED"; + messageLabel = "due to unexpected exception with "; + if (_stats.infoMessages.size() == 1) + messageLabel += "message"; + if (_stats.infoMessages.size() > 1) + messageLabel += "messages"; + break; + case ResultWas::FatalErrorCondition: + colour = Colour::Error; + passOrFail = "FAILED"; + messageLabel = "due to a fatal error condition"; + break; + case ResultWas::DidntThrowException: + colour = Colour::Error; + passOrFail = "FAILED"; + messageLabel = "because no exception was thrown where one was expected"; + break; + case ResultWas::Info: + messageLabel = "info"; + break; + case ResultWas::Warning: + messageLabel = "warning"; + break; + case ResultWas::ExplicitFailure: + passOrFail = "FAILED"; + colour = Colour::Error; + if( _stats.infoMessages.size() == 1 ) + messageLabel = "explicitly with message"; + if( _stats.infoMessages.size() > 1 ) + messageLabel = "explicitly with messages"; + break; + // These cases are here to prevent compiler warnings + case ResultWas::Unknown: + case ResultWas::FailureBit: + case ResultWas::Exception: + passOrFail = "** internal error **"; + colour = Colour::Error; + break; + } + } + + void print() const { + printSourceInfo(); + if( stats.totals.assertions.total() > 0 ) { + if( result.isOk() ) + stream << '\n'; + printResultType(); + printOriginalExpression(); + printReconstructedExpression(); + } + else { + stream << '\n'; + } + printMessage(); + } + + private: + void printResultType() const { + if( !passOrFail.empty() ) { + Colour colourGuard( colour ); + stream << passOrFail << ":\n"; + } + } + void printOriginalExpression() const { + if( result.hasExpression() ) { + Colour colourGuard( Colour::OriginalExpression ); + stream << " "; + stream << result.getExpressionInMacro(); + stream << '\n'; + } + } + void printReconstructedExpression() const { + if( result.hasExpandedExpression() ) { + stream << "with expansion:\n"; + Colour colourGuard( Colour::ReconstructedExpression ); + stream << Text( result.getExpandedExpression(), TextAttributes().setIndent(2) ) << '\n'; + } + } + void printMessage() const { + if( !messageLabel.empty() ) + stream << messageLabel << ':' << '\n'; + for( std::vector::const_iterator it = messages.begin(), itEnd = messages.end(); + it != itEnd; + ++it ) { + // If this assertion is a warning ignore any INFO messages + if( printInfoMessages || it->type != ResultWas::Info ) + stream << Text( it->message, TextAttributes().setIndent(2) ) << '\n'; + } + } + void printSourceInfo() const { + Colour colourGuard( Colour::FileName ); + stream << result.getSourceInfo() << ": "; + } + + std::ostream& stream; + AssertionStats const& stats; + AssertionResult const& result; + Colour::Code colour; + std::string passOrFail; + std::string messageLabel; + std::string message; + std::vector messages; + bool printInfoMessages; + }; + + void lazyPrint() { + + if( !currentTestRunInfo.used ) + lazyPrintRunInfo(); + if( !currentGroupInfo.used ) + lazyPrintGroupInfo(); + + if( !m_headerPrinted ) { + printTestCaseAndSectionHeader(); + m_headerPrinted = true; + } + } + void lazyPrintRunInfo() { + stream << '\n' << getLineOfChars<'~'>() << '\n'; + Colour colour( Colour::SecondaryText ); + stream << currentTestRunInfo->name + << " is a Catch v" << libraryVersion() << " host application.\n" + << "Run with -? for options\n\n"; + + if( m_config->rngSeed() != 0 ) + stream << "Randomness seeded to: " << m_config->rngSeed() << "\n\n"; + + currentTestRunInfo.used = true; + } + void lazyPrintGroupInfo() { + if( !currentGroupInfo->name.empty() && currentGroupInfo->groupsCounts > 1 ) { + printClosedHeader( "Group: " + currentGroupInfo->name ); + currentGroupInfo.used = true; + } + } + void printTestCaseAndSectionHeader() { + assert( !m_sectionStack.empty() ); + printOpenHeader( currentTestCaseInfo->name ); + + if( m_sectionStack.size() > 1 ) { + Colour colourGuard( Colour::Headers ); + + std::vector::const_iterator + it = m_sectionStack.begin()+1, // Skip first section (test case) + itEnd = m_sectionStack.end(); + for( ; it != itEnd; ++it ) + printHeaderString( it->name, 2 ); + } + + SourceLineInfo lineInfo = m_sectionStack.back().lineInfo; + + if( !lineInfo.empty() ){ + stream << getLineOfChars<'-'>() << '\n'; + Colour colourGuard( Colour::FileName ); + stream << lineInfo << '\n'; + } + stream << getLineOfChars<'.'>() << '\n' << std::endl; + } + + void printClosedHeader( std::string const& _name ) { + printOpenHeader( _name ); + stream << getLineOfChars<'.'>() << '\n'; + } + void printOpenHeader( std::string const& _name ) { + stream << getLineOfChars<'-'>() << '\n'; + { + Colour colourGuard( Colour::Headers ); + printHeaderString( _name ); + } + } + + // if string has a : in first line will set indent to follow it on + // subsequent lines + void printHeaderString( std::string const& _string, std::size_t indent = 0 ) { + std::size_t i = _string.find( ": " ); + if( i != std::string::npos ) + i+=2; + else + i = 0; + stream << Text( _string, TextAttributes() + .setIndent( indent+i) + .setInitialIndent( indent ) ) << '\n'; + } + + struct SummaryColumn { + + SummaryColumn( std::string const& _label, Colour::Code _colour ) + : label( _label ), + colour( _colour ) + {} + SummaryColumn addRow( std::size_t count ) { + std::ostringstream oss; + oss << count; + std::string row = oss.str(); + for( std::vector::iterator it = rows.begin(); it != rows.end(); ++it ) { + while( it->size() < row.size() ) + *it = ' ' + *it; + while( it->size() > row.size() ) + row = ' ' + row; + } + rows.push_back( row ); + return *this; + } + + std::string label; + Colour::Code colour; + std::vector rows; + + }; + + void printTotals( Totals const& totals ) { + if( totals.testCases.total() == 0 ) { + stream << Colour( Colour::Warning ) << "No tests ran\n"; + } + else if( totals.assertions.total() > 0 && totals.testCases.allPassed() ) { + stream << Colour( Colour::ResultSuccess ) << "All tests passed"; + stream << " (" + << pluralise( totals.assertions.passed, "assertion" ) << " in " + << pluralise( totals.testCases.passed, "test case" ) << ')' + << '\n'; + } + else { + + std::vector columns; + columns.push_back( SummaryColumn( "", Colour::None ) + .addRow( totals.testCases.total() ) + .addRow( totals.assertions.total() ) ); + columns.push_back( SummaryColumn( "passed", Colour::Success ) + .addRow( totals.testCases.passed ) + .addRow( totals.assertions.passed ) ); + columns.push_back( SummaryColumn( "failed", Colour::ResultError ) + .addRow( totals.testCases.failed ) + .addRow( totals.assertions.failed ) ); + columns.push_back( SummaryColumn( "failed as expected", Colour::ResultExpectedFailure ) + .addRow( totals.testCases.failedButOk ) + .addRow( totals.assertions.failedButOk ) ); + + printSummaryRow( "test cases", columns, 0 ); + printSummaryRow( "assertions", columns, 1 ); + } + } + void printSummaryRow( std::string const& label, std::vector const& cols, std::size_t row ) { + for( std::vector::const_iterator it = cols.begin(); it != cols.end(); ++it ) { + std::string value = it->rows[row]; + if( it->label.empty() ) { + stream << label << ": "; + if( value != "0" ) + stream << value; + else + stream << Colour( Colour::Warning ) << "- none -"; + } + else if( value != "0" ) { + stream << Colour( Colour::LightGrey ) << " | "; + stream << Colour( it->colour ) + << value << ' ' << it->label; + } + } + stream << '\n'; + } + + static std::size_t makeRatio( std::size_t number, std::size_t total ) { + std::size_t ratio = total > 0 ? CATCH_CONFIG_CONSOLE_WIDTH * number/ total : 0; + return ( ratio == 0 && number > 0 ) ? 1 : ratio; + } + static std::size_t& findMax( std::size_t& i, std::size_t& j, std::size_t& k ) { + if( i > j && i > k ) + return i; + else if( j > k ) + return j; + else + return k; + } + + void printTotalsDivider( Totals const& totals ) { + if( totals.testCases.total() > 0 ) { + std::size_t failedRatio = makeRatio( totals.testCases.failed, totals.testCases.total() ); + std::size_t failedButOkRatio = makeRatio( totals.testCases.failedButOk, totals.testCases.total() ); + std::size_t passedRatio = makeRatio( totals.testCases.passed, totals.testCases.total() ); + while( failedRatio + failedButOkRatio + passedRatio < CATCH_CONFIG_CONSOLE_WIDTH-1 ) + findMax( failedRatio, failedButOkRatio, passedRatio )++; + while( failedRatio + failedButOkRatio + passedRatio > CATCH_CONFIG_CONSOLE_WIDTH-1 ) + findMax( failedRatio, failedButOkRatio, passedRatio )--; + + stream << Colour( Colour::Error ) << std::string( failedRatio, '=' ); + stream << Colour( Colour::ResultExpectedFailure ) << std::string( failedButOkRatio, '=' ); + if( totals.testCases.allPassed() ) + stream << Colour( Colour::ResultSuccess ) << std::string( passedRatio, '=' ); + else + stream << Colour( Colour::Success ) << std::string( passedRatio, '=' ); + } + else { + stream << Colour( Colour::Warning ) << std::string( CATCH_CONFIG_CONSOLE_WIDTH-1, '=' ); + } + stream << '\n'; + } + void printSummaryDivider() { + stream << getLineOfChars<'-'>() << '\n'; + } + + private: + bool m_headerPrinted; + }; + + INTERNAL_CATCH_REGISTER_REPORTER( "console", ConsoleReporter ) + +} // end namespace Catch + +// #included from: ../reporters/catch_reporter_compact.hpp +#define TWOBLUECUBES_CATCH_REPORTER_COMPACT_HPP_INCLUDED + +namespace Catch { + + struct CompactReporter : StreamingReporterBase { + + CompactReporter( ReporterConfig const& _config ) + : StreamingReporterBase( _config ) + {} + + virtual ~CompactReporter(); + + static std::string getDescription() { + return "Reports test results on a single line, suitable for IDEs"; + } + + virtual ReporterPreferences getPreferences() const { + ReporterPreferences prefs; + prefs.shouldRedirectStdOut = false; + return prefs; + } + + virtual void noMatchingTestCases( std::string const& spec ) { + stream << "No test cases matched '" << spec << '\'' << std::endl; + } + + virtual void assertionStarting( AssertionInfo const& ) {} + + virtual bool assertionEnded( AssertionStats const& _assertionStats ) { + AssertionResult const& result = _assertionStats.assertionResult; + + bool printInfoMessages = true; + + // Drop out if result was successful and we're not printing those + if( !m_config->includeSuccessfulResults() && result.isOk() ) { + if( result.getResultType() != ResultWas::Warning ) + return false; + printInfoMessages = false; + } + + AssertionPrinter printer( stream, _assertionStats, printInfoMessages ); + printer.print(); + + stream << std::endl; + return true; + } + + virtual void sectionEnded(SectionStats const& _sectionStats) CATCH_OVERRIDE { + if (m_config->showDurations() == ShowDurations::Always) { + stream << getFormattedDuration(_sectionStats.durationInSeconds) << " s: " << _sectionStats.sectionInfo.name << std::endl; + } + } + + virtual void testRunEnded( TestRunStats const& _testRunStats ) { + printTotals( _testRunStats.totals ); + stream << '\n' << std::endl; + StreamingReporterBase::testRunEnded( _testRunStats ); + } + + private: + class AssertionPrinter { + void operator= ( AssertionPrinter const& ); + public: + AssertionPrinter( std::ostream& _stream, AssertionStats const& _stats, bool _printInfoMessages ) + : stream( _stream ) + , stats( _stats ) + , result( _stats.assertionResult ) + , messages( _stats.infoMessages ) + , itMessage( _stats.infoMessages.begin() ) + , printInfoMessages( _printInfoMessages ) + {} + + void print() { + printSourceInfo(); + + itMessage = messages.begin(); + + switch( result.getResultType() ) { + case ResultWas::Ok: + printResultType( Colour::ResultSuccess, passedString() ); + printOriginalExpression(); + printReconstructedExpression(); + if ( ! result.hasExpression() ) + printRemainingMessages( Colour::None ); + else + printRemainingMessages(); + break; + case ResultWas::ExpressionFailed: + if( result.isOk() ) + printResultType( Colour::ResultSuccess, failedString() + std::string( " - but was ok" ) ); + else + printResultType( Colour::Error, failedString() ); + printOriginalExpression(); + printReconstructedExpression(); + printRemainingMessages(); + break; + case ResultWas::ThrewException: + printResultType( Colour::Error, failedString() ); + printIssue( "unexpected exception with message:" ); + printMessage(); + printExpressionWas(); + printRemainingMessages(); + break; + case ResultWas::FatalErrorCondition: + printResultType( Colour::Error, failedString() ); + printIssue( "fatal error condition with message:" ); + printMessage(); + printExpressionWas(); + printRemainingMessages(); + break; + case ResultWas::DidntThrowException: + printResultType( Colour::Error, failedString() ); + printIssue( "expected exception, got none" ); + printExpressionWas(); + printRemainingMessages(); + break; + case ResultWas::Info: + printResultType( Colour::None, "info" ); + printMessage(); + printRemainingMessages(); + break; + case ResultWas::Warning: + printResultType( Colour::None, "warning" ); + printMessage(); + printRemainingMessages(); + break; + case ResultWas::ExplicitFailure: + printResultType( Colour::Error, failedString() ); + printIssue( "explicitly" ); + printRemainingMessages( Colour::None ); + break; + // These cases are here to prevent compiler warnings + case ResultWas::Unknown: + case ResultWas::FailureBit: + case ResultWas::Exception: + printResultType( Colour::Error, "** internal error **" ); + break; + } + } + + private: + // Colour::LightGrey + + static Colour::Code dimColour() { return Colour::FileName; } + +#ifdef CATCH_PLATFORM_MAC + static const char* failedString() { return "FAILED"; } + static const char* passedString() { return "PASSED"; } +#else + static const char* failedString() { return "failed"; } + static const char* passedString() { return "passed"; } +#endif + + void printSourceInfo() const { + Colour colourGuard( Colour::FileName ); + stream << result.getSourceInfo() << ':'; + } + + void printResultType( Colour::Code colour, std::string const& passOrFail ) const { + if( !passOrFail.empty() ) { + { + Colour colourGuard( colour ); + stream << ' ' << passOrFail; + } + stream << ':'; + } + } + + void printIssue( std::string const& issue ) const { + stream << ' ' << issue; + } + + void printExpressionWas() { + if( result.hasExpression() ) { + stream << ';'; + { + Colour colour( dimColour() ); + stream << " expression was:"; + } + printOriginalExpression(); + } + } + + void printOriginalExpression() const { + if( result.hasExpression() ) { + stream << ' ' << result.getExpression(); + } + } + + void printReconstructedExpression() const { + if( result.hasExpandedExpression() ) { + { + Colour colour( dimColour() ); + stream << " for: "; + } + stream << result.getExpandedExpression(); + } + } + + void printMessage() { + if ( itMessage != messages.end() ) { + stream << " '" << itMessage->message << '\''; + ++itMessage; + } + } + + void printRemainingMessages( Colour::Code colour = dimColour() ) { + if ( itMessage == messages.end() ) + return; + + // using messages.end() directly yields compilation error: + std::vector::const_iterator itEnd = messages.end(); + const std::size_t N = static_cast( std::distance( itMessage, itEnd ) ); + + { + Colour colourGuard( colour ); + stream << " with " << pluralise( N, "message" ) << ':'; + } + + for(; itMessage != itEnd; ) { + // If this assertion is a warning ignore any INFO messages + if( printInfoMessages || itMessage->type != ResultWas::Info ) { + stream << " '" << itMessage->message << '\''; + if ( ++itMessage != itEnd ) { + Colour colourGuard( dimColour() ); + stream << " and"; + } + } + } + } + + private: + std::ostream& stream; + AssertionStats const& stats; + AssertionResult const& result; + std::vector messages; + std::vector::const_iterator itMessage; + bool printInfoMessages; + }; + + // Colour, message variants: + // - white: No tests ran. + // - red: Failed [both/all] N test cases, failed [both/all] M assertions. + // - white: Passed [both/all] N test cases (no assertions). + // - red: Failed N tests cases, failed M assertions. + // - green: Passed [both/all] N tests cases with M assertions. + + std::string bothOrAll( std::size_t count ) const { + return count == 1 ? std::string() : count == 2 ? "both " : "all " ; + } + + void printTotals( const Totals& totals ) const { + if( totals.testCases.total() == 0 ) { + stream << "No tests ran."; + } + else if( totals.testCases.failed == totals.testCases.total() ) { + Colour colour( Colour::ResultError ); + const std::string qualify_assertions_failed = + totals.assertions.failed == totals.assertions.total() ? + bothOrAll( totals.assertions.failed ) : std::string(); + stream << + "Failed " << bothOrAll( totals.testCases.failed ) + << pluralise( totals.testCases.failed, "test case" ) << ", " + "failed " << qualify_assertions_failed << + pluralise( totals.assertions.failed, "assertion" ) << '.'; + } + else if( totals.assertions.total() == 0 ) { + stream << + "Passed " << bothOrAll( totals.testCases.total() ) + << pluralise( totals.testCases.total(), "test case" ) + << " (no assertions)."; + } + else if( totals.assertions.failed ) { + Colour colour( Colour::ResultError ); + stream << + "Failed " << pluralise( totals.testCases.failed, "test case" ) << ", " + "failed " << pluralise( totals.assertions.failed, "assertion" ) << '.'; + } + else { + Colour colour( Colour::ResultSuccess ); + stream << + "Passed " << bothOrAll( totals.testCases.passed ) + << pluralise( totals.testCases.passed, "test case" ) << + " with " << pluralise( totals.assertions.passed, "assertion" ) << '.'; + } + } + }; + + INTERNAL_CATCH_REGISTER_REPORTER( "compact", CompactReporter ) + +} // end namespace Catch + +namespace Catch { + // These are all here to avoid warnings about not having any out of line + // virtual methods + NonCopyable::~NonCopyable() {} + IShared::~IShared() {} + IStream::~IStream() CATCH_NOEXCEPT {} + FileStream::~FileStream() CATCH_NOEXCEPT {} + CoutStream::~CoutStream() CATCH_NOEXCEPT {} + DebugOutStream::~DebugOutStream() CATCH_NOEXCEPT {} + StreamBufBase::~StreamBufBase() CATCH_NOEXCEPT {} + IContext::~IContext() {} + IResultCapture::~IResultCapture() {} + ITestCase::~ITestCase() {} + ITestCaseRegistry::~ITestCaseRegistry() {} + IRegistryHub::~IRegistryHub() {} + IMutableRegistryHub::~IMutableRegistryHub() {} + IExceptionTranslator::~IExceptionTranslator() {} + IExceptionTranslatorRegistry::~IExceptionTranslatorRegistry() {} + IReporter::~IReporter() {} + IReporterFactory::~IReporterFactory() {} + IReporterRegistry::~IReporterRegistry() {} + IStreamingReporter::~IStreamingReporter() {} + AssertionStats::~AssertionStats() {} + SectionStats::~SectionStats() {} + TestCaseStats::~TestCaseStats() {} + TestGroupStats::~TestGroupStats() {} + TestRunStats::~TestRunStats() {} + CumulativeReporterBase::SectionNode::~SectionNode() {} + CumulativeReporterBase::~CumulativeReporterBase() {} + + StreamingReporterBase::~StreamingReporterBase() {} + ConsoleReporter::~ConsoleReporter() {} + CompactReporter::~CompactReporter() {} + IRunner::~IRunner() {} + IMutableContext::~IMutableContext() {} + IConfig::~IConfig() {} + XmlReporter::~XmlReporter() {} + JunitReporter::~JunitReporter() {} + TestRegistry::~TestRegistry() {} + FreeFunctionTestCase::~FreeFunctionTestCase() {} + IGeneratorInfo::~IGeneratorInfo() {} + IGeneratorsForTest::~IGeneratorsForTest() {} + WildcardPattern::~WildcardPattern() {} + TestSpec::Pattern::~Pattern() {} + TestSpec::NamePattern::~NamePattern() {} + TestSpec::TagPattern::~TagPattern() {} + TestSpec::ExcludedPattern::~ExcludedPattern() {} + Matchers::Impl::MatcherUntypedBase::~MatcherUntypedBase() {} + + void Config::dummy() {} + + namespace TestCaseTracking { + ITracker::~ITracker() {} + TrackerBase::~TrackerBase() {} + SectionTracker::~SectionTracker() {} + IndexTracker::~IndexTracker() {} + } +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#endif + +#ifdef CATCH_CONFIG_MAIN +// #included from: internal/catch_default_main.hpp +#define TWOBLUECUBES_CATCH_DEFAULT_MAIN_HPP_INCLUDED + +#ifndef __OBJC__ + +#if defined(WIN32) && defined(_UNICODE) && !defined(DO_NOT_USE_WMAIN) +// Standard C/C++ Win32 Unicode wmain entry point +extern "C" int wmain (int argc, wchar_t * argv[], wchar_t * []) { +#else +// Standard C/C++ main entry point +int main (int argc, char * argv[]) { +#endif + + int result = Catch::Session().run( argc, argv ); + return ( result < 0xff ? result : 0xff ); +} + +#else // __OBJC__ + +// Objective-C entry point +int main (int argc, char * const argv[]) { +#if !CATCH_ARC_ENABLED + NSAutoreleasePool * pool = [[NSAutoreleasePool alloc] init]; +#endif + + Catch::registerTestMethods(); + int result = Catch::Session().run( argc, (char* const*)argv ); + +#if !CATCH_ARC_ENABLED + [pool drain]; +#endif + + return ( result < 0xff ? result : 0xff ); +} + +#endif // __OBJC__ + +#endif + +#ifdef CLARA_CONFIG_MAIN_NOT_DEFINED +# undef CLARA_CONFIG_MAIN +#endif + +////// + +// If this config identifier is defined then all CATCH macros are prefixed with CATCH_ +#ifdef CATCH_CONFIG_PREFIX_ALL + +#if defined(CATCH_CONFIG_FAST_COMPILE) +#define CATCH_REQUIRE( expr ) INTERNAL_CATCH_TEST_NO_TRY( "CATCH_REQUIRE", Catch::ResultDisposition::Normal, expr ) +#define CATCH_REQUIRE_FALSE( expr ) INTERNAL_CATCH_TEST_NO_TRY( "CATCH_REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, expr ) +#else +#define CATCH_REQUIRE( expr ) INTERNAL_CATCH_TEST( "CATCH_REQUIRE", Catch::ResultDisposition::Normal, expr ) +#define CATCH_REQUIRE_FALSE( expr ) INTERNAL_CATCH_TEST( "CATCH_REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, expr ) +#endif + +#define CATCH_REQUIRE_THROWS( expr ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, "", expr ) +#define CATCH_REQUIRE_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "CATCH_REQUIRE_THROWS_AS", exceptionType, Catch::ResultDisposition::Normal, expr ) +#define CATCH_REQUIRE_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS_WITH", Catch::ResultDisposition::Normal, matcher, expr ) +#define CATCH_REQUIRE_NOTHROW( expr ) INTERNAL_CATCH_NO_THROW( "CATCH_REQUIRE_NOTHROW", Catch::ResultDisposition::Normal, expr ) + +#define CATCH_CHECK( expr ) INTERNAL_CATCH_TEST( "CATCH_CHECK", Catch::ResultDisposition::ContinueOnFailure, expr ) +#define CATCH_CHECK_FALSE( expr ) INTERNAL_CATCH_TEST( "CATCH_CHECK_FALSE", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::FalseTest, expr ) +#define CATCH_CHECKED_IF( expr ) INTERNAL_CATCH_IF( "CATCH_CHECKED_IF", Catch::ResultDisposition::ContinueOnFailure, expr ) +#define CATCH_CHECKED_ELSE( expr ) INTERNAL_CATCH_ELSE( "CATCH_CHECKED_ELSE", Catch::ResultDisposition::ContinueOnFailure, expr ) +#define CATCH_CHECK_NOFAIL( expr ) INTERNAL_CATCH_TEST( "CATCH_CHECK_NOFAIL", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::SuppressFail, expr ) + +#define CATCH_CHECK_THROWS( expr ) INTERNAL_CATCH_THROWS( "CATCH_CHECK_THROWS", Catch::ResultDisposition::ContinueOnFailure, "", expr ) +#define CATCH_CHECK_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "CATCH_CHECK_THROWS_AS", exceptionType, Catch::ResultDisposition::ContinueOnFailure, expr ) +#define CATCH_CHECK_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS( "CATCH_CHECK_THROWS_WITH", Catch::ResultDisposition::ContinueOnFailure, matcher, expr ) +#define CATCH_CHECK_NOTHROW( expr ) INTERNAL_CATCH_NO_THROW( "CATCH_CHECK_NOTHROW", Catch::ResultDisposition::ContinueOnFailure, expr ) + +#define CATCH_CHECK_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "CATCH_CHECK_THAT", matcher, Catch::ResultDisposition::ContinueOnFailure, arg ) + +#if defined(CATCH_CONFIG_FAST_COMPILE) +#define CATCH_REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT_NO_TRY( "CATCH_REQUIRE_THAT", matcher, Catch::ResultDisposition::Normal, arg ) +#else +#define CATCH_REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "CATCH_REQUIRE_THAT", matcher, Catch::ResultDisposition::Normal, arg ) +#endif + +#define CATCH_INFO( msg ) INTERNAL_CATCH_INFO( "CATCH_INFO", msg ) +#define CATCH_WARN( msg ) INTERNAL_CATCH_MSG( "CATCH_WARN", Catch::ResultWas::Warning, Catch::ResultDisposition::ContinueOnFailure, msg ) +#define CATCH_SCOPED_INFO( msg ) INTERNAL_CATCH_INFO( "CATCH_INFO", msg ) +#define CATCH_CAPTURE( msg ) INTERNAL_CATCH_INFO( "CATCH_CAPTURE", #msg " := " << Catch::toString(msg) ) +#define CATCH_SCOPED_CAPTURE( msg ) INTERNAL_CATCH_INFO( "CATCH_CAPTURE", #msg " := " << Catch::toString(msg) ) + +#ifdef CATCH_CONFIG_VARIADIC_MACROS + #define CATCH_TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE( __VA_ARGS__ ) + #define CATCH_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, __VA_ARGS__ ) + #define CATCH_METHOD_AS_TEST_CASE( method, ... ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, __VA_ARGS__ ) + #define CATCH_REGISTER_TEST_CASE( Function, ... ) INTERNAL_CATCH_REGISTER_TESTCASE( Function, __VA_ARGS__ ) + #define CATCH_SECTION( ... ) INTERNAL_CATCH_SECTION( __VA_ARGS__ ) + #define CATCH_FAIL( ... ) INTERNAL_CATCH_MSG( "CATCH_FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, __VA_ARGS__ ) + #define CATCH_FAIL_CHECK( ... ) INTERNAL_CATCH_MSG( "CATCH_FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) + #define CATCH_SUCCEED( ... ) INTERNAL_CATCH_MSG( "CATCH_SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#else + #define CATCH_TEST_CASE( name, description ) INTERNAL_CATCH_TESTCASE( name, description ) + #define CATCH_TEST_CASE_METHOD( className, name, description ) INTERNAL_CATCH_TEST_CASE_METHOD( className, name, description ) + #define CATCH_METHOD_AS_TEST_CASE( method, name, description ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, name, description ) + #define CATCH_REGISTER_TEST_CASE( function, name, description ) INTERNAL_CATCH_REGISTER_TESTCASE( function, name, description ) + #define CATCH_SECTION( name, description ) INTERNAL_CATCH_SECTION( name, description ) + #define CATCH_FAIL( msg ) INTERNAL_CATCH_MSG( "CATCH_FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, msg ) + #define CATCH_FAIL_CHECK( msg ) INTERNAL_CATCH_MSG( "CATCH_FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, msg ) + #define CATCH_SUCCEED( msg ) INTERNAL_CATCH_MSG( "CATCH_SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, msg ) +#endif +#define CATCH_ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE( "", "" ) + +#define CATCH_REGISTER_REPORTER( name, reporterType ) INTERNAL_CATCH_REGISTER_REPORTER( name, reporterType ) +#define CATCH_REGISTER_LEGACY_REPORTER( name, reporterType ) INTERNAL_CATCH_REGISTER_LEGACY_REPORTER( name, reporterType ) + +#define CATCH_GENERATE( expr) INTERNAL_CATCH_GENERATE( expr ) + +// "BDD-style" convenience wrappers +#ifdef CATCH_CONFIG_VARIADIC_MACROS +#define CATCH_SCENARIO( ... ) CATCH_TEST_CASE( "Scenario: " __VA_ARGS__ ) +#define CATCH_SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, "Scenario: " __VA_ARGS__ ) +#else +#define CATCH_SCENARIO( name, tags ) CATCH_TEST_CASE( "Scenario: " name, tags ) +#define CATCH_SCENARIO_METHOD( className, name, tags ) INTERNAL_CATCH_TEST_CASE_METHOD( className, "Scenario: " name, tags ) +#endif +#define CATCH_GIVEN( desc ) CATCH_SECTION( std::string( "Given: ") + desc, "" ) +#define CATCH_WHEN( desc ) CATCH_SECTION( std::string( " When: ") + desc, "" ) +#define CATCH_AND_WHEN( desc ) CATCH_SECTION( std::string( " And: ") + desc, "" ) +#define CATCH_THEN( desc ) CATCH_SECTION( std::string( " Then: ") + desc, "" ) +#define CATCH_AND_THEN( desc ) CATCH_SECTION( std::string( " And: ") + desc, "" ) + +// If CATCH_CONFIG_PREFIX_ALL is not defined then the CATCH_ prefix is not required +#else + +#if defined(CATCH_CONFIG_FAST_COMPILE) +#define REQUIRE( expr ) INTERNAL_CATCH_TEST_NO_TRY( "REQUIRE", Catch::ResultDisposition::Normal, expr ) +#define REQUIRE_FALSE( expr ) INTERNAL_CATCH_TEST_NO_TRY( "REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, expr ) + +#else +#define REQUIRE( expr ) INTERNAL_CATCH_TEST( "REQUIRE", Catch::ResultDisposition::Normal, expr ) +#define REQUIRE_FALSE( expr ) INTERNAL_CATCH_TEST( "REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, expr ) +#endif + +#define REQUIRE_THROWS( expr ) INTERNAL_CATCH_THROWS( "REQUIRE_THROWS", Catch::ResultDisposition::Normal, "", expr ) +#define REQUIRE_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "REQUIRE_THROWS_AS", exceptionType, Catch::ResultDisposition::Normal, expr ) +#define REQUIRE_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS( "REQUIRE_THROWS_WITH", Catch::ResultDisposition::Normal, matcher, expr ) +#define REQUIRE_NOTHROW( expr ) INTERNAL_CATCH_NO_THROW( "REQUIRE_NOTHROW", Catch::ResultDisposition::Normal, expr ) + +#define CHECK( expr ) INTERNAL_CATCH_TEST( "CHECK", Catch::ResultDisposition::ContinueOnFailure, expr ) +#define CHECK_FALSE( expr ) INTERNAL_CATCH_TEST( "CHECK_FALSE", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::FalseTest, expr ) +#define CHECKED_IF( expr ) INTERNAL_CATCH_IF( "CHECKED_IF", Catch::ResultDisposition::ContinueOnFailure, expr ) +#define CHECKED_ELSE( expr ) INTERNAL_CATCH_ELSE( "CHECKED_ELSE", Catch::ResultDisposition::ContinueOnFailure, expr ) +#define CHECK_NOFAIL( expr ) INTERNAL_CATCH_TEST( "CHECK_NOFAIL", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::SuppressFail, expr ) + +#define CHECK_THROWS( expr ) INTERNAL_CATCH_THROWS( "CHECK_THROWS", Catch::ResultDisposition::ContinueOnFailure, "", expr ) +#define CHECK_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "CHECK_THROWS_AS", exceptionType, Catch::ResultDisposition::ContinueOnFailure, expr ) +#define CHECK_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS( "CHECK_THROWS_WITH", Catch::ResultDisposition::ContinueOnFailure, matcher, expr ) +#define CHECK_NOTHROW( expr ) INTERNAL_CATCH_NO_THROW( "CHECK_NOTHROW", Catch::ResultDisposition::ContinueOnFailure, expr ) + +#define CHECK_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "CHECK_THAT", matcher, Catch::ResultDisposition::ContinueOnFailure, arg ) + +#if defined(CATCH_CONFIG_FAST_COMPILE) +#define REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT_NO_TRY( "REQUIRE_THAT", matcher, Catch::ResultDisposition::Normal, arg ) +#else +#define REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "REQUIRE_THAT", matcher, Catch::ResultDisposition::Normal, arg ) +#endif + +#define INFO( msg ) INTERNAL_CATCH_INFO( "INFO", msg ) +#define WARN( msg ) INTERNAL_CATCH_MSG( "WARN", Catch::ResultWas::Warning, Catch::ResultDisposition::ContinueOnFailure, msg ) +#define SCOPED_INFO( msg ) INTERNAL_CATCH_INFO( "INFO", msg ) +#define CAPTURE( msg ) INTERNAL_CATCH_INFO( "CAPTURE", #msg " := " << Catch::toString(msg) ) +#define SCOPED_CAPTURE( msg ) INTERNAL_CATCH_INFO( "CAPTURE", #msg " := " << Catch::toString(msg) ) + +#ifdef CATCH_CONFIG_VARIADIC_MACROS +#define TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE( __VA_ARGS__ ) +#define TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#define METHOD_AS_TEST_CASE( method, ... ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, __VA_ARGS__ ) +#define REGISTER_TEST_CASE( Function, ... ) INTERNAL_CATCH_REGISTER_TESTCASE( Function, __VA_ARGS__ ) +#define SECTION( ... ) INTERNAL_CATCH_SECTION( __VA_ARGS__ ) +#define FAIL( ... ) INTERNAL_CATCH_MSG( "FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define FAIL_CHECK( ... ) INTERNAL_CATCH_MSG( "FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define SUCCEED( ... ) INTERNAL_CATCH_MSG( "SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#else +#define TEST_CASE( name, description ) INTERNAL_CATCH_TESTCASE( name, description ) + #define TEST_CASE_METHOD( className, name, description ) INTERNAL_CATCH_TEST_CASE_METHOD( className, name, description ) + #define METHOD_AS_TEST_CASE( method, name, description ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, name, description ) + #define REGISTER_TEST_CASE( method, name, description ) INTERNAL_CATCH_REGISTER_TESTCASE( method, name, description ) + #define SECTION( name, description ) INTERNAL_CATCH_SECTION( name, description ) + #define FAIL( msg ) INTERNAL_CATCH_MSG( "FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, msg ) + #define FAIL_CHECK( msg ) INTERNAL_CATCH_MSG( "FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, msg ) + #define SUCCEED( msg ) INTERNAL_CATCH_MSG( "SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, msg ) +#endif +#define ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE( "", "" ) + +#define REGISTER_REPORTER( name, reporterType ) INTERNAL_CATCH_REGISTER_REPORTER( name, reporterType ) +#define REGISTER_LEGACY_REPORTER( name, reporterType ) INTERNAL_CATCH_REGISTER_LEGACY_REPORTER( name, reporterType ) + +#define GENERATE( expr) INTERNAL_CATCH_GENERATE( expr ) + +#endif + +#define CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION( signature ) + +// "BDD-style" convenience wrappers +#ifdef CATCH_CONFIG_VARIADIC_MACROS +#define SCENARIO( ... ) TEST_CASE( "Scenario: " __VA_ARGS__ ) +#define SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, "Scenario: " __VA_ARGS__ ) +#else +#define SCENARIO( name, tags ) TEST_CASE( "Scenario: " name, tags ) +#define SCENARIO_METHOD( className, name, tags ) INTERNAL_CATCH_TEST_CASE_METHOD( className, "Scenario: " name, tags ) +#endif +#define GIVEN( desc ) SECTION( std::string(" Given: ") + desc, "" ) +#define WHEN( desc ) SECTION( std::string(" When: ") + desc, "" ) +#define AND_WHEN( desc ) SECTION( std::string("And when: ") + desc, "" ) +#define THEN( desc ) SECTION( std::string(" Then: ") + desc, "" ) +#define AND_THEN( desc ) SECTION( std::string(" And: ") + desc, "" ) + +using Catch::Detail::Approx; + +// #included from: internal/catch_reenable_warnings.h + +#define TWOBLUECUBES_CATCH_REENABLE_WARNINGS_H_INCLUDED + +#ifdef __clang__ +# ifdef __ICC // icpc defines the __clang__ macro +# pragma warning(pop) +# else +# pragma clang diagnostic pop +# endif +#elif defined __GNUC__ +# pragma GCC diagnostic pop +#endif + +#endif // TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED + diff --git a/src/dionysus/wasserstein/def_debug_ws.h b/src/dionysus/wasserstein/def_debug_ws.h new file mode 100755 index 0000000..791ce1d --- /dev/null +++ b/src/dionysus/wasserstein/def_debug_ws.h @@ -0,0 +1,44 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef DEF_DEBUG_WS_H +#define DEF_DEBUG_WS_H + +//#define DEBUG_BOUND_MATCH +//#define DEBUG_NEIGHBOUR_ORACLE +//#define DEBUG_MATCHING +//#define DEBUG_AUCTION +// This symbol should be defined only in the version +// for R package TDA, to comply with some CRAN rules +// like no usage of cout, cerr, cin, exit, etc. +//#define FOR_R_TDA +// +//#define DEBUG_KDTREE_RESTR_ORACLE +//#define DEBUG_STUPID_SPARSE_RESTR_ORACLE +//#define DEBUG_FR_AUCTION + +#endif diff --git a/src/dionysus/wasserstein/diagonal_heap.h b/src/dionysus/wasserstein/diagonal_heap.h new file mode 100755 index 0000000..9ffee70 --- /dev/null +++ b/src/dionysus/wasserstein/diagonal_heap.h @@ -0,0 +1,149 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef DIAGONAL_HEAP_H +#define DIAGONAL_HEAP_H + +//#define USE_BOOST_HEAP + +#include +#include +#include +#include + +#ifdef USE_BOOST_HEAP +#include +#endif + +#include "basic_defs_ws.h" + +namespace hera { +namespace ws { + +template +struct CompPairsBySecondLexStruct { + bool operator()(const IdxValPair& a, const IdxValPair& b) const + { + return a.second < b.second or (a.second == b.second and a.first > b.first); + } +}; + + +template +struct CompPairsBySecondGreaterStruct { + bool operator()(const IdxValPair& a, const IdxValPair& b) const + { + return a.second > b.second; + } +}; + +#ifdef USE_BOOST_HEAP +template +using LossesHeapOld = boost::heap::d_ary_heap, boost::heap::arity<2>, boost::heap::mutable_, boost::heap::compare>>; +#else +template +class IdxValHeap { +public: + using InternalKeeper = std::set, ComparisonStruct>; + using handle_type = typename InternalKeeper::iterator; + using const_handle_type = typename InternalKeeper::const_iterator; + // methods + handle_type push(const IdxValPair& val) + { + auto res_pair = _heap.insert(val); + assert(res_pair.second); + assert(res_pair.first != _heap.end()); + return res_pair.first; + } + + void decrease(handle_type& handle, const IdxValPair& new_val) + { + _heap.erase(handle); + handle = push(new_val); + } + + void increase(handle_type& handle, const IdxValPair& new_val) + { + _heap.erase(handle); + handle = push(new_val); + } + + size_t size() const + { + return _heap.size(); + } + + handle_type ordered_begin() + { + return _heap.begin(); + } + + handle_type ordered_end() + { + return _heap.end(); + } + + const_handle_type ordered_begin() const + { + return _heap.cbegin(); + } + + const_handle_type ordered_end() const + { + return _heap.cend(); + } + + +private: + std::set, ComparisonStruct> _heap; +}; + +// if we store losses, the minimal value should come first +template +using LossesHeapOld = IdxValHeap>; +#endif + +template +std::string losses_heap_to_string(const LossesHeapOld& h) +{ + std::stringstream result; + result << "["; + for(auto iter = h.ordered_begin(); iter != h.ordered_end(); ++iter) { + result << *iter; + if (std::next(iter) != h.ordered_end()) { + result << ", "; + } + } + result << "]"; + return result.str(); +} + +} // ws +} // hera + +#endif // DIAGONAL_HEAP_H diff --git a/src/dionysus/wasserstein/diagram_reader.h b/src/dionysus/wasserstein/diagram_reader.h new file mode 100755 index 0000000..8d09c9b --- /dev/null +++ b/src/dionysus/wasserstein/diagram_reader.h @@ -0,0 +1,369 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + +*/ + +#ifndef HERA_DIAGRAM_READER_H +#define HERA_DIAGRAM_READER_H + +#ifndef FOR_R_TDA +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef WASSERSTEIN_PURE_GEOM +#include "dnn/geometry/euclidean-dynamic.h" +#endif + +namespace hera { + +// cannot choose stod, stof or stold based on RealType, +// lazy solution: partial specialization +template +RealType parse_real_from_str(const std::string& s); + +template <> +double parse_real_from_str(const std::string& s) +{ + return std::stod(s); +} + + +template <> +long double parse_real_from_str(const std::string& s) +{ + return std::stold(s); +} + + +template <> +float parse_real_from_str(const std::string& s) +{ + return std::stof(s); +} + + +template +RealType parse_real_from_str(const std::string& s) +{ + static_assert(sizeof(RealType) != sizeof(RealType), "Must be specialized for each type you want to use, see above"); +} + +// fill in result with points from file fname +// return false if file can't be opened +// or error occurred while reading +// decPrecision is the maximal decimal precision in the input, +// it is zero if all coordinates in the input are integers +template>> +bool read_diagram_point_set(const char* fname, ContType_& result, int& decPrecision) +{ + size_t lineNumber { 0 }; + result.clear(); + std::ifstream f(fname); + if (!f.good()) { +#ifndef FOR_R_TDA + std::cerr << "Cannot open file " << fname << std::endl; +#endif + return false; + } + std::locale loc; + std::string line; + while(std::getline(f, line)) { + lineNumber++; + // process comments: remove everything after hash + auto hashPos = line.find_first_of("#", 0); + if( std::string::npos != hashPos) { + line = std::string(line.begin(), line.begin() + hashPos); + } + if (line.empty()) { + continue; + } + // trim whitespaces + auto whiteSpaceFront = std::find_if_not(line.begin(),line.end(),isspace); + auto whiteSpaceBack = std::find_if_not(line.rbegin(),line.rend(),isspace).base(); + if (whiteSpaceBack <= whiteSpaceFront) { + // line consists of spaces only - move to the next line + continue; + } + line = std::string(whiteSpaceFront,whiteSpaceBack); + + // transform line to lower case + // to parse Infinity + for(auto& c : line) { + c = std::tolower(c, loc); + } + + bool fracPart = false; + int currDecPrecision = 0; + for(auto c : line) { + if (c == '.') { + fracPart = true; + } else if (fracPart) { + if (isdigit(c)) { + currDecPrecision++; + } else { + fracPart = false; + if (currDecPrecision > decPrecision) + decPrecision = currDecPrecision; + currDecPrecision = 0; + } + } + } + + RealType x, y; + std::string str_x, str_y; + std::istringstream iss(line); + try { + iss >> str_x >> str_y; + + x = parse_real_from_str(str_x); + y = parse_real_from_str(str_y); + + result.push_back(std::make_pair(x, y)); + } + catch (const std::invalid_argument& e) { +#ifndef FOR_R_TDA + std::cerr << "Error in file " << fname << ", line number " << lineNumber << ": cannot parse \"" << line << "\"" << std::endl; +#endif + return false; + } + catch (const std::out_of_range&) { +#ifndef FOR_R_TDA + std::cerr << "Error while reading file " << fname << ", line number " << lineNumber << ": value too large in \"" << line << "\"" << std::endl; +#endif + return false; + } + } + f.close(); + return true; +} + + +// wrappers +template>> +bool read_diagram_point_set(const std::string& fname, ContType_& result, int& decPrecision) +{ + return read_diagram_point_set(fname.c_str(), result, decPrecision); +} + +// these two functions are now just wrappers for the previous ones, +// in case someone needs them; decPrecision is ignored +template>> +bool read_diagram_point_set(const char* fname, ContType_& result) +{ + int decPrecision; + return read_diagram_point_set(fname, result, decPrecision); +} + +template>> +bool read_diagram_point_set(const std::string& fname, ContType_& result) +{ + int decPrecision; + return read_diagram_point_set(fname.c_str(), result, decPrecision); +} + + +template +void remove_duplicates(ContType& dgm_A, ContType& dgm_B) +{ + std::map, int> map_A, map_B; + // copy points to maps + for(const auto& ptA : dgm_A) { + map_A[ptA]++; + } + for(const auto& ptB : dgm_B) { + map_B[ptB]++; + } + // clear vectors + dgm_A.clear(); + dgm_B.clear(); + // remove duplicates from maps + // loop over the smaller one + if (map_A.size() <= map_B.size()) { + for(auto& point_multiplicity_pair : map_A) { + auto iter_B = map_B.find(point_multiplicity_pair.first); + if (iter_B != map_B.end()) { + int duplicate_multiplicity = std::min(point_multiplicity_pair.second, iter_B->second); + point_multiplicity_pair.second -= duplicate_multiplicity; + iter_B->second -= duplicate_multiplicity; + } + } + } else { + for(auto& point_multiplicity_pair : map_B) { + auto iter_A = map_A.find(point_multiplicity_pair.first); + if (iter_A != map_A.end()) { + int duplicate_multiplicity = std::min(point_multiplicity_pair.second, iter_A->second); + point_multiplicity_pair.second -= duplicate_multiplicity; + iter_A->second -= duplicate_multiplicity; + } + } + } + // copy points back to vectors + for(const auto& pointMultiplicityPairA : map_A) { + assert( pointMultiplicityPairA.second >= 0); + for(int i = 0; i < pointMultiplicityPairA.second; ++i) { + dgm_A.push_back(pointMultiplicityPairA.first); + } + } + + for(const auto& pointMultiplicityPairB : map_B) { + assert( pointMultiplicityPairB.second >= 0); + for(int i = 0; i < pointMultiplicityPairB.second; ++i) { + dgm_B.push_back(pointMultiplicityPairB.first); + } + } +} + + +#ifdef WASSERSTEIN_PURE_GEOM + +template +int get_point_dimension(const std::string& line) +{ + Real x; + int dim = 0; + std::istringstream iss(line); + while(iss >> x) { + dim++; + } + return dim; +} + + +template +bool read_point_cloud(const char* fname, hera::ws::dnn::DynamicPointVector& result, int& dimension, int& decPrecision) +{ + using DynamicPointTraitsR = typename hera::ws::dnn::DynamicPointTraits; + + size_t lineNumber { 0 }; + result.clear(); + std::ifstream f(fname); + if (!f.good()) { +#ifndef FOR_R_TDA + std::cerr << "Cannot open file " << fname << std::endl; +#endif + return false; + } + std::string line; + DynamicPointTraitsR traits; + bool dim_computed = false; + int point_idx = 0; + while(std::getline(f, line)) { + lineNumber++; + // process comments: remove everything after hash + auto hashPos = line.find_first_of("#", 0); + if( std::string::npos != hashPos) { + line = std::string(line.begin(), line.begin() + hashPos); + } + if (line.empty()) { + continue; + } + // trim whitespaces + auto whiteSpaceFront = std::find_if_not(line.begin(),line.end(),isspace); + auto whiteSpaceBack = std::find_if_not(line.rbegin(),line.rend(),isspace).base(); + if (whiteSpaceBack <= whiteSpaceFront) { + // line consists of spaces only - move to the next line + continue; + } + + line = std::string(whiteSpaceFront,whiteSpaceBack); + + if (not dim_computed) { + dimension = get_point_dimension(line); + traits = hera::ws::dnn::DynamicPointTraits(dimension); + result = traits.container(); + result.clear(); + dim_computed = true; + } + + bool fracPart = false; + int currDecPrecision = 0; + for(auto c : line) { + if (c == '.') { + fracPart = true; + } else if (fracPart) { + if (isdigit(c)) { + currDecPrecision++; + } else { + fracPart = false; + if (currDecPrecision > decPrecision) + decPrecision = currDecPrecision; + currDecPrecision = 0; + } + } + } + + result.resize(result.size() + 1); + RealType x; + std::istringstream iss(line); + for(int d = 0; d < dimension; ++d) { + if (not(iss >> x)) { +#ifndef FOR_R_TDA + std::cerr << "Error in file " << fname << ", line number " << lineNumber << ": cannot parse \"" << line << "\"" << std::endl; +#endif + return false; + } + result[point_idx][d] = x; + } + point_idx++; + } + f.close(); + return true; +} + +// wrappers +template +bool read_point_cloud(const char* fname, hera::ws::dnn::DynamicPointVector& result, int& dimension) +{ + int dec_precision; + return read_point_cloud(fname, result, dimension, dec_precision); +} + +template +bool read_point_cloud(std::string fname, hera::ws::dnn::DynamicPointVector& result, int& dimension, int& dec_precision) +{ + return read_point_cloud(fname.c_str(), result, dimension, dec_precision); +} + +template +bool read_point_cloud(std::string fname, hera::ws::dnn::DynamicPointVector& result, int& dimension) +{ + return read_point_cloud(fname.c_str(), result, dimension); +} + +#endif // WASSERSTEIN_PURE_GEOM + +} // end namespace hera +#endif // HERA_DIAGRAM_READER_H diff --git a/src/dionysus/wasserstein/dnn/geometry/euclidean-dynamic.h b/src/dionysus/wasserstein/dnn/geometry/euclidean-dynamic.h new file mode 100755 index 0000000..4b98309 --- /dev/null +++ b/src/dionysus/wasserstein/dnn/geometry/euclidean-dynamic.h @@ -0,0 +1,248 @@ +#ifndef DNN_GEOMETRY_EUCLIDEAN_DYNAMIC_H +#define DNN_GEOMETRY_EUCLIDEAN_DYNAMIC_H + +#include +#include +#include +#include +#include +#include + +namespace hera +{ +namespace ws +{ +namespace dnn +{ + +template +class DynamicPointVector +{ + public: + using Real = Real_; + struct PointType + { + void* p; + + Real& operator[](const int i) + { + return (static_cast(p))[i]; + } + + const Real& operator[](const int i) const + { + return (static_cast(p))[i]; + } + + }; + struct iterator; + typedef iterator const_iterator; + + public: + DynamicPointVector(size_t point_capacity = 0): + point_capacity_(point_capacity) {} + + + PointType operator[](size_t i) const { return {(void*) &storage_[i*point_capacity_]}; } + inline void push_back(PointType p); + + inline iterator begin(); + inline iterator end(); + inline const_iterator begin() const; + inline const_iterator end() const; + + size_t size() const { return storage_.size() / point_capacity_; } + + void clear() { storage_.clear(); } + void swap(DynamicPointVector& other) { storage_.swap(other.storage_); std::swap(point_capacity_, other.point_capacity_); } + void reserve(size_t sz) { storage_.reserve(sz * point_capacity_); } + void resize(size_t sz) { storage_.resize(sz * point_capacity_); } + + private: + size_t point_capacity_; + std::vector storage_; + + private: + friend class boost::serialization::access; + + template + void serialize(Archive& ar, const unsigned int version) { ar & point_capacity_ & storage_; } +}; + +template +struct DynamicPointTraits +{ + typedef DynamicPointVector PointContainer; + typedef typename PointContainer::PointType PointType; + struct PointHandle + { + void* p; + bool operator==(const PointHandle& other) const { return p == other.p; } + bool operator!=(const PointHandle& other) const { return !(*this == other); } + bool operator<(const PointHandle& other) const { return p < other.p; } + bool operator>(const PointHandle& other) const { return p > other.p; } + }; + + typedef Real Coordinate; + typedef Real DistanceType; + + DynamicPointTraits(unsigned dim = 0): + dim_(dim) {} + + DistanceType distance(PointType p1, PointType p2) const { return sqrt(sq_distance(p1,p2)); } + DistanceType distance(PointHandle p1, PointHandle p2) const { return distance(PointType({p1.p}), PointType({p2.p})); } + DistanceType sq_distance(PointType p1, PointType p2) const { Real res = 0; for (unsigned i = 0; i < dimension(); ++i) { Real c1 = coordinate(p1,i), c2 = coordinate(p2,i); res += (c1 - c2)*(c1 - c2); } return res; } + DistanceType sq_distance(PointHandle p1, PointHandle p2) const { return sq_distance(PointType({p1.p}), PointType({p2.p})); } + unsigned dimension() const { return dim_; } + Real& coordinate(PointType p, unsigned i) const { return ((Real*) p.p)[i]; } + Real& coordinate(PointHandle h, unsigned i) const { return ((Real*) h.p)[i]; } + + // it's non-standard to return a reference, but we can rely on it for code that assumes this particular point type + size_t& id(PointType p) const { return *((size_t*) ((Real*) p.p + dimension())); } + size_t& id(PointHandle h) const { return *((size_t*) ((Real*) h.p + dimension())); } + PointHandle handle(PointType p) const { return {p.p}; } + PointType point(PointHandle h) const { return {h.p}; } + + void swap(PointType p1, PointType p2) const { std::swap_ranges((char*) p1.p, ((char*) p1.p) + capacity(), (char*) p2.p); } + bool cmp(PointType p1, PointType p2) const { return std::lexicographical_compare((Real*) p1.p, ((Real*) p1.p) + dimension(), (Real*) p2.p, ((Real*) p2.p) + dimension()); } + bool eq(PointType p1, PointType p2) const { return std::equal((Real*) p1.p, ((Real*) p1.p) + dimension(), (Real*) p2.p); } + + // non-standard, and possibly a weird name + size_t capacity() const { return sizeof(Real)*dimension() + sizeof(size_t); } + + PointContainer container(size_t n = 0) const { PointContainer c(capacity()); c.resize(n); return c; } + PointContainer container(size_t n, const PointType& p) const; + + typename PointContainer::iterator + iterator(PointContainer& c, PointHandle ph) const; + typename PointContainer::const_iterator + iterator(const PointContainer& c, PointHandle ph) const; + + Real internal_p; + + private: + unsigned dim_; + + private: + friend class boost::serialization::access; + + template + void serialize(Archive& ar, const unsigned int version) { ar & dim_; } +}; + +} // dnn + +template +struct dnn::DynamicPointVector::iterator: + public boost::iterator_facade +{ + typedef boost::iterator_facade Parent; + + + public: + typedef typename Parent::value_type value_type; + typedef typename Parent::difference_type difference_type; + typedef typename Parent::reference reference; + + iterator(size_t point_capacity = 0): + point_capacity_(point_capacity) {} + + iterator(void* p, size_t point_capacity): + p_(p), point_capacity_(point_capacity) {} + + private: + void increment() { p_ = ((char*) p_) + point_capacity_; } + void decrement() { p_ = ((char*) p_) - point_capacity_; } + void advance(difference_type n) { p_ = ((char*) p_) + n*point_capacity_; } + difference_type + distance_to(iterator other) const { return (((char*) other.p_) - ((char*) p_))/(int) point_capacity_; } + bool equal(const iterator& other) const { return p_ == other.p_; } + reference dereference() const { return {p_}; } + + friend class ::boost::iterator_core_access; + + private: + void* p_; + size_t point_capacity_; +}; + +template +void dnn::DynamicPointVector::push_back(PointType p) +{ + if (storage_.capacity() < storage_.size() + point_capacity_) + storage_.reserve(1.5*storage_.capacity()); + + storage_.resize(storage_.size() + point_capacity_); + + std::copy((char*) p.p, (char*) p.p + point_capacity_, storage_.end() - point_capacity_); +} + +template +typename dnn::DynamicPointVector::iterator dnn::DynamicPointVector::begin() { return iterator((void*) &*storage_.begin(), point_capacity_); } + +template +typename dnn::DynamicPointVector::iterator dnn::DynamicPointVector::end() { return iterator((void*) &*storage_.end(), point_capacity_); } + +template +typename dnn::DynamicPointVector::const_iterator dnn::DynamicPointVector::begin() const { return const_iterator((void*) &*storage_.begin(), point_capacity_); } + +template +typename dnn::DynamicPointVector::const_iterator dnn::DynamicPointVector::end() const { return const_iterator((void*) &*storage_.end(), point_capacity_); } + +template +typename dnn::DynamicPointTraits::PointContainer +dnn::DynamicPointTraits::container(size_t n, const PointType& p) const +{ + PointContainer c = container(n); + for (auto x : c) + std::copy((char*) p.p, (char*) p.p + capacity(), (char*) x.p); + return c; +} + +template +typename dnn::DynamicPointTraits::PointContainer::iterator +dnn::DynamicPointTraits::iterator(PointContainer& c, PointHandle ph) const +{ return typename PointContainer::iterator(ph.p, capacity()); } + +template +typename dnn::DynamicPointTraits::PointContainer::const_iterator +dnn::DynamicPointTraits::iterator(const PointContainer& c, PointHandle ph) const +{ return typename PointContainer::const_iterator(ph.p, capacity()); } + +} // ws +} // hera + +namespace std { + template<> + struct hash::PointHandle> + { + using PointHandle = typename hera::ws::dnn::DynamicPointTraits::PointHandle; + size_t operator()(const PointHandle& ph) const + { + return std::hash()(ph.p); + } + }; + + template<> + struct hash::PointHandle> + { + using PointHandle = typename hera::ws::dnn::DynamicPointTraits::PointHandle; + size_t operator()(const PointHandle& ph) const + { + return std::hash()(ph.p); + } + }; + + +} // std + + +#endif diff --git a/src/dionysus/wasserstein/dnn/geometry/euclidean-fixed.h b/src/dionysus/wasserstein/dnn/geometry/euclidean-fixed.h new file mode 100755 index 0000000..3e38baf --- /dev/null +++ b/src/dionysus/wasserstein/dnn/geometry/euclidean-fixed.h @@ -0,0 +1,196 @@ +#ifndef HERA_WS_DNN_GEOMETRY_EUCLIDEAN_FIXED_H +#define HERA_WS_DNN_GEOMETRY_EUCLIDEAN_FIXED_H + +#include +#include +#include +#include +#include + +//#include +#include +#include +#include +#include + +#include "../parallel/tbb.h" // for dnn::vector<...> + +namespace hera +{ +namespace ws +{ +namespace dnn +{ + // TODO: wrap in another namespace (e.g., euclidean) + + template + struct Point: + boost::addable< Point, + boost::subtractable< Point, + boost::dividable2< Point, Real, + boost::multipliable2< Point, Real > > > >, + public boost::array + { + public: + typedef Real Coordinate; + typedef Real DistanceType; + + + public: + Point(size_t id = 0): id_(id) {} + template + Point(const Point& p, size_t id = 0): + id_(id) { *this = p; } + + static size_t dimension() { return D; } + + // Assign a point of different dimension + template + Point& operator=(const Point& p) { for (size_t i = 0; i < (D < DD ? D : DD); ++i) (*this)[i] = p[i]; if (DD < D) for (size_t i = DD; i < D; ++i) (*this)[i] = 0; return *this; } + + Point& operator+=(const Point& p) { for (size_t i = 0; i < D; ++i) (*this)[i] += p[i]; return *this; } + Point& operator-=(const Point& p) { for (size_t i = 0; i < D; ++i) (*this)[i] -= p[i]; return *this; } + Point& operator/=(Real r) { for (size_t i = 0; i < D; ++i) (*this)[i] /= r; return *this; } + Point& operator*=(Real r) { for (size_t i = 0; i < D; ++i) (*this)[i] *= r; return *this; } + + Real norm2() const { Real n = 0; for (size_t i = 0; i < D; ++i) n += (*this)[i] * (*this)[i]; return n; } + Real max_norm() const + { + Real res = std::fabs((*this)[0]); + for (size_t i = 1; i < D; ++i) + if (std::fabs((*this)[i]) > res) + res = std::fabs((*this)[i]); + return res; + } + + Real l1_norm() const + { + Real res = std::fabs((*this)[0]); + for (size_t i = 1; i < D; ++i) + res += std::fabs((*this)[i]); + return res; + } + + Real lp_norm(const Real p) const + { + assert( !std::isinf(p) ); + if ( p == 1.0 ) + return l1_norm(); + Real res = std::pow(std::fabs((*this)[0]), p); + for (size_t i = 1; i < D; ++i) + res += std::pow(std::fabs((*this)[i]), p); + return std::pow(res, 1.0 / p); + } + + // quick and dirty for now; make generic later + //DistanceType distance(const Point& other) const { return sqrt(sq_distance(other)); } + //DistanceType sq_distance(const Point& other) const { return (other - *this).norm2(); } + + DistanceType distance(const Point& other) const { return (other - *this).max_norm(); } + DistanceType p_distance(const Point& other, const double p) const { return (other - *this).lp_norm(p); } + + size_t id() const { return id_; } + size_t& id() { return id_; } + + private: + friend class boost::serialization::access; + + template + void serialize(Archive& ar, const unsigned int version) { ar & boost::serialization::base_object< boost::array >(*this) & id_; } + + private: + size_t id_; + }; + + template + std::ostream& + operator<<(std::ostream& out, const Point& p) + { out << p[0]; for (size_t i = 1; i < D; ++i) out << " " << p[i]; return out; } + + + template + struct PointTraits; // intentionally undefined; should be specialized for each type + + + template + struct PointTraits< Point > // specialization for dnn::Point + { + typedef Point PointType; + typedef const PointType* PointHandle; + typedef std::vector PointContainer; + + typedef typename PointType::Coordinate Coordinate; + typedef typename PointType::DistanceType DistanceType; + + + static DistanceType + distance(const PointType& p1, const PointType& p2) { if (hera::is_infinity(internal_p)) return p1.distance(p2); else return p1.p_distance(p2, internal_p); } + + static DistanceType + distance(PointHandle p1, PointHandle p2) { return distance(*p1,*p2); } + + static size_t dimension() { return D; } + static Real coordinate(const PointType& p, size_t i) { return p[i]; } + static Real& coordinate(PointType& p, size_t i) { return p[i]; } + static Real coordinate(PointHandle p, size_t i) { return coordinate(*p,i); } + + static size_t id(const PointType& p) { return p.id(); } + static size_t& id(PointType& p) { return p.id(); } + static size_t id(PointHandle p) { return id(*p); } + + static PointHandle + handle(const PointType& p) { return &p; } + static const PointType& + point(PointHandle ph) { return *ph; } + + void swap(PointType& p1, PointType& p2) const { return std::swap(p1, p2); } + + static PointContainer + container(size_t n = 0, const PointType& p = PointType()) { return PointContainer(n, p); } + static typename PointContainer::iterator + iterator(PointContainer& c, PointHandle ph) { return c.begin() + (ph - &c[0]); } + static typename PointContainer::const_iterator + iterator(const PointContainer& c, PointHandle ph) { return c.begin() + (ph - &c[0]); } + + // Internal_p determines which norm will be used in Wasserstein metric (not to + // be confused with wassersteinPower parameter: + // we raise \| p - q \|_{internal_p} to wassersteinPower. + static Real internal_p; + + private: + + friend class boost::serialization::access; + + template + void serialize(Archive& ar, const unsigned int version) {} + + }; + + template + Real PointTraits< Point >::internal_p = hera::get_infinity(); + + + template + void read_points(const std::string& filename, PointContainer& points) + { + typedef typename boost::range_value::type Point; + typedef typename PointTraits::Coordinate Coordinate; + + std::ifstream in(filename.c_str()); + std::string line; + while(std::getline(in, line)) + { + if (line[0] == '#') continue; // comment line in the file + std::stringstream linestream(line); + Coordinate x; + points.push_back(Point()); + size_t i = 0; + while (linestream >> x) + points.back()[i++] = x; + } + } +} // dnn +} // ws +} // hera + +#endif diff --git a/src/dionysus/wasserstein/dnn/local/kd-tree.h b/src/dionysus/wasserstein/dnn/local/kd-tree.h new file mode 100755 index 0000000..8e52a5c --- /dev/null +++ b/src/dionysus/wasserstein/dnn/local/kd-tree.h @@ -0,0 +1,97 @@ +#ifndef HERA_WS_DNN_LOCAL_KD_TREE_H +#define HERA_WS_DNN_LOCAL_KD_TREE_H + +#include "../utils.h" +#include "search-functors.h" + +#include + +#include +#include +#include + +#include +#include + +namespace hera +{ +namespace ws +{ +namespace dnn +{ + // Weighted KDTree + // Traits_ provides Coordinate, DistanceType, PointType, dimension(), distance(p1,p2), coordinate(p,i) + template< class Traits_ > + class KDTree + { + public: + typedef Traits_ Traits; + typedef dnn::HandleDistance HandleDistance; + + typedef typename Traits::PointType Point; + typedef typename Traits::PointHandle PointHandle; + typedef typename Traits::Coordinate Coordinate; + typedef typename Traits::DistanceType DistanceType; + typedef std::vector HandleContainer; + typedef std::vector HDContainer; // TODO: use tbb::scalable_allocator + typedef HDContainer Result; + typedef std::vector DistanceContainer; + typedef std::unordered_map HandleMap; + + BOOST_STATIC_ASSERT_MSG(has_coordinates::value, "KDTree requires coordinates"); + + public: + KDTree(const Traits& traits): + traits_(traits) {} + + KDTree(const Traits& traits, HandleContainer&& handles, double _wassersteinPower = 1.0); + + template + KDTree(const Traits& traits, const Range& range, double _wassersteinPower = 1.0); + + template + void init(const Range& range); + + DistanceType weight(PointHandle p) { return weights_[indices_[p]]; } + void change_weight(PointHandle p, DistanceType w); + void adjust_weights(DistanceType delta); // subtract delta from all weights + + HandleDistance find(PointHandle q) const; + Result findR(PointHandle q, DistanceType r) const; // all neighbors within r + Result findK(PointHandle q, size_t k) const; // k nearest neighbors + + HandleDistance find(const Point& q) const { return find(traits().handle(q)); } + Result findR(const Point& q, DistanceType r) const { return findR(traits().handle(q), r); } + Result findK(const Point& q, size_t k) const { return findK(traits().handle(q), k); } + + template + void search(PointHandle q, ResultsFunctor& rf) const; + + const Traits& traits() const { return traits_; } + + void printWeights(void); + + private: + void init(); + + typedef typename HandleContainer::iterator HCIterator; + typedef std::tuple KDTreeNode; + + struct CoordinateComparison; + struct OrderTree; + + private: + Traits traits_; + HandleContainer tree_; + DistanceContainer weights_; // point weight + DistanceContainer subtree_weights_; // min weight in the subtree + HandleMap indices_; + double wassersteinPower; + }; +} // dnn +} // ws +} // hera + +#include "kd-tree.hpp" + +#endif diff --git a/src/dionysus/wasserstein/dnn/local/kd-tree.hpp b/src/dionysus/wasserstein/dnn/local/kd-tree.hpp new file mode 100755 index 0000000..4699ca3 --- /dev/null +++ b/src/dionysus/wasserstein/dnn/local/kd-tree.hpp @@ -0,0 +1,330 @@ +#include +#include +#include + +#include +#include + +#include "../parallel/tbb.h" +#include "../../def_debug_ws.h" + +template +hera::ws::dnn::KDTree:: +KDTree(const Traits& traits, HandleContainer&& handles, double _wassersteinPower): + traits_(traits), tree_(std::move(handles)), wassersteinPower(_wassersteinPower) +{ assert(wassersteinPower >= 1.0); init(); } + +template +template +hera::ws::dnn::KDTree:: +KDTree(const Traits& traits, const Range& range, double _wassersteinPower): + traits_(traits), wassersteinPower(_wassersteinPower) +{ + assert( wassersteinPower >= 1.0); + init(range); +} + +template +template +void +hera::ws::dnn::KDTree:: +init(const Range& range) +{ + size_t sz = std::distance(std::begin(range), std::end(range)); + tree_.reserve(sz); + weights_.resize(sz, 0); + subtree_weights_.resize(sz, 0); + for (PointHandle h : range) + tree_.push_back(h); + init(); +} + +template +void +hera::ws::dnn::KDTree:: +init() +{ + if (tree_.empty()) + return; + +#if defined(TBB) + task_group g; + g.run(OrderTree(tree_.begin(), tree_.end(), 0, traits())); + g.wait(); +#else + OrderTree(tree_.begin(), tree_.end(), 0, traits()).serial(); +#endif + + for (size_t i = 0; i < tree_.size(); ++i) + indices_[tree_[i]] = i; +} + +template +struct +hera::ws::dnn::KDTree::OrderTree +{ + OrderTree(HCIterator b_, HCIterator e_, size_t i_, const Traits& traits_): + b(b_), e(e_), i(i_), traits(traits_) {} + + void operator()() const + { + if (e - b < 1000) + { + serial(); + return; + } + + HCIterator m = b + (e - b)/2; + CoordinateComparison cmp(i, traits); + std::nth_element(b,m,e, cmp); + size_t next_i = (i + 1) % traits.dimension(); + + task_group g; + if (b < m - 1) g.run(OrderTree(b, m, next_i, traits)); + if (e > m + 2) g.run(OrderTree(m+1, e, next_i, traits)); + g.wait(); + } + + void serial() const + { + std::queue q; + q.push(KDTreeNode(b,e,i)); + while (!q.empty()) + { + HCIterator b, e; size_t i; + std::tie(b,e,i) = q.front(); + q.pop(); + HCIterator m = b + (e - b)/2; + + CoordinateComparison cmp(i, traits); + std::nth_element(b,m,e, cmp); + size_t next_i = (i + 1) % traits.dimension(); + + // Replace with a size condition instead? + if (b < m - 1) q.push(KDTreeNode(b, m, next_i)); + if (e - m > 2) q.push(KDTreeNode(m+1, e, next_i)); + } + } + + HCIterator b, e; + size_t i; + const Traits& traits; +}; + +template +template +void +hera::ws::dnn::KDTree:: +search(PointHandle q, ResultsFunctor& rf) const +{ + typedef typename HandleContainer::const_iterator HCIterator; + typedef std::tuple KDTreeNode; + + if (tree_.empty()) + return; + + DistanceType D = std::numeric_limits::max(); + + // TODO: use tbb::scalable_allocator for the queue + std::queue nodes; + + nodes.push(KDTreeNode(tree_.begin(), tree_.end(), 0)); + + while (!nodes.empty()) + { + HCIterator b, e; size_t i; + std::tie(b,e,i) = nodes.front(); + nodes.pop(); + + CoordinateComparison cmp(i, traits()); + i = (i + 1) % traits().dimension(); + + HCIterator m = b + (e - b)/2; + + DistanceType dist = (wassersteinPower == 1.0) ? traits().distance(q, *m) + weights_[m - tree_.begin()] : std::pow(traits().distance(q, *m), wassersteinPower) + weights_[m - tree_.begin()]; + + + D = rf(*m, dist); + + // we are really searching w.r.t L_\infty ball; could prune better with an L_2 ball + Coordinate diff = cmp.diff(q, *m); // diff returns signed distance + + DistanceType diffToWasserPower = (wassersteinPower == 1.0) ? diff : ((diff > 0 ? 1.0 : -1.0) * std::pow(fabs(diff), wassersteinPower)); + + size_t lm = m + 1 + (e - (m+1))/2 - tree_.begin(); + if (e > m + 1 && diffToWasserPower - subtree_weights_[lm] >= -D) { + nodes.push(KDTreeNode(m+1, e, i)); + } + + size_t rm = b + (m - b) / 2 - tree_.begin(); + if (b < m && diffToWasserPower + subtree_weights_[rm] <= D) { + nodes.push(KDTreeNode(b, m, i)); + } + } +} + +template +void +hera::ws::dnn::KDTree:: +adjust_weights(DistanceType delta) +{ + for(auto& w : weights_) + w -= delta; + + for(auto& sw : subtree_weights_) + sw -= delta; +} + + +template +void +hera::ws::dnn::KDTree:: +change_weight(PointHandle p, DistanceType w) +{ + size_t idx = indices_[p]; + + if ( weights_[idx] == w ) { + return; + } + + bool weight_increases = ( weights_[idx] < w ); + weights_[idx] = w; + + typedef std::tuple KDTreeNode; + + // find the path down the tree to this node + // not an ideal strategy, but // it's not clear how to move up from the node in general + std::stack s; + s.push(KDTreeNode(tree_.begin(),tree_.end())); + + do + { + HCIterator b,e; + std::tie(b,e) = s.top(); + + size_t im = b + (e - b)/2 - tree_.begin(); + + if (idx == im) + break; + else if (idx < im) + s.push(KDTreeNode(b, tree_.begin() + im)); + else // idx > im + s.push(KDTreeNode(tree_.begin() + im + 1, e)); + } while(1); + + // update subtree_weights_ on the path to the root + DistanceType min_w = w; + while (!s.empty()) + { + HCIterator b,e; + std::tie(b,e) = s.top(); + HCIterator m = b + (e - b)/2; + size_t im = m - tree_.begin(); + s.pop(); + + + // left and right children + if (b < m) + { + size_t lm = b + (m - b)/2 - tree_.begin(); + if (subtree_weights_[lm] < min_w) + min_w = subtree_weights_[lm]; + } + + if (e > m + 1) + { + size_t rm = m + 1 + (e - (m+1))/2 - tree_.begin(); + if (subtree_weights_[rm] < min_w) + min_w = subtree_weights_[rm]; + } + + if (weights_[im] < min_w) { + min_w = weights_[im]; + } + + if (weight_increases) { + + if (subtree_weights_[im] < min_w ) // increase weight + subtree_weights_[im] = min_w; + else + break; + + } else { + + if (subtree_weights_[im] > min_w ) // decrease weight + subtree_weights_[im] = min_w; + else + break; + + } + } +} + +template +typename hera::ws::dnn::KDTree::HandleDistance +hera::ws::dnn::KDTree:: +find(PointHandle q) const +{ + hera::ws::dnn::NNRecord nn; + search(q, nn); + return nn.result; +} + +template +typename hera::ws::dnn::KDTree::Result +hera::ws::dnn::KDTree:: +findR(PointHandle q, DistanceType r) const +{ + hera::ws::dnn::rNNRecord rnn(r); + search(q, rnn); + std::sort(rnn.result.begin(), rnn.result.end()); + return rnn.result; +} + +template +typename hera::ws::dnn::KDTree::Result +hera::ws::dnn::KDTree:: +findK(PointHandle q, size_t k) const +{ + hera::ws::dnn::kNNRecord knn(k); + search(q, knn); + std::sort(knn.result.begin(), knn.result.end()); + return knn.result; +} + + +template +struct hera::ws::dnn::KDTree::CoordinateComparison +{ + CoordinateComparison(size_t i, const Traits& traits): + i_(i), traits_(traits) {} + + bool operator()(PointHandle p1, PointHandle p2) const { return coordinate(p1) < coordinate(p2); } + Coordinate diff(PointHandle p1, PointHandle p2) const { return coordinate(p1) - coordinate(p2); } + + Coordinate coordinate(PointHandle p) const { return traits_.coordinate(p, i_); } + size_t axis() const { return i_; } + + private: + size_t i_; + const Traits& traits_; +}; + +template +void +hera::ws::dnn::KDTree:: +printWeights(void) +{ +#ifndef FOR_R_TDA + std::cout << "weights_:" << std::endl; + for(const auto ph : indices_) { + std::cout << "idx = " << ph.second << ": (" << (ph.first)->at(0) << ", " << (ph.first)->at(1) << ") weight = " << weights_[ph.second] << std::endl; + } + std::cout << "subtree_weights_:" << std::endl; + for(size_t idx = 0; idx < subtree_weights_.size(); ++idx) { + std::cout << idx << " : " << subtree_weights_[idx] << std::endl; + } +#endif +} + + diff --git a/src/dionysus/wasserstein/dnn/local/search-functors.h b/src/dionysus/wasserstein/dnn/local/search-functors.h new file mode 100755 index 0000000..1419f22 --- /dev/null +++ b/src/dionysus/wasserstein/dnn/local/search-functors.h @@ -0,0 +1,95 @@ +#ifndef HERA_WS_DNN_LOCAL_SEARCH_FUNCTORS_H +#define HERA_WS_DNN_LOCAL_SEARCH_FUNCTORS_H + +#include + +namespace hera +{ +namespace ws +{ +namespace dnn +{ + +template +struct HandleDistance +{ + typedef typename NN::PointHandle PointHandle; + typedef typename NN::DistanceType DistanceType; + typedef typename NN::HDContainer HDContainer; + + HandleDistance() {} + HandleDistance(PointHandle pp, DistanceType dd): + p(pp), d(dd) {} + bool operator<(const HandleDistance& other) const { return d < other.d; } + + PointHandle p; + DistanceType d; +}; + +template +struct NNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + + NNRecord() { result.d = std::numeric_limits::max(); } + DistanceType operator()(PointHandle p, DistanceType d) { if (d < result.d) { result.p = p; result.d = d; } return result.d; } + HandleDistance result; +}; + +template +struct rNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + rNNRecord(DistanceType r_): r(r_) {} + DistanceType operator()(PointHandle p, DistanceType d) + { + if (d <= r) + result.push_back(HandleDistance(p,d)); + return r; + } + + DistanceType r; + HDContainer result; +}; + +template +struct kNNRecord +{ + typedef typename HandleDistance::PointHandle PointHandle; + typedef typename HandleDistance::DistanceType DistanceType; + typedef typename HandleDistance::HDContainer HDContainer; + + kNNRecord(unsigned k_): k(k_) {} + DistanceType operator()(PointHandle p, DistanceType d) + { + if (result.size() < k) + { + result.push_back(HandleDistance(p,d)); + boost::push_heap(result); + if (result.size() < k) + return std::numeric_limits::max(); + } else if (d < result[0].d) + { + boost::pop_heap(result); + result.back() = HandleDistance(p,d); + boost::push_heap(result); + } + if ( result.size() > 1 ) { + assert( result[0].d >= result[1].d ); + } + return result[0].d; + } + + unsigned k; + HDContainer result; +}; + +} // dnn +} // ws +} // hera + +#endif // DNN_LOCAL_SEARCH_FUNCTORS_H diff --git a/src/dionysus/wasserstein/dnn/parallel/tbb.h b/src/dionysus/wasserstein/dnn/parallel/tbb.h new file mode 100755 index 0000000..3f811d6 --- /dev/null +++ b/src/dionysus/wasserstein/dnn/parallel/tbb.h @@ -0,0 +1,237 @@ +#ifndef HERA_WS_PARALLEL_H +#define HERA_WS_PARALLEL_H + +#include + +#include +#include +#include + +#ifdef TBB + +#include +#include +#include + +#include +#include +#include + +namespace hera +{ +namespace ws +{ +namespace dnn +{ + using tbb::mutex; + using tbb::task_scheduler_init; + using tbb::task_group; + using tbb::task; + + template + struct vector + { + typedef tbb::concurrent_vector type; + }; + + template + struct atomic + { + typedef tbb::atomic type; + static T compare_and_swap(type& v, T n, T o) { return v.compare_and_swap(n,o); } + }; + + template + void do_foreach(Iterator begin, Iterator end, const F& f) { tbb::parallel_do(begin, end, f); } + + template + void for_each_range_(const Range& r, const F& f) + { + for (typename Range::iterator cur = r.begin(); cur != r.end(); ++cur) + f(*cur); + } + + template + void for_each_range(size_t from, size_t to, const F& f) + { + //static tbb::affinity_partitioner ap; + //tbb::parallel_for(c.range(), boost::bind(&for_each_range_, _1, f), ap); + tbb::parallel_for(from, to, f); + } + + template + void for_each_range(const Container& c, const F& f) + { + //static tbb::affinity_partitioner ap; + //tbb::parallel_for(c.range(), boost::bind(&for_each_range_, _1, f), ap); + tbb::parallel_for(c.range(), boost::bind(&for_each_range_, _1, f)); + } + + template + void for_each_range(Container& c, const F& f) + { + //static tbb::affinity_partitioner ap; + //tbb::parallel_for(c.range(), boost::bind(&for_each_range_, _1, f), ap); + tbb::parallel_for(c.range(), boost::bind(&for_each_range_, _1, f)); + } + + template + struct map_traits + { + typedef tbb::concurrent_hash_map type; + typedef typename type::range_type range; + }; + + struct progress_timer + { + progress_timer(): start(tbb::tick_count::now()) {} + ~progress_timer() + { std::cout << (tbb::tick_count::now() - start).seconds() << " s" << std::endl; } + + tbb::tick_count start; + }; +} // dnn +} // ws +} // hera + +// Serialization for tbb::concurrent_vector<...> +namespace boost +{ + namespace serialization + { + template + void save(Archive& ar, const tbb::concurrent_vector& v, const unsigned int file_version) + { stl::save_collection(ar, v); } + + template + void load(Archive& ar, tbb::concurrent_vector& v, const unsigned int file_version) + { + stl::load_collection, + stl::archive_input_seq< Archive, tbb::concurrent_vector >, + stl::reserve_imp< tbb::concurrent_vector > + >(ar, v); + } + + template + void serialize(Archive& ar, tbb::concurrent_vector& v, const unsigned int file_version) + { split_free(ar, v, file_version); } + + template + void save(Archive& ar, const tbb::atomic& v, const unsigned int file_version) + { T v_ = v; ar << v_; } + + template + void load(Archive& ar, tbb::atomic& v, const unsigned int file_version) + { T v_; ar >> v_; v = v_; } + + template + void serialize(Archive& ar, tbb::atomic& v, const unsigned int file_version) + { split_free(ar, v, file_version); } + } +} + +#else + +#include +#include +#include + +namespace hera +{ +namespace ws +{ +namespace dnn +{ + template + struct vector + { + typedef ::std::vector type; + }; + + template + struct atomic + { + typedef T type; + static T compare_and_swap(type& v, T n, T o) { if (v != o) return v; v = n; return o; } + }; + + template + void do_foreach(Iterator begin, Iterator end, const F& f) { std::for_each(begin, end, f); } + + template + void for_each_range(size_t from, size_t to, const F& f) + { + for (size_t i = from; i < to; ++i) + f(i); + } + + template + void for_each_range(Container& c, const F& f) + { + BOOST_FOREACH(const typename Container::value_type& i, c) + f(i); + } + + template + void for_each_range(const Container& c, const F& f) + { + BOOST_FOREACH(const typename Container::value_type& i, c) + f(i); + } + + struct mutex + { + struct scoped_lock + { + scoped_lock() {} + scoped_lock(mutex& ) {} + void acquire(mutex& ) const {} + void release() const {} + }; + }; + + struct task_scheduler_init + { + task_scheduler_init(unsigned) {} + void initialize(unsigned) {} + static const unsigned automatic = 0; + static const unsigned deferred = 0; + }; + + struct task_group + { + template + void run(const Functor& f) const { f(); } + void wait() const {} + }; + + template + struct map_traits + { + typedef std::map type; + typedef type range; + }; + + using boost::progress_timer; +} // dnn +} // ws +} // hera + +#endif // TBB + +namespace hera +{ +namespace ws +{ +namespace dnn +{ + template + void do_foreach(const Range& range, const F& f) { do_foreach(boost::begin(range), boost::end(range), f); } +} // dnn +} // ws +} // hera + +#endif diff --git a/src/dionysus/wasserstein/dnn/parallel/utils.h b/src/dionysus/wasserstein/dnn/parallel/utils.h new file mode 100755 index 0000000..7104ec3 --- /dev/null +++ b/src/dionysus/wasserstein/dnn/parallel/utils.h @@ -0,0 +1,100 @@ +#ifndef HERA_WS_PARALLEL_UTILS_H +#define HERA_WS_PARALLEL_UTILS_H + +#include "../utils.h" + +namespace hera +{ +namespace ws +{ +namespace dnn +{ + // Assumes rng is synchronized across ranks + template + void shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const SwapFunctor& swap, DataVector empty = DataVector()); + + template + void shuffle(mpi::communicator& world, DataVector& data, RNGType& rng) + { + typedef decltype(data[0]) T; + shuffle(world, data, rng, [](T& x, T& y) { std::swap(x,y); }); + } +} // dnn +} // ws +} // hera + +template +void +hera::ws::dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const SwapFunctor& swap, DataVector empty) +{ + // This is not a perfect shuffle: it dishes out data in chunks of 1/size. + // (It can be interpreted as generating a bistochastic matrix by taking the + // sum of size random permutation matrices.) Hopefully, it works for our purposes. + + typedef typename RNGType::result_type RNGResult; + + int size = world.size(); + int rank = world.rank(); + + // Generate local seeds + boost::uniform_int uniform; + RNGResult seed; + for (size_t i = 0; i < size; ++i) + { + RNGResult v = uniform(rng); + if (i == rank) + seed = v; + } + RNGType local_rng(seed); + + // Shuffle local data + hera::ws::dnn::random_shuffle(data.begin(), data.end(), local_rng, swap); + + // Decide how much of our data goes to i-th processor + std::vector out_counts(size); + std::vector ranks(boost::counting_iterator(0), + boost::counting_iterator(size)); + for (size_t i = 0; i < size; ++i) + { + hera::ws::dnn::random_shuffle(ranks.begin(), ranks.end(), rng); + ++out_counts[ranks[rank]]; + } + + // Fill the outgoing array + size_t total = 0; + std::vector< DataVector > outgoing(size, empty); + for (size_t i = 0; i < size; ++i) + { + size_t count = data.size()*out_counts[i]/size; + if (total + count > data.size()) + count = data.size() - total; + + outgoing[i].reserve(count); + for (size_t j = total; j < total + count; ++j) + outgoing[i].push_back(data[j]); + + total += count; + } + + boost::uniform_int uniform_outgoing(0,size-1); // in range [0,size-1] + while(total < data.size()) // send leftover to random processes + { + outgoing[uniform_outgoing(local_rng)].push_back(data[total]); + ++total; + } + data.clear(); + + // Exchange the data + std::vector< DataVector > incoming(size, empty); + mpi::all_to_all(world, outgoing, incoming); + outgoing.clear(); + + // Assemble our data + for(const DataVector& vec : incoming) + for (size_t i = 0; i < vec.size(); ++i) + data.push_back(vec[i]); + hera::ws::dnn::random_shuffle(data.begin(), data.end(), local_rng, swap); + // XXX: the final shuffle is irrelevant for our purposes. But it's also cheap. +} + +#endif diff --git a/src/dionysus/wasserstein/dnn/utils.h b/src/dionysus/wasserstein/dnn/utils.h new file mode 100755 index 0000000..bbce793 --- /dev/null +++ b/src/dionysus/wasserstein/dnn/utils.h @@ -0,0 +1,47 @@ +#ifndef HERA_WS_DNN_UTILS_H +#define HERA_WS_DNN_UTILS_H + +#include +#include +#include + +namespace hera +{ +namespace ws +{ +namespace dnn +{ + +template +struct has_coordinates +{ + template ().coordinate(std::declval()...) )> + static std::true_type test(int); + + template + static std::false_type test(...); + + static constexpr bool value = decltype(test(0))::value; +}; + +template +void random_shuffle(RandomIt first, RandomIt last, UniformRandomNumberGenerator& g, const SwapFunctor& swap) +{ + size_t n = last - first; + boost::uniform_int uniform(0,n); + for (size_t i = n-1; i > 0; --i) + swap(first[i], first[uniform(g,i+1)]); // picks a random number in [0,i] range +} + +template +void random_shuffle(RandomIt first, RandomIt last, UniformRandomNumberGenerator& g) +{ + typedef decltype(*first) T; + random_shuffle(first, last, g, [](T& x, T& y) { std::swap(x,y); }); +} + +} // dnn +} // ws +} // hera + +#endif diff --git a/src/dionysus/wasserstein/spdlog/async_logger.h b/src/dionysus/wasserstein/spdlog/async_logger.h new file mode 100755 index 0000000..e9fcd5f --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/async_logger.h @@ -0,0 +1,82 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +// Very fast asynchronous logger (millions of logs per second on an average desktop) +// Uses pre allocated lockfree queue for maximum throughput even under large number of threads. +// Creates a single back thread to pop messages from the queue and log them. +// +// Upon each log write the logger: +// 1. Checks if its log level is enough to log the message +// 2. Push a new copy of the message to a queue (or block the caller until space is available in the queue) +// 3. will throw spdlog_ex upon log exceptions +// Upon destruction, logs all remaining messages in the queue before destructing.. + +#include "common.h" +#include "logger.h" + +#include +#include +#include +#include + +namespace spdlog +{ + +namespace details +{ +class async_log_helper; +} + +class async_logger SPDLOG_FINAL :public logger +{ +public: + template + async_logger(const std::string& name, + const It& begin, + const It& end, + size_t queue_size, + const async_overflow_policy overflow_policy = async_overflow_policy::block_retry, + const std::function& worker_warmup_cb = nullptr, + const std::chrono::milliseconds& flush_interval_ms = std::chrono::milliseconds::zero(), + const std::function& worker_teardown_cb = nullptr); + + async_logger(const std::string& logger_name, + sinks_init_list sinks, + size_t queue_size, + const async_overflow_policy overflow_policy = async_overflow_policy::block_retry, + const std::function& worker_warmup_cb = nullptr, + const std::chrono::milliseconds& flush_interval_ms = std::chrono::milliseconds::zero(), + const std::function& worker_teardown_cb = nullptr); + + async_logger(const std::string& logger_name, + sink_ptr single_sink, + size_t queue_size, + const async_overflow_policy overflow_policy = async_overflow_policy::block_retry, + const std::function& worker_warmup_cb = nullptr, + const std::chrono::milliseconds& flush_interval_ms = std::chrono::milliseconds::zero(), + const std::function& worker_teardown_cb = nullptr); + + //Wait for the queue to be empty, and flush synchronously + //Warning: this can potentially last forever as we wait it to complete + void flush() override; + + // Error handler + virtual void set_error_handler(log_err_handler) override; + virtual log_err_handler error_handler() override; + +protected: + void _sink_it(details::log_msg& msg) override; + void _set_formatter(spdlog::formatter_ptr msg_formatter) override; + void _set_pattern(const std::string& pattern, pattern_time_type pattern_time) override; + +private: + std::unique_ptr _async_log_helper; +}; +} + + +#include "details/async_logger_impl.h" diff --git a/src/dionysus/wasserstein/spdlog/common.h b/src/dionysus/wasserstein/spdlog/common.h new file mode 100755 index 0000000..252a2d6 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/common.h @@ -0,0 +1,160 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#if defined(_WIN32) && defined(SPDLOG_WCHAR_FILENAMES) +#include +#include +#endif + +#include "details/null_mutex.h" + +//visual studio upto 2013 does not support noexcept nor constexpr +#if defined(_MSC_VER) && (_MSC_VER < 1900) +#define SPDLOG_NOEXCEPT throw() +#define SPDLOG_CONSTEXPR +#else +#define SPDLOG_NOEXCEPT noexcept +#define SPDLOG_CONSTEXPR constexpr +#endif + +// See tweakme.h +#if !defined(SPDLOG_FINAL) +#define SPDLOG_FINAL +#endif + +#if defined(__GNUC__) || defined(__clang__) +#define SPDLOG_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +#define SPDLOG_DEPRECATED __declspec(deprecated) +#else +#define SPDLOG_DEPRECATED +#endif + + +#include "fmt/fmt.h" + +namespace spdlog +{ + +class formatter; + +namespace sinks +{ +class sink; +} + +using log_clock = std::chrono::system_clock; +using sink_ptr = std::shared_ptr < sinks::sink >; +using sinks_init_list = std::initializer_list < sink_ptr >; +using formatter_ptr = std::shared_ptr; +#if defined(SPDLOG_NO_ATOMIC_LEVELS) +using level_t = details::null_atomic_int; +#else +using level_t = std::atomic; +#endif + +using log_err_handler = std::function; + +//Log level enum +namespace level +{ +typedef enum +{ + trace = 0, + debug = 1, + info = 2, + warn = 3, + err = 4, + critical = 5, + off = 6 +} level_enum; + +#if !defined(SPDLOG_LEVEL_NAMES) +#define SPDLOG_LEVEL_NAMES { "trace", "debug", "info", "warning", "error", "critical", "off" }; +#endif +static const char* level_names[] SPDLOG_LEVEL_NAMES + +static const char* short_level_names[] { "T", "D", "I", "W", "E", "C", "O" }; + +inline const char* to_str(spdlog::level::level_enum l) +{ + return level_names[l]; +} + +inline const char* to_short_str(spdlog::level::level_enum l) +{ + return short_level_names[l]; +} +} //level + + +// +// Async overflow policy - block by default. +// +enum class async_overflow_policy +{ + block_retry, // Block / yield / sleep until message can be enqueued + discard_log_msg // Discard the message it enqueue fails +}; + +// +// Pattern time - specific time getting to use for pattern_formatter. +// local time by default +// +enum class pattern_time_type +{ + local, // log localtime + utc // log utc +}; + +// +// Log exception +// +namespace details +{ +namespace os +{ +std::string errno_str(int err_num); +} +} +class spdlog_ex: public std::exception +{ +public: + spdlog_ex(const std::string& msg):_msg(msg) + {} + spdlog_ex(const std::string& msg, int last_errno) + { + _msg = msg + ": " + details::os::errno_str(last_errno); + } + const char* what() const SPDLOG_NOEXCEPT override + { + return _msg.c_str(); + } +private: + std::string _msg; + +}; + +// +// wchar support for windows file names (SPDLOG_WCHAR_FILENAMES must be defined) +// +#if defined(_WIN32) && defined(SPDLOG_WCHAR_FILENAMES) +using filename_t = std::wstring; +#else +using filename_t = std::string; +#endif + + +} //spdlog diff --git a/src/dionysus/wasserstein/spdlog/details/async_log_helper.h b/src/dionysus/wasserstein/spdlog/details/async_log_helper.h new file mode 100755 index 0000000..ceb1d69 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/details/async_log_helper.h @@ -0,0 +1,399 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +// async log helper : +// Process logs asynchronously using a back thread. +// +// If the internal queue of log messages reaches its max size, +// then the client call will block until there is more room. +// + +#pragma once + +#include "../common.h" +#include "../sinks/sink.h" +#include "../details/mpmc_bounded_q.h" +#include "../details/log_msg.h" +#include "../details/os.h" +#include "../formatter.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace spdlog +{ +namespace details +{ + +class async_log_helper +{ + // Async msg to move to/from the queue + // Movable only. should never be copied + enum class async_msg_type + { + log, + flush, + terminate + }; + struct async_msg + { + std::string logger_name; + level::level_enum level; + log_clock::time_point time; + size_t thread_id; + std::string txt; + async_msg_type msg_type; + size_t msg_id; + + async_msg() = default; + ~async_msg() = default; + + +async_msg(async_msg&& other) SPDLOG_NOEXCEPT: + logger_name(std::move(other.logger_name)), + level(std::move(other.level)), + time(std::move(other.time)), + thread_id(other.thread_id), + txt(std::move(other.txt)), + msg_type(std::move(other.msg_type)), + msg_id(other.msg_id) + {} + + async_msg(async_msg_type m_type): + level(level::info), + thread_id(0), + msg_type(m_type), + msg_id(0) + {} + + async_msg& operator=(async_msg&& other) SPDLOG_NOEXCEPT + { + logger_name = std::move(other.logger_name); + level = other.level; + time = std::move(other.time); + thread_id = other.thread_id; + txt = std::move(other.txt); + msg_type = other.msg_type; + msg_id = other.msg_id; + return *this; + } + + // never copy or assign. should only be moved.. + async_msg(const async_msg&) = delete; + async_msg& operator=(const async_msg& other) = delete; + + // construct from log_msg + async_msg(const details::log_msg& m): + level(m.level), + time(m.time), + thread_id(m.thread_id), + txt(m.raw.data(), m.raw.size()), + msg_type(async_msg_type::log), + msg_id(m.msg_id) + { +#ifndef SPDLOG_NO_NAME + logger_name = *m.logger_name; +#endif + } + + + // copy into log_msg + void fill_log_msg(log_msg &msg) + { + msg.logger_name = &logger_name; + msg.level = level; + msg.time = time; + msg.thread_id = thread_id; + msg.raw << txt; + msg.msg_id = msg_id; + } + }; + +public: + + using item_type = async_msg; + using q_type = details::mpmc_bounded_queue; + + using clock = std::chrono::steady_clock; + + + async_log_helper(formatter_ptr formatter, + const std::vector& sinks, + size_t queue_size, + const log_err_handler err_handler, + const async_overflow_policy overflow_policy = async_overflow_policy::block_retry, + const std::function& worker_warmup_cb = nullptr, + const std::chrono::milliseconds& flush_interval_ms = std::chrono::milliseconds::zero(), + const std::function& worker_teardown_cb = nullptr); + + void log(const details::log_msg& msg); + + // stop logging and join the back thread + ~async_log_helper(); + + void set_formatter(formatter_ptr); + + void flush(bool wait_for_q); + + void set_error_handler(spdlog::log_err_handler err_handler); + +private: + formatter_ptr _formatter; + std::vector> _sinks; + + // queue of messages to log + q_type _q; + + log_err_handler _err_handler; + + bool _flush_requested; + + bool _terminate_requested; + + + // overflow policy + const async_overflow_policy _overflow_policy; + + // worker thread warmup callback - one can set thread priority, affinity, etc + const std::function _worker_warmup_cb; + + // auto periodic sink flush parameter + const std::chrono::milliseconds _flush_interval_ms; + + // worker thread teardown callback + const std::function _worker_teardown_cb; + + // worker thread + std::thread _worker_thread; + + void push_msg(async_msg&& new_msg); + + // worker thread main loop + void worker_loop(); + + // pop next message from the queue and process it. will set the last_pop to the pop time + // return false if termination of the queue is required + bool process_next_msg(log_clock::time_point& last_pop, log_clock::time_point& last_flush); + + void handle_flush_interval(log_clock::time_point& now, log_clock::time_point& last_flush); + + // sleep,yield or return immediately using the time passed since last message as a hint + static void sleep_or_yield(const spdlog::log_clock::time_point& now, const log_clock::time_point& last_op_time); + + // wait until the queue is empty + void wait_empty_q(); + +}; +} +} + +/////////////////////////////////////////////////////////////////////////////// +// async_sink class implementation +/////////////////////////////////////////////////////////////////////////////// +inline spdlog::details::async_log_helper::async_log_helper( + formatter_ptr formatter, + const std::vector& sinks, + size_t queue_size, + log_err_handler err_handler, + const async_overflow_policy overflow_policy, + const std::function& worker_warmup_cb, + const std::chrono::milliseconds& flush_interval_ms, + const std::function& worker_teardown_cb): + _formatter(formatter), + _sinks(sinks), + _q(queue_size), + _err_handler(err_handler), + _flush_requested(false), + _terminate_requested(false), + _overflow_policy(overflow_policy), + _worker_warmup_cb(worker_warmup_cb), + _flush_interval_ms(flush_interval_ms), + _worker_teardown_cb(worker_teardown_cb), + _worker_thread(&async_log_helper::worker_loop, this) +{} + +// Send to the worker thread termination message(level=off) +// and wait for it to finish gracefully +inline spdlog::details::async_log_helper::~async_log_helper() +{ + try + { + push_msg(async_msg(async_msg_type::terminate)); + _worker_thread.join(); + } + catch (...) // don't crash in destructor + { + } +} + + +//Try to push and block until succeeded (if the policy is not to discard when the queue is full) +inline void spdlog::details::async_log_helper::log(const details::log_msg& msg) +{ + push_msg(async_msg(msg)); +} + +inline void spdlog::details::async_log_helper::push_msg(details::async_log_helper::async_msg&& new_msg) +{ + if (!_q.enqueue(std::move(new_msg)) && _overflow_policy != async_overflow_policy::discard_log_msg) + { + auto last_op_time = details::os::now(); + auto now = last_op_time; + do + { + now = details::os::now(); + sleep_or_yield(now, last_op_time); + } + while (!_q.enqueue(std::move(new_msg))); + } +} + +// optionally wait for the queue be empty and request flush from the sinks +inline void spdlog::details::async_log_helper::flush(bool wait_for_q) +{ + push_msg(async_msg(async_msg_type::flush)); + if (wait_for_q) + wait_empty_q(); //return only make after the above flush message was processed +} + +inline void spdlog::details::async_log_helper::worker_loop() +{ + if (_worker_warmup_cb) _worker_warmup_cb(); + auto last_pop = details::os::now(); + auto last_flush = last_pop; + auto active = true; + while (active) + { + try + { + active = process_next_msg(last_pop, last_flush); + } + catch (const std::exception &ex) + { + _err_handler(ex.what()); + } + catch (...) + { + _err_handler("Unknown exception"); + } + } + if (_worker_teardown_cb) _worker_teardown_cb(); + + +} + +// process next message in the queue +// return true if this thread should still be active (while no terminate msg was received) +inline bool spdlog::details::async_log_helper::process_next_msg(log_clock::time_point& last_pop, log_clock::time_point& last_flush) +{ + async_msg incoming_async_msg; + + if (_q.dequeue(incoming_async_msg)) + { + last_pop = details::os::now(); + switch (incoming_async_msg.msg_type) + { + case async_msg_type::flush: + _flush_requested = true; + break; + + case async_msg_type::terminate: + _flush_requested = true; + _terminate_requested = true; + break; + + default: + log_msg incoming_log_msg; + incoming_async_msg.fill_log_msg(incoming_log_msg); + _formatter->format(incoming_log_msg); + for (auto &s : _sinks) + { + if (s->should_log(incoming_log_msg.level)) + { + s->log(incoming_log_msg); + } + } + } + return true; + } + + // Handle empty queue.. + // This is the only place where the queue can terminate or flush to avoid losing messages already in the queue + else + { + auto now = details::os::now(); + handle_flush_interval(now, last_flush); + sleep_or_yield(now, last_pop); + return !_terminate_requested; + } +} + +// flush all sinks if _flush_interval_ms has expired +inline void spdlog::details::async_log_helper::handle_flush_interval(log_clock::time_point& now, log_clock::time_point& last_flush) +{ + auto should_flush = _flush_requested || (_flush_interval_ms != std::chrono::milliseconds::zero() && now - last_flush >= _flush_interval_ms); + if (should_flush) + { + for (auto &s : _sinks) + s->flush(); + now = last_flush = details::os::now(); + _flush_requested = false; + } +} + +inline void spdlog::details::async_log_helper::set_formatter(formatter_ptr msg_formatter) +{ + _formatter = msg_formatter; +} + + +// spin, yield or sleep. use the time passed since last message as a hint +inline void spdlog::details::async_log_helper::sleep_or_yield(const spdlog::log_clock::time_point& now, const spdlog::log_clock::time_point& last_op_time) +{ + using namespace std::this_thread; + using std::chrono::milliseconds; + using std::chrono::microseconds; + + auto time_since_op = now - last_op_time; + + // spin upto 50 micros + if (time_since_op <= microseconds(50)) + return; + + // yield upto 150 micros + if (time_since_op <= microseconds(100)) + return std::this_thread::yield(); + + // sleep for 20 ms upto 200 ms + if (time_since_op <= milliseconds(200)) + return sleep_for(milliseconds(20)); + + // sleep for 200 ms + return sleep_for(milliseconds(200)); +} + +// wait for the queue to be empty +inline void spdlog::details::async_log_helper::wait_empty_q() +{ + auto last_op = details::os::now(); + while (_q.approx_size() > 0) + { + sleep_or_yield(details::os::now(), last_op); + } +} + +inline void spdlog::details::async_log_helper::set_error_handler(spdlog::log_err_handler err_handler) +{ + _err_handler = err_handler; +} + + + diff --git a/src/dionysus/wasserstein/spdlog/details/async_logger_impl.h b/src/dionysus/wasserstein/spdlog/details/async_logger_impl.h new file mode 100755 index 0000000..2cac488 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/details/async_logger_impl.h @@ -0,0 +1,105 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +// Async Logger implementation +// Use an async_sink (queue per logger) to perform the logging in a worker thread + +#include "../details/async_log_helper.h" +#include "../async_logger.h" + +#include +#include +#include +#include + +template +inline spdlog::async_logger::async_logger(const std::string& logger_name, + const It& begin, + const It& end, + size_t queue_size, + const async_overflow_policy overflow_policy, + const std::function& worker_warmup_cb, + const std::chrono::milliseconds& flush_interval_ms, + const std::function& worker_teardown_cb) : + logger(logger_name, begin, end), + _async_log_helper(new details::async_log_helper(_formatter, _sinks, queue_size, _err_handler, overflow_policy, worker_warmup_cb, flush_interval_ms, worker_teardown_cb)) +{ +} + +inline spdlog::async_logger::async_logger(const std::string& logger_name, + sinks_init_list sinks_list, + size_t queue_size, + const async_overflow_policy overflow_policy, + const std::function& worker_warmup_cb, + const std::chrono::milliseconds& flush_interval_ms, + const std::function& worker_teardown_cb) : + async_logger(logger_name, sinks_list.begin(), sinks_list.end(), queue_size, overflow_policy, worker_warmup_cb, flush_interval_ms, worker_teardown_cb) {} + +inline spdlog::async_logger::async_logger(const std::string& logger_name, + sink_ptr single_sink, + size_t queue_size, + const async_overflow_policy overflow_policy, + const std::function& worker_warmup_cb, + const std::chrono::milliseconds& flush_interval_ms, + const std::function& worker_teardown_cb) : + async_logger(logger_name, +{ + single_sink +}, queue_size, overflow_policy, worker_warmup_cb, flush_interval_ms, worker_teardown_cb) {} + + +inline void spdlog::async_logger::flush() +{ + _async_log_helper->flush(true); +} + +// Error handler +inline void spdlog::async_logger::set_error_handler(spdlog::log_err_handler err_handler) +{ + _err_handler = err_handler; + _async_log_helper->set_error_handler(err_handler); + +} +inline spdlog::log_err_handler spdlog::async_logger::error_handler() +{ + return _err_handler; +} + + +inline void spdlog::async_logger::_set_formatter(spdlog::formatter_ptr msg_formatter) +{ + _formatter = msg_formatter; + _async_log_helper->set_formatter(_formatter); +} + +inline void spdlog::async_logger::_set_pattern(const std::string& pattern, pattern_time_type pattern_time) +{ + _formatter = std::make_shared(pattern, pattern_time); + _async_log_helper->set_formatter(_formatter); +} + + +inline void spdlog::async_logger::_sink_it(details::log_msg& msg) +{ + try + { +#if defined(SPDLOG_ENABLE_MESSAGE_COUNTER) + msg.msg_id = _msg_counter.fetch_add(1, std::memory_order_relaxed); +#endif + _async_log_helper->log(msg); + if (_should_flush_on(msg)) + _async_log_helper->flush(false); // do async flush + } + catch (const std::exception &ex) + { + _err_handler(ex.what()); + } + catch (...) + { + _err_handler("Unknown exception"); + } +} diff --git a/src/dionysus/wasserstein/spdlog/details/file_helper.h b/src/dionysus/wasserstein/spdlog/details/file_helper.h new file mode 100755 index 0000000..0d6b703 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/details/file_helper.h @@ -0,0 +1,117 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +// Helper class for file sink +// When failing to open a file, retry several times(5) with small delay between the tries(10 ms) +// Throw spdlog_ex exception on errors + +#include "../details/os.h" +#include "../details/log_msg.h" + +#include +#include +#include +#include +#include + +namespace spdlog +{ +namespace details +{ + +class file_helper +{ + +public: + const int open_tries = 5; + const int open_interval = 10; + + explicit file_helper() : + _fd(nullptr) + {} + + file_helper(const file_helper&) = delete; + file_helper& operator=(const file_helper&) = delete; + + ~file_helper() + { + close(); + } + + + void open(const filename_t& fname, bool truncate = false) + { + + close(); + auto *mode = truncate ? SPDLOG_FILENAME_T("wb") : SPDLOG_FILENAME_T("ab"); + _filename = fname; + for (int tries = 0; tries < open_tries; ++tries) + { + if (!os::fopen_s(&_fd, fname, mode)) + return; + + std::this_thread::sleep_for(std::chrono::milliseconds(open_interval)); + } + + throw spdlog_ex("Failed opening file " + os::filename_to_str(_filename) + " for writing", errno); + } + + void reopen(bool truncate) + { + if (_filename.empty()) + throw spdlog_ex("Failed re opening file - was not opened before"); + open(_filename, truncate); + + } + + void flush() + { + std::fflush(_fd); + } + + void close() + { + if (_fd) + { + std::fclose(_fd); + _fd = nullptr; + } + } + + void write(const log_msg& msg) + { + + size_t msg_size = msg.formatted.size(); + auto data = msg.formatted.data(); + if (std::fwrite(data, 1, msg_size, _fd) != msg_size) + throw spdlog_ex("Failed writing to file " + os::filename_to_str(_filename), errno); + } + + size_t size() + { + if (!_fd) + throw spdlog_ex("Cannot use size() on closed file " + os::filename_to_str(_filename)); + return os::filesize(_fd); + } + + const filename_t& filename() const + { + return _filename; + } + + static bool file_exists(const filename_t& name) + { + + return os::file_exists(name); + } + +private: + FILE* _fd; + filename_t _filename; +}; +} +} diff --git a/src/dionysus/wasserstein/spdlog/details/log_msg.h b/src/dionysus/wasserstein/spdlog/details/log_msg.h new file mode 100755 index 0000000..a9fe920 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/details/log_msg.h @@ -0,0 +1,50 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include "../common.h" +#include "../details/os.h" + + +#include +#include + +namespace spdlog +{ +namespace details +{ +struct log_msg +{ + log_msg() = default; + log_msg(const std::string *loggers_name, level::level_enum lvl) : + logger_name(loggers_name), + level(lvl), + msg_id(0) + { +#ifndef SPDLOG_NO_DATETIME + time = os::now(); +#endif + +#ifndef SPDLOG_NO_THREAD_ID + thread_id = os::thread_id(); +#endif + } + + log_msg(const log_msg& other) = delete; + log_msg& operator=(log_msg&& other) = delete; + log_msg(log_msg&& other) = delete; + + + const std::string *logger_name; + level::level_enum level; + log_clock::time_point time; + size_t thread_id; + fmt::MemoryWriter raw; + fmt::MemoryWriter formatted; + size_t msg_id; +}; +} +} diff --git a/src/dionysus/wasserstein/spdlog/details/logger_impl.h b/src/dionysus/wasserstein/spdlog/details/logger_impl.h new file mode 100755 index 0000000..9203769 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/details/logger_impl.h @@ -0,0 +1,563 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include "../logger.h" +#include "../sinks/stdout_sinks.h" + +#include +#include + + +// create logger with given name, sinks and the default pattern formatter +// all other ctors will call this one +template +inline spdlog::logger::logger(const std::string& logger_name, const It& begin, const It& end): + _name(logger_name), + _sinks(begin, end), + _formatter(std::make_shared("%+")), + _level(level::info), + _flush_level(level::off), + _last_err_time(0), + _msg_counter(1) // message counter will start from 1. 0-message id will be reserved for controll messages +{ + _err_handler = [this](const std::string &msg) + { + this->_default_err_handler(msg); + }; +} + +// ctor with sinks as init list +inline spdlog::logger::logger(const std::string& logger_name, sinks_init_list sinks_list): + logger(logger_name, sinks_list.begin(), sinks_list.end()) +{} + + +// ctor with single sink +inline spdlog::logger::logger(const std::string& logger_name, spdlog::sink_ptr single_sink): + logger(logger_name, +{ + single_sink +}) +{} + + +inline spdlog::logger::~logger() = default; + + +inline void spdlog::logger::set_formatter(spdlog::formatter_ptr msg_formatter) +{ + _set_formatter(msg_formatter); +} + +inline void spdlog::logger::set_pattern(const std::string& pattern, pattern_time_type pattern_time) +{ + _set_pattern(pattern, pattern_time); +} + + +template +inline void spdlog::logger::log(level::level_enum lvl, const char* fmt, const Args&... args) +{ + if (!should_log(lvl)) return; + + try + { + details::log_msg log_msg(&_name, lvl); + log_msg.raw.write(fmt, args...); + _sink_it(log_msg); + } + catch (const std::exception &ex) + { + _err_handler(ex.what()); + } + catch (...) + { + _err_handler("Unknown exception"); + } +} + +template +inline void spdlog::logger::log(level::level_enum lvl, const char* msg) +{ + if (!should_log(lvl)) return; + try + { + details::log_msg log_msg(&_name, lvl); + log_msg.raw << msg; + _sink_it(log_msg); + } + catch (const std::exception &ex) + { + _err_handler(ex.what()); + } + catch (...) + { + _err_handler("Unknown exception"); + } + +} + +template +inline void spdlog::logger::log(level::level_enum lvl, const T& msg) +{ + if (!should_log(lvl)) return; + try + { + details::log_msg log_msg(&_name, lvl); + log_msg.raw << msg; + _sink_it(log_msg); + } + catch (const std::exception &ex) + { + _err_handler(ex.what()); + } + catch (...) + { + _err_handler("Unknown exception"); + } +} + + +template +inline void spdlog::logger::trace(const char* fmt, const Arg1 &arg1, const Args&... args) +{ + log(level::trace, fmt, arg1, args...); +} + +template +inline void spdlog::logger::debug(const char* fmt, const Arg1 &arg1, const Args&... args) +{ + log(level::debug, fmt, arg1, args...); +} + +template +inline void spdlog::logger::info(const char* fmt, const Arg1 &arg1, const Args&... args) +{ + log(level::info, fmt, arg1, args...); +} + +template +inline void spdlog::logger::warn(const char* fmt, const Arg1 &arg1, const Args&... args) +{ + log(level::warn, fmt, arg1, args...); +} + +template +inline void spdlog::logger::error(const char* fmt, const Arg1 &arg1, const Args&... args) +{ + log(level::err, fmt, arg1, args...); +} + +template +inline void spdlog::logger::critical(const char* fmt, const Arg1 &arg1, const Args&... args) +{ + log(level::critical, fmt, arg1, args...); +} + +template +inline void spdlog::logger::log_if(const bool flag, level::level_enum lvl, const char* msg) +{ + if (flag) + { + log(lvl, msg); + } +} + +template +inline void spdlog::logger::log_if(const bool flag, level::level_enum lvl, const T& msg) +{ + if (flag) + { + log(lvl, msg); + } +} + +template +inline void spdlog::logger::trace_if(const bool flag, const char* fmt, const Arg1 &arg1, const Args&... args) +{ + if (flag) + { + log(level::trace, fmt, arg1, args...); + } +} + +template +inline void spdlog::logger::debug_if(const bool flag, const char* fmt, const Arg1 &arg1, const Args&... args) +{ + if (flag) + { + log(level::debug, fmt, arg1, args...); + } +} + +template +inline void spdlog::logger::info_if(const bool flag, const char* fmt, const Arg1 &arg1, const Args&... args) +{ + if (flag) + { + log(level::info, fmt, arg1, args...); + } +} + +template +inline void spdlog::logger::warn_if(const bool flag, const char* fmt, const Arg1& arg1, const Args&... args) +{ + if (flag) + { + log(level::warn, fmt, arg1, args...); + } +} + +template +inline void spdlog::logger::error_if(const bool flag, const char* fmt, const Arg1 &arg1, const Args&... args) +{ + if (flag) + { + log(level::err, fmt, arg1, args...); + } +} + +template +inline void spdlog::logger::critical_if(const bool flag, const char* fmt, const Arg1 &arg1, const Args&... args) +{ + if (flag) + { + log(level::critical, fmt, arg1, args...); + } +} + + +template +inline void spdlog::logger::trace(const T& msg) +{ + log(level::trace, msg); +} + +template +inline void spdlog::logger::debug(const T& msg) +{ + log(level::debug, msg); +} + + +template +inline void spdlog::logger::info(const T& msg) +{ + log(level::info, msg); +} + + +template +inline void spdlog::logger::warn(const T& msg) +{ + log(level::warn, msg); +} + +template +inline void spdlog::logger::error(const T& msg) +{ + log(level::err, msg); +} + +template +inline void spdlog::logger::critical(const T& msg) +{ + log(level::critical, msg); +} + +template +inline void spdlog::logger::trace_if(const bool flag, const T& msg) +{ + if (flag) + { + log(level::trace, msg); + } +} + +template +inline void spdlog::logger::debug_if(const bool flag, const T& msg) +{ + if (flag) + { + log(level::debug, msg); + } +} + +template +inline void spdlog::logger::info_if(const bool flag, const T& msg) +{ + if (flag) + { + log(level::info, msg); + } +} + +template +inline void spdlog::logger::warn_if(const bool flag, const T& msg) +{ + if (flag) + { + log(level::warn, msg); + } +} + +template +inline void spdlog::logger::error_if(const bool flag, const T& msg) +{ + if (flag) + { + log(level::err, msg); + } +} + +template +inline void spdlog::logger::critical_if(const bool flag, const T& msg) +{ + if (flag) + { + log(level::critical, msg); + } +} + + +#ifdef SPDLOG_WCHAR_TO_UTF8_SUPPORT +#include + +template +inline void spdlog::logger::log(level::level_enum lvl, const wchar_t* msg) +{ + std::wstring_convert > conv; + + log(lvl, conv.to_bytes(msg)); +} + +template +inline void spdlog::logger::log(level::level_enum lvl, const wchar_t* fmt, const Args&... args) +{ + fmt::WMemoryWriter wWriter; + + wWriter.write(fmt, args...); + log(lvl, wWriter.c_str()); +} + +template +inline void spdlog::logger::trace(const wchar_t* fmt, const Args&... args) +{ + log(level::trace, fmt, args...); +} + +template +inline void spdlog::logger::debug(const wchar_t* fmt, const Args&... args) +{ + log(level::debug, fmt, args...); +} + +template +inline void spdlog::logger::info(const wchar_t* fmt, const Args&... args) +{ + log(level::info, fmt, args...); +} + + +template +inline void spdlog::logger::warn(const wchar_t* fmt, const Args&... args) +{ + log(level::warn, fmt, args...); +} + +template +inline void spdlog::logger::error(const wchar_t* fmt, const Args&... args) +{ + log(level::err, fmt, args...); +} + +template +inline void spdlog::logger::critical(const wchar_t* fmt, const Args&... args) +{ + log(level::critical, fmt, args...); +} + +// +// conditional logging +// + +template +inline void spdlog::logger::log_if(const bool flag, level::level_enum lvl, const wchar_t* msg) +{ + if (flag) + { + log(lvl, msg); + } +} + +template +inline void spdlog::logger::log_if(const bool flag, level::level_enum lvl, const wchar_t* fmt, const Args&... args) +{ + if (flag) + { + log(lvl, fmt, args); + } +} + +template +inline void spdlog::logger::trace_if(const bool flag, const wchar_t* fmt, const Args&... args) +{ + if (flag) + { + log(level::trace, fmt, args...); + } +} + +template +inline void spdlog::logger::debug_if(const bool flag, const wchar_t* fmt, const Args&... args) +{ + if (flag) + { + log(level::debug, fmt, args...); + } +} + +template +inline void spdlog::logger::info_if(const bool flag, const wchar_t* fmt, const Args&... args) +{ + if (flag) + { + log(level::info, fmt, args...); + } +} + + +template +inline void spdlog::logger::warn_if(const bool flag, const wchar_t* fmt, const Args&... args) +{ + if (flag) + { + log(level::warn, fmt, args...); + } +} + +template +inline void spdlog::logger::error_if(const bool flag, const wchar_t* fmt, const Args&... args) +{ + if (flag) + { + log(level::err, fmt, args...); + } +} + +template +inline void spdlog::logger::critical_if(const bool flag, const wchar_t* fmt, const Args&... args) +{ + if (flag) + { + log(level::critical, fmt, args...); + } +} + +#endif // SPDLOG_WCHAR_TO_UTF8_SUPPORT + + + +// +// name and level +// +inline const std::string& spdlog::logger::name() const +{ + return _name; +} + +inline void spdlog::logger::set_level(spdlog::level::level_enum log_level) +{ + _level.store(log_level); +} + +inline void spdlog::logger::set_error_handler(spdlog::log_err_handler err_handler) +{ + _err_handler = err_handler; +} + +inline spdlog::log_err_handler spdlog::logger::error_handler() +{ + return _err_handler; +} + + +inline void spdlog::logger::flush_on(level::level_enum log_level) +{ + _flush_level.store(log_level); +} + +inline spdlog::level::level_enum spdlog::logger::level() const +{ + return static_cast(_level.load(std::memory_order_relaxed)); +} + +inline bool spdlog::logger::should_log(spdlog::level::level_enum msg_level) const +{ + return msg_level >= _level.load(std::memory_order_relaxed); +} + +// +// protected virtual called at end of each user log call (if enabled) by the line_logger +// +inline void spdlog::logger::_sink_it(details::log_msg& msg) +{ +#if defined(SPDLOG_ENABLE_MESSAGE_COUNTER) + msg.msg_id = _msg_counter.fetch_add(1, std::memory_order_relaxed); +#endif + _formatter->format(msg); + for (auto &sink : _sinks) + { + if( sink->should_log( msg.level)) + { + sink->log(msg); + } + } + + if(_should_flush_on(msg)) + flush(); +} + +inline void spdlog::logger::_set_pattern(const std::string& pattern, pattern_time_type pattern_time) +{ + _formatter = std::make_shared(pattern, pattern_time); +} +inline void spdlog::logger::_set_formatter(formatter_ptr msg_formatter) +{ + _formatter = msg_formatter; +} + +inline void spdlog::logger::flush() +{ + for (auto& sink : _sinks) + sink->flush(); +} + +inline void spdlog::logger::_default_err_handler(const std::string &msg) +{ + auto now = time(nullptr); + if (now - _last_err_time < 60) + return; + auto tm_time = details::os::localtime(now); + char date_buf[100]; + std::strftime(date_buf, sizeof(date_buf), "%Y-%m-%d %H:%M:%S", &tm_time); + details::log_msg err_msg; + err_msg.formatted.write("[*** LOG ERROR ***] [{}] [{}] [{}]{}", name(), msg, date_buf, details::os::eol); + sinks::stderr_sink_mt::instance()->log(err_msg); + _last_err_time = now; +} + +inline bool spdlog::logger::_should_flush_on(const details::log_msg &msg) +{ + const auto flush_level = _flush_level.load(std::memory_order_relaxed); + return (msg.level >= flush_level) && (msg.level != level::off); +} + +inline const std::vector& spdlog::logger::sinks() const +{ + return _sinks; +} diff --git a/src/dionysus/wasserstein/spdlog/details/mpmc_bounded_q.h b/src/dionysus/wasserstein/spdlog/details/mpmc_bounded_q.h new file mode 100755 index 0000000..378f5d2 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/details/mpmc_bounded_q.h @@ -0,0 +1,172 @@ +/* +A modified version of Bounded MPMC queue by Dmitry Vyukov. + +Original code from: +http://www.1024cores.net/home/lock-free-algorithms/queues/bounded-mpmc-queue + +licensed by Dmitry Vyukov under the terms below: + +Simplified BSD license + +Copyright (c) 2010-2011 Dmitry Vyukov. All rights reserved. +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of +conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list +of conditions and the following disclaimer in the documentation and/or other materials +provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY DMITRY VYUKOV "AS IS" AND ANY EXPRESS OR IMPLIED +WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT +SHALL DMITRY VYUKOV OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +The views and conclusions contained in the software and documentation are those of the authors and +should not be interpreted as representing official policies, either expressed or implied, of Dmitry Vyukov. +*/ + +/* +The code in its current form adds the license below: + +Copyright(c) 2015 Gabi Melman. +Distributed under the MIT License (http://opensource.org/licenses/MIT) + +*/ + +#pragma once + +#include "../common.h" + +#include +#include + +namespace spdlog +{ +namespace details +{ + +template +class mpmc_bounded_queue +{ +public: + + using item_type = T; + mpmc_bounded_queue(size_t buffer_size) + :max_size_(buffer_size), + buffer_(new cell_t [buffer_size]), + buffer_mask_(buffer_size - 1) + { + //queue size must be power of two + if(!((buffer_size >= 2) && ((buffer_size & (buffer_size - 1)) == 0))) + throw spdlog_ex("async logger queue size must be power of two"); + + for (size_t i = 0; i != buffer_size; i += 1) + buffer_[i].sequence_.store(i, std::memory_order_relaxed); + enqueue_pos_.store(0, std::memory_order_relaxed); + dequeue_pos_.store(0, std::memory_order_relaxed); + } + + ~mpmc_bounded_queue() + { + delete [] buffer_; + } + + + bool enqueue(T&& data) + { + cell_t* cell; + size_t pos = enqueue_pos_.load(std::memory_order_relaxed); + for (;;) + { + cell = &buffer_[pos & buffer_mask_]; + size_t seq = cell->sequence_.load(std::memory_order_acquire); + intptr_t dif = (intptr_t)seq - (intptr_t)pos; + if (dif == 0) + { + if (enqueue_pos_.compare_exchange_weak(pos, pos + 1, std::memory_order_relaxed)) + break; + } + else if (dif < 0) + { + return false; + } + else + { + pos = enqueue_pos_.load(std::memory_order_relaxed); + } + } + cell->data_ = std::move(data); + cell->sequence_.store(pos + 1, std::memory_order_release); + return true; + } + + bool dequeue(T& data) + { + cell_t* cell; + size_t pos = dequeue_pos_.load(std::memory_order_relaxed); + for (;;) + { + cell = &buffer_[pos & buffer_mask_]; + size_t seq = + cell->sequence_.load(std::memory_order_acquire); + intptr_t dif = (intptr_t)seq - (intptr_t)(pos + 1); + if (dif == 0) + { + if (dequeue_pos_.compare_exchange_weak(pos, pos + 1, std::memory_order_relaxed)) + break; + } + else if (dif < 0) + return false; + else + pos = dequeue_pos_.load(std::memory_order_relaxed); + } + data = std::move(cell->data_); + cell->sequence_.store(pos + buffer_mask_ + 1, std::memory_order_release); + return true; + } + + size_t approx_size() + { + size_t first_pos = dequeue_pos_.load(std::memory_order_relaxed); + size_t last_pos = enqueue_pos_.load(std::memory_order_relaxed); + if (last_pos <= first_pos) + return 0; + auto size = last_pos - first_pos; + return size < max_size_ ? size : max_size_; + } + +private: + struct cell_t + { + std::atomic sequence_; + T data_; + }; + + size_t const max_size_; + + static size_t const cacheline_size = 64; + typedef char cacheline_pad_t [cacheline_size]; + + cacheline_pad_t pad0_; + cell_t* const buffer_; + size_t const buffer_mask_; + cacheline_pad_t pad1_; + std::atomic enqueue_pos_; + cacheline_pad_t pad2_; + std::atomic dequeue_pos_; + cacheline_pad_t pad3_; + + mpmc_bounded_queue(mpmc_bounded_queue const&) = delete; + void operator= (mpmc_bounded_queue const&) = delete; +}; + +} // ns details +} // ns spdlog diff --git a/src/dionysus/wasserstein/spdlog/details/null_mutex.h b/src/dionysus/wasserstein/spdlog/details/null_mutex.h new file mode 100755 index 0000000..67b0aee --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/details/null_mutex.h @@ -0,0 +1,45 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include +// null, no cost dummy "mutex" and dummy "atomic" int + +namespace spdlog +{ +namespace details +{ +struct null_mutex +{ + void lock() {} + void unlock() {} + bool try_lock() + { + return true; + } +}; + +struct null_atomic_int +{ + int value; + null_atomic_int() = default; + + null_atomic_int(int val):value(val) + {} + + int load(std::memory_order) const + { + return value; + } + + void store(int val) + { + value = val; + } +}; + +} +} diff --git a/src/dionysus/wasserstein/spdlog/details/os.h b/src/dionysus/wasserstein/spdlog/details/os.h new file mode 100755 index 0000000..3503680 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/details/os.h @@ -0,0 +1,469 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// +#pragma once + +#include "../common.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 + +#ifndef NOMINMAX +#define NOMINMAX //prevent windows redefining min/max +#endif + +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#include +#include // _get_pid support +#include // _get_osfhandle and _isatty support + +#ifdef __MINGW32__ +#include +#endif + +#else // unix + +#include +#include + +#ifdef __linux__ +#include //Use gettid() syscall under linux to get thread id + +#elif __FreeBSD__ +#include //Use thr_self() syscall under FreeBSD to get thread id +#endif + +#endif //unix + +#ifndef __has_feature // Clang - feature checking macros. +#define __has_feature(x) 0 // Compatibility with non-clang compilers. +#endif + + +namespace spdlog +{ +namespace details +{ +namespace os +{ + +inline spdlog::log_clock::time_point now() +{ + +#if defined __linux__ && defined SPDLOG_CLOCK_COARSE + timespec ts; + ::clock_gettime(CLOCK_REALTIME_COARSE, &ts); + return std::chrono::time_point( + std::chrono::duration_cast( + std::chrono::seconds(ts.tv_sec) + std::chrono::nanoseconds(ts.tv_nsec))); + + +#else + return log_clock::now(); +#endif + +} +inline std::tm localtime(const std::time_t &time_tt) +{ + +#ifdef _WIN32 + std::tm tm; + localtime_s(&tm, &time_tt); +#else + std::tm tm; + localtime_r(&time_tt, &tm); +#endif + return tm; +} + +inline std::tm localtime() +{ + std::time_t now_t = time(nullptr); + return localtime(now_t); +} + + +inline std::tm gmtime(const std::time_t &time_tt) +{ + +#ifdef _WIN32 + std::tm tm; + gmtime_s(&tm, &time_tt); +#else + std::tm tm; + gmtime_r(&time_tt, &tm); +#endif + return tm; +} + +inline std::tm gmtime() +{ + std::time_t now_t = time(nullptr); + return gmtime(now_t); +} +inline bool operator==(const std::tm& tm1, const std::tm& tm2) +{ + return (tm1.tm_sec == tm2.tm_sec && + tm1.tm_min == tm2.tm_min && + tm1.tm_hour == tm2.tm_hour && + tm1.tm_mday == tm2.tm_mday && + tm1.tm_mon == tm2.tm_mon && + tm1.tm_year == tm2.tm_year && + tm1.tm_isdst == tm2.tm_isdst); +} + +inline bool operator!=(const std::tm& tm1, const std::tm& tm2) +{ + return !(tm1 == tm2); +} + +// eol definition +#if !defined (SPDLOG_EOL) +#ifdef _WIN32 +#define SPDLOG_EOL "\r\n" +#else +#define SPDLOG_EOL "\n" +#endif +#endif + +SPDLOG_CONSTEXPR static const char* eol = SPDLOG_EOL; +SPDLOG_CONSTEXPR static int eol_size = sizeof(SPDLOG_EOL) - 1; + +inline void prevent_child_fd(FILE *f) +{ +#ifdef _WIN32 + auto file_handle = (HANDLE)_get_osfhandle(_fileno(f)); + if (!::SetHandleInformation(file_handle, HANDLE_FLAG_INHERIT, 0)) + throw spdlog_ex("SetHandleInformation failed", errno); +#else + auto fd = fileno(f); + if (fcntl(fd, F_SETFD, FD_CLOEXEC) == -1) + throw spdlog_ex("fcntl with FD_CLOEXEC failed", errno); +#endif +} + + +//fopen_s on non windows for writing +inline int fopen_s(FILE** fp, const filename_t& filename, const filename_t& mode) +{ +#ifdef _WIN32 +#ifdef SPDLOG_WCHAR_FILENAMES + *fp = _wfsopen((filename.c_str()), mode.c_str(), _SH_DENYWR); +#else + *fp = _fsopen((filename.c_str()), mode.c_str(), _SH_DENYWR); +#endif +#else //unix + *fp = fopen((filename.c_str()), mode.c_str()); +#endif + +#ifdef SPDLOG_PREVENT_CHILD_FD + if (*fp != nullptr) + prevent_child_fd(*fp); +#endif + return *fp == nullptr; +} + + +inline int remove(const filename_t &filename) +{ +#if defined(_WIN32) && defined(SPDLOG_WCHAR_FILENAMES) + return _wremove(filename.c_str()); +#else + return std::remove(filename.c_str()); +#endif +} + +inline int rename(const filename_t& filename1, const filename_t& filename2) +{ +#if defined(_WIN32) && defined(SPDLOG_WCHAR_FILENAMES) + return _wrename(filename1.c_str(), filename2.c_str()); +#else + return std::rename(filename1.c_str(), filename2.c_str()); +#endif +} + + +//Return if file exists +inline bool file_exists(const filename_t& filename) +{ +#ifdef _WIN32 +#ifdef SPDLOG_WCHAR_FILENAMES + auto attribs = GetFileAttributesW(filename.c_str()); +#else + auto attribs = GetFileAttributesA(filename.c_str()); +#endif + return (attribs != INVALID_FILE_ATTRIBUTES && !(attribs & FILE_ATTRIBUTE_DIRECTORY)); +#else //common linux/unix all have the stat system call + struct stat buffer; + return (stat(filename.c_str(), &buffer) == 0); +#endif +} + + + + +//Return file size according to open FILE* object +inline size_t filesize(FILE *f) +{ + if (f == nullptr) + throw spdlog_ex("Failed getting file size. fd is null"); +#ifdef _WIN32 + int fd = _fileno(f); +#if _WIN64 //64 bits + struct _stat64 st; + if (_fstat64(fd, &st) == 0) + return st.st_size; + +#else //windows 32 bits + long ret = _filelength(fd); + if (ret >= 0) + return static_cast(ret); +#endif + +#else // unix + int fd = fileno(f); + //64 bits(but not in osx, where fstat64 is deprecated) +#if !defined(__FreeBSD__) && !defined(__APPLE__) && (defined(__x86_64__) || defined(__ppc64__)) + struct stat64 st; + if (fstat64(fd, &st) == 0) + return static_cast(st.st_size); +#else // unix 32 bits or osx + struct stat st; + if (fstat(fd, &st) == 0) + return static_cast(st.st_size); +#endif +#endif + throw spdlog_ex("Failed getting file size from fd", errno); +} + + + + +//Return utc offset in minutes or throw spdlog_ex on failure +inline int utc_minutes_offset(const std::tm& tm = details::os::localtime()) +{ + +#ifdef _WIN32 +#if _WIN32_WINNT < _WIN32_WINNT_WS08 + TIME_ZONE_INFORMATION tzinfo; + auto rv = GetTimeZoneInformation(&tzinfo); +#else + DYNAMIC_TIME_ZONE_INFORMATION tzinfo; + auto rv = GetDynamicTimeZoneInformation(&tzinfo); +#endif + if (rv == TIME_ZONE_ID_INVALID) + throw spdlog::spdlog_ex("Failed getting timezone info. ", errno); + + int offset = -tzinfo.Bias; + if (tm.tm_isdst) + offset -= tzinfo.DaylightBias; + else + offset -= tzinfo.StandardBias; + return offset; +#else + +#if defined(sun) || defined(__sun) + // 'tm_gmtoff' field is BSD extension and it's missing on SunOS/Solaris + struct helper + { + static long int calculate_gmt_offset(const std::tm & localtm = details::os::localtime(), const std::tm & gmtm = details::os::gmtime()) + { + int local_year = localtm.tm_year + (1900 - 1); + int gmt_year = gmtm.tm_year + (1900 - 1); + + long int days = ( + // difference in day of year + localtm.tm_yday - gmtm.tm_yday + + // + intervening leap days + + ((local_year >> 2) - (gmt_year >> 2)) + - (local_year / 100 - gmt_year / 100) + + ((local_year / 100 >> 2) - (gmt_year / 100 >> 2)) + + // + difference in years * 365 */ + + (long int)(local_year - gmt_year) * 365 + ); + + long int hours = (24 * days) + (localtm.tm_hour - gmtm.tm_hour); + long int mins = (60 * hours) + (localtm.tm_min - gmtm.tm_min); + long int secs = (60 * mins) + (localtm.tm_sec - gmtm.tm_sec); + + return secs; + } + }; + + long int offset_seconds = helper::calculate_gmt_offset(tm); +#else + long int offset_seconds = tm.tm_gmtoff; +#endif + + return static_cast(offset_seconds / 60); +#endif +} + +//Return current thread id as size_t +//It exists because the std::this_thread::get_id() is much slower(espcially under VS 2013) +inline size_t _thread_id() +{ +#ifdef _WIN32 + return static_cast(::GetCurrentThreadId()); +#elif __linux__ +# if defined(__ANDROID__) && defined(__ANDROID_API__) && (__ANDROID_API__ < 21) +# define SYS_gettid __NR_gettid +# endif + return static_cast(syscall(SYS_gettid)); +#elif __FreeBSD__ + long tid; + thr_self(&tid); + return static_cast(tid); +#elif __APPLE__ + uint64_t tid; + pthread_threadid_np(nullptr, &tid); + return static_cast(tid); +#else //Default to standard C++11 (other Unix) + return static_cast(std::hash()(std::this_thread::get_id())); +#endif +} + +//Return current thread id as size_t (from thread local storage) +inline size_t thread_id() +{ +#if defined(_MSC_VER) && (_MSC_VER < 1900) || defined(__clang__) && !__has_feature(cxx_thread_local) + return _thread_id(); +#else + static thread_local const size_t tid = _thread_id(); + return tid; +#endif +} + + + + +// wchar support for windows file names (SPDLOG_WCHAR_FILENAMES must be defined) +#if defined(_WIN32) && defined(SPDLOG_WCHAR_FILENAMES) +#define SPDLOG_FILENAME_T(s) L ## s +inline std::string filename_to_str(const filename_t& filename) +{ + std::wstring_convert, wchar_t> c; + return c.to_bytes(filename); +} +#else +#define SPDLOG_FILENAME_T(s) s +inline std::string filename_to_str(const filename_t& filename) +{ + return filename; +} +#endif + +inline std::string errno_to_string(char[256], char* res) +{ + return std::string(res); +} + +inline std::string errno_to_string(char buf[256], int res) +{ + if (res == 0) + { + return std::string(buf); + } + else + { + return "Unknown error"; + } +} + +// Return errno string (thread safe) +inline std::string errno_str(int err_num) +{ + char buf[256]; + SPDLOG_CONSTEXPR auto buf_size = sizeof(buf); + +#ifdef _WIN32 + if (strerror_s(buf, buf_size, err_num) == 0) + return std::string(buf); + else + return "Unknown error"; + +#elif defined(__FreeBSD__) || defined(__APPLE__) || defined(ANDROID) || defined(__SUNPRO_CC) || \ + ((_POSIX_C_SOURCE >= 200112L) && ! defined(_GNU_SOURCE)) // posix version + + if (strerror_r(err_num, buf, buf_size) == 0) + return std::string(buf); + else + return "Unknown error"; + +#else // gnu version (might not use the given buf, so its retval pointer must be used) + auto err = strerror_r(err_num, buf, buf_size); // let compiler choose type + return errno_to_string(buf, err); // use overloading to select correct stringify function +#endif +} + +inline int pid() +{ + +#ifdef _WIN32 + return ::_getpid(); +#else + return static_cast(::getpid()); +#endif + +} + + +// Detrmine if the terminal supports colors +// Source: https://github.com/agauniyal/rang/ +inline bool is_color_terminal() +{ +#ifdef _WIN32 + return true; +#else + static constexpr const char* Terms[] = + { + "ansi", "color", "console", "cygwin", "gnome", "konsole", "kterm", + "linux", "msys", "putty", "rxvt", "screen", "vt100", "xterm" + }; + + const char *env_p = std::getenv("TERM"); + if (env_p == nullptr) + { + return false; + } + + static const bool result = std::any_of( + std::begin(Terms), std::end(Terms), [&](const char* term) + { + return std::strstr(env_p, term) != nullptr; + }); + return result; +#endif +} + + +// Detrmine if the terminal attached +// Source: https://github.com/agauniyal/rang/ +inline bool in_terminal(FILE* file) +{ + +#ifdef _WIN32 + return _isatty(_fileno(file)) ? true : false; +#else + return isatty(fileno(file)) ? true : false; +#endif +} +} //os +} //details +} //spdlog diff --git a/src/dionysus/wasserstein/spdlog/details/pattern_formatter_impl.h b/src/dionysus/wasserstein/spdlog/details/pattern_formatter_impl.h new file mode 100755 index 0000000..058c34d --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/details/pattern_formatter_impl.h @@ -0,0 +1,690 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include "../formatter.h" +#include "../details/log_msg.h" +#include "../details/os.h" +#include "../fmt/fmt.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace spdlog +{ +namespace details +{ +class flag_formatter +{ +public: + virtual ~flag_formatter() + {} + virtual void format(details::log_msg& msg, const std::tm& tm_time) = 0; +}; + +/////////////////////////////////////////////////////////////////////// +// name & level pattern appenders +/////////////////////////////////////////////////////////////////////// +namespace +{ +class name_formatter:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm&) override + { + msg.formatted << *msg.logger_name; + } +}; +} + +// log level appender +class level_formatter:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm&) override + { + msg.formatted << level::to_str(msg.level); + } +}; + +// short log level appender +class short_level_formatter:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm&) override + { + msg.formatted << level::to_short_str(msg.level); + } +}; + +/////////////////////////////////////////////////////////////////////// +// Date time pattern appenders +/////////////////////////////////////////////////////////////////////// + +static const char* ampm(const tm& t) +{ + return t.tm_hour >= 12 ? "PM" : "AM"; +} + +static int to12h(const tm& t) +{ + return t.tm_hour > 12 ? t.tm_hour - 12 : t.tm_hour; +} + +//Abbreviated weekday name +using days_array = std::array; +static const days_array& days() +{ + static const days_array arr{ { "Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat" } }; + return arr; +} +class a_formatter:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << days()[tm_time.tm_wday]; + } +}; +// message counter formatter +class i_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm&) override + { + msg.formatted << '#' << msg.msg_id; + } +}; +//Full weekday name +static const days_array& full_days() +{ + static const days_array arr{ { "Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday" } }; + return arr; +} +class A_formatter SPDLOG_FINAL :public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << full_days()[tm_time.tm_wday]; + } +}; + +//Abbreviated month +using months_array = std::array; +static const months_array& months() +{ + static const months_array arr{ { "Jan", "Feb", "Mar", "Apr", "May", "June", "July", "Aug", "Sept", "Oct", "Nov", "Dec" } }; + return arr; +} +class b_formatter:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << months()[tm_time.tm_mon]; + } +}; + +//Full month name +static const months_array& full_months() +{ + static const months_array arr{ { "January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December" } }; + return arr; +} +class B_formatter SPDLOG_FINAL :public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << full_months()[tm_time.tm_mon]; + } +}; + + +//write 2 ints seperated by sep with padding of 2 +static fmt::MemoryWriter& pad_n_join(fmt::MemoryWriter& w, int v1, int v2, char sep) +{ + w << fmt::pad(v1, 2, '0') << sep << fmt::pad(v2, 2, '0'); + return w; +} + +//write 3 ints seperated by sep with padding of 2 +static fmt::MemoryWriter& pad_n_join(fmt::MemoryWriter& w, int v1, int v2, int v3, char sep) +{ + w << fmt::pad(v1, 2, '0') << sep << fmt::pad(v2, 2, '0') << sep << fmt::pad(v3, 2, '0'); + return w; +} + + +//Date and time representation (Thu Aug 23 15:35:46 2014) +class c_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << days()[tm_time.tm_wday] << ' ' << months()[tm_time.tm_mon] << ' ' << tm_time.tm_mday << ' '; + pad_n_join(msg.formatted, tm_time.tm_hour, tm_time.tm_min, tm_time.tm_sec, ':') << ' ' << tm_time.tm_year + 1900; + } +}; + + +// year - 2 digit +class C_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << fmt::pad(tm_time.tm_year % 100, 2, '0'); + } +}; + + + +// Short MM/DD/YY date, equivalent to %m/%d/%y 08/23/01 +class D_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + pad_n_join(msg.formatted, tm_time.tm_mon + 1, tm_time.tm_mday, tm_time.tm_year % 100, '/'); + } +}; + + +// year - 4 digit +class Y_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << tm_time.tm_year + 1900; + } +}; + +// month 1-12 +class m_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << fmt::pad(tm_time.tm_mon + 1, 2, '0'); + } +}; + +// day of month 1-31 +class d_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << fmt::pad(tm_time.tm_mday, 2, '0'); + } +}; + +// hours in 24 format 0-23 +class H_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << fmt::pad(tm_time.tm_hour, 2, '0'); + } +}; + +// hours in 12 format 1-12 +class I_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << fmt::pad(to12h(tm_time), 2, '0'); + } +}; + +// minutes 0-59 +class M_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << fmt::pad(tm_time.tm_min, 2, '0'); + } +}; + +// seconds 0-59 +class S_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << fmt::pad(tm_time.tm_sec, 2, '0'); + } +}; + +// milliseconds +class e_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm&) override + { + auto duration = msg.time.time_since_epoch(); + auto millis = std::chrono::duration_cast(duration).count() % 1000; + msg.formatted << fmt::pad(static_cast(millis), 3, '0'); + } +}; + +// microseconds +class f_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm&) override + { + auto duration = msg.time.time_since_epoch(); + auto micros = std::chrono::duration_cast(duration).count() % 1000000; + msg.formatted << fmt::pad(static_cast(micros), 6, '0'); + } +}; + +// nanoseconds +class F_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm&) override + { + auto duration = msg.time.time_since_epoch(); + auto ns = std::chrono::duration_cast(duration).count() % 1000000000; + msg.formatted << fmt::pad(static_cast(ns), 9, '0'); + } +}; + +// AM/PM +class p_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + msg.formatted << ampm(tm_time); + } +}; + + +// 12 hour clock 02:55:02 pm +class r_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + pad_n_join(msg.formatted, to12h(tm_time), tm_time.tm_min, tm_time.tm_sec, ':') << ' ' << ampm(tm_time); + } +}; + +// 24-hour HH:MM time, equivalent to %H:%M +class R_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + pad_n_join(msg.formatted, tm_time.tm_hour, tm_time.tm_min, ':'); + } +}; + +// ISO 8601 time format (HH:MM:SS), equivalent to %H:%M:%S +class T_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { + pad_n_join(msg.formatted, tm_time.tm_hour, tm_time.tm_min, tm_time.tm_sec, ':'); + } +}; + +// ISO 8601 offset from UTC in timezone (+-HH:MM) +class z_formatter SPDLOG_FINAL:public flag_formatter +{ +public: + const std::chrono::seconds cache_refresh = std::chrono::seconds(5); + + z_formatter():_last_update(std::chrono::seconds(0)), _offset_minutes(0) + {} + z_formatter(const z_formatter&) = delete; + z_formatter& operator=(const z_formatter&) = delete; + + void format(details::log_msg& msg, const std::tm& tm_time) override + { +#ifdef _WIN32 + int total_minutes = get_cached_offset(msg, tm_time); +#else + // No need to chache under gcc, + // it is very fast (already stored in tm.tm_gmtoff) + int total_minutes = os::utc_minutes_offset(tm_time); +#endif + bool is_negative = total_minutes < 0; + char sign; + if (is_negative) + { + total_minutes = -total_minutes; + sign = '-'; + } + else + { + sign = '+'; + } + + int h = total_minutes / 60; + int m = total_minutes % 60; + msg.formatted << sign; + pad_n_join(msg.formatted, h, m, ':'); + } +private: + log_clock::time_point _last_update; + int _offset_minutes; + std::mutex _mutex; + + int get_cached_offset(const log_msg& msg, const std::tm& tm_time) + { + using namespace std::chrono; + std::lock_guard l(_mutex); + if (msg.time - _last_update >= cache_refresh) + { + _offset_minutes = os::utc_minutes_offset(tm_time); + _last_update = msg.time; + } + return _offset_minutes; + } +}; + + + +// Thread id +class t_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm&) override + { + msg.formatted << msg.thread_id; + } +}; + +// Current pid +class pid_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm&) override + { + msg.formatted << details::os::pid(); + } +}; + + +class v_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm&) override + { + msg.formatted << fmt::StringRef(msg.raw.data(), msg.raw.size()); + } +}; + +class ch_formatter SPDLOG_FINAL:public flag_formatter +{ +public: + explicit ch_formatter(char ch): _ch(ch) + {} + void format(details::log_msg& msg, const std::tm&) override + { + msg.formatted << _ch; + } +private: + char _ch; +}; + + +//aggregate user chars to display as is +class aggregate_formatter SPDLOG_FINAL:public flag_formatter +{ +public: + aggregate_formatter() + {} + void add_ch(char ch) + { + _str += ch; + } + void format(details::log_msg& msg, const std::tm&) override + { + msg.formatted << _str; + } +private: + std::string _str; +}; + +// Full info formatter +// pattern: [%Y-%m-%d %H:%M:%S.%e] [%n] [%l] %v +class full_formatter SPDLOG_FINAL:public flag_formatter +{ + void format(details::log_msg& msg, const std::tm& tm_time) override + { +#ifndef SPDLOG_NO_DATETIME + auto duration = msg.time.time_since_epoch(); + auto millis = std::chrono::duration_cast(duration).count() % 1000; + + /* Slower version(while still very fast - about 3.2 million lines/sec under 10 threads), + msg.formatted.write("[{:d}-{:02d}-{:02d} {:02d}:{:02d}:{:02d}.{:03d}] [{}] [{}] {} ", + tm_time.tm_year + 1900, + tm_time.tm_mon + 1, + tm_time.tm_mday, + tm_time.tm_hour, + tm_time.tm_min, + tm_time.tm_sec, + static_cast(millis), + msg.logger_name, + level::to_str(msg.level), + msg.raw.str());*/ + + + // Faster (albeit uglier) way to format the line (5.6 million lines/sec under 10 threads) + msg.formatted << '[' << static_cast(tm_time.tm_year + 1900) << '-' + << fmt::pad(static_cast(tm_time.tm_mon + 1), 2, '0') << '-' + << fmt::pad(static_cast(tm_time.tm_mday), 2, '0') << ' ' + << fmt::pad(static_cast(tm_time.tm_hour), 2, '0') << ':' + << fmt::pad(static_cast(tm_time.tm_min), 2, '0') << ':' + << fmt::pad(static_cast(tm_time.tm_sec), 2, '0') << '.' + << fmt::pad(static_cast(millis), 3, '0') << "] "; + + //no datetime needed +#else + (void)tm_time; +#endif + +#ifndef SPDLOG_NO_NAME + msg.formatted << '[' << *msg.logger_name << "] "; +#endif + + msg.formatted << '[' << level::to_str(msg.level) << "] "; + msg.formatted << fmt::StringRef(msg.raw.data(), msg.raw.size()); + } +}; + + + +} +} +/////////////////////////////////////////////////////////////////////////////// +// pattern_formatter inline impl +/////////////////////////////////////////////////////////////////////////////// +inline spdlog::pattern_formatter::pattern_formatter(const std::string& pattern, pattern_time_type pattern_time) + : _pattern_time(pattern_time) +{ + compile_pattern(pattern); +} + +inline void spdlog::pattern_formatter::compile_pattern(const std::string& pattern) +{ + auto end = pattern.end(); + std::unique_ptr user_chars; + for (auto it = pattern.begin(); it != end; ++it) + { + if (*it == '%') + { + if (user_chars) //append user chars found so far + _formatters.push_back(std::move(user_chars)); + + if (++it != end) + handle_flag(*it); + else + break; + } + else // chars not following the % sign should be displayed as is + { + if (!user_chars) + user_chars = std::unique_ptr(new details::aggregate_formatter()); + user_chars->add_ch(*it); + } + } + if (user_chars) //append raw chars found so far + { + _formatters.push_back(std::move(user_chars)); + } + +} +inline void spdlog::pattern_formatter::handle_flag(char flag) +{ + switch (flag) + { + // logger name + case 'n': + _formatters.push_back(std::unique_ptr(new details::name_formatter())); + break; + + case 'l': + _formatters.push_back(std::unique_ptr(new details::level_formatter())); + break; + + case 'L': + _formatters.push_back(std::unique_ptr(new details::short_level_formatter())); + break; + + case('t'): + _formatters.push_back(std::unique_ptr(new details::t_formatter())); + break; + + case('v'): + _formatters.push_back(std::unique_ptr(new details::v_formatter())); + break; + + case('a'): + _formatters.push_back(std::unique_ptr(new details::a_formatter())); + break; + + case('A'): + _formatters.push_back(std::unique_ptr(new details::A_formatter())); + break; + + case('b'): + case('h'): + _formatters.push_back(std::unique_ptr(new details::b_formatter())); + break; + + case('B'): + _formatters.push_back(std::unique_ptr(new details::B_formatter())); + break; + case('c'): + _formatters.push_back(std::unique_ptr(new details::c_formatter())); + break; + + case('C'): + _formatters.push_back(std::unique_ptr(new details::C_formatter())); + break; + + case('Y'): + _formatters.push_back(std::unique_ptr(new details::Y_formatter())); + break; + + case('D'): + case('x'): + + _formatters.push_back(std::unique_ptr(new details::D_formatter())); + break; + + case('m'): + _formatters.push_back(std::unique_ptr(new details::m_formatter())); + break; + + case('d'): + _formatters.push_back(std::unique_ptr(new details::d_formatter())); + break; + + case('H'): + _formatters.push_back(std::unique_ptr(new details::H_formatter())); + break; + + case('I'): + _formatters.push_back(std::unique_ptr(new details::I_formatter())); + break; + + case('M'): + _formatters.push_back(std::unique_ptr(new details::M_formatter())); + break; + + case('S'): + _formatters.push_back(std::unique_ptr(new details::S_formatter())); + break; + + case('e'): + _formatters.push_back(std::unique_ptr(new details::e_formatter())); + break; + + case('f'): + _formatters.push_back(std::unique_ptr(new details::f_formatter())); + break; + case('F'): + _formatters.push_back(std::unique_ptr(new details::F_formatter())); + break; + + case('p'): + _formatters.push_back(std::unique_ptr(new details::p_formatter())); + break; + + case('r'): + _formatters.push_back(std::unique_ptr(new details::r_formatter())); + break; + + case('R'): + _formatters.push_back(std::unique_ptr(new details::R_formatter())); + break; + + case('T'): + case('X'): + _formatters.push_back(std::unique_ptr(new details::T_formatter())); + break; + + case('z'): + _formatters.push_back(std::unique_ptr(new details::z_formatter())); + break; + + case ('+'): + _formatters.push_back(std::unique_ptr(new details::full_formatter())); + break; + + case ('P'): + _formatters.push_back(std::unique_ptr(new details::pid_formatter())); + break; + +#if defined(SPDLOG_ENABLE_MESSAGE_COUNTER) + case ('i'): + _formatters.push_back(std::unique_ptr(new details::i_formatter())); + break; +#endif + + default: //Unknown flag appears as is + _formatters.push_back(std::unique_ptr(new details::ch_formatter('%'))); + _formatters.push_back(std::unique_ptr(new details::ch_formatter(flag))); + break; + } +} + +inline std::tm spdlog::pattern_formatter::get_time(details::log_msg& msg) +{ + if (_pattern_time == pattern_time_type::local) + return details::os::localtime(log_clock::to_time_t(msg.time)); + else + return details::os::gmtime(log_clock::to_time_t(msg.time)); +} + +inline void spdlog::pattern_formatter::format(details::log_msg& msg) +{ + +#ifndef SPDLOG_NO_DATETIME + auto tm_time = get_time(msg); +#else + std::tm tm_time; +#endif + for (auto &f : _formatters) + { + f->format(msg, tm_time); + } + //write eol + msg.formatted.write(details::os::eol, details::os::eol_size); +} diff --git a/src/dionysus/wasserstein/spdlog/details/registry.h b/src/dionysus/wasserstein/spdlog/details/registry.h new file mode 100755 index 0000000..1064488 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/details/registry.h @@ -0,0 +1,214 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +// Loggers registy of unique name->logger pointer +// An attempt to create a logger with an already existing name will be ignored +// If user requests a non existing logger, nullptr will be returned +// This class is thread safe + +#include "../details/null_mutex.h" +#include "../logger.h" +#include "../async_logger.h" +#include "../common.h" + +#include +#include +#include +#include +#include +#include + +namespace spdlog +{ +namespace details +{ +template class registry_t +{ +public: + + void register_logger(std::shared_ptr logger) + { + std::lock_guard lock(_mutex); + auto logger_name = logger->name(); + throw_if_exists(logger_name); + _loggers[logger_name] = logger; + } + + + std::shared_ptr get(const std::string& logger_name) + { + std::lock_guard lock(_mutex); + auto found = _loggers.find(logger_name); + return found == _loggers.end() ? nullptr : found->second; + } + + template + std::shared_ptr create(const std::string& logger_name, const It& sinks_begin, const It& sinks_end) + { + std::lock_guard lock(_mutex); + throw_if_exists(logger_name); + std::shared_ptr new_logger; + if (_async_mode) + new_logger = std::make_shared(logger_name, sinks_begin, sinks_end, _async_q_size, _overflow_policy, _worker_warmup_cb, _flush_interval_ms, _worker_teardown_cb); + else + new_logger = std::make_shared(logger_name, sinks_begin, sinks_end); + + if (_formatter) + new_logger->set_formatter(_formatter); + + if (_err_handler) + new_logger->set_error_handler(_err_handler); + + new_logger->set_level(_level); + + + //Add to registry + _loggers[logger_name] = new_logger; + return new_logger; + } + + template + std::shared_ptr create_async(const std::string& logger_name, size_t queue_size, const async_overflow_policy overflow_policy, const std::function& worker_warmup_cb, const std::chrono::milliseconds& flush_interval_ms, const std::function& worker_teardown_cb, const It& sinks_begin, const It& sinks_end) + { + std::lock_guard lock(_mutex); + throw_if_exists(logger_name); + auto new_logger = std::make_shared(logger_name, sinks_begin, sinks_end, queue_size, overflow_policy, worker_warmup_cb, flush_interval_ms, worker_teardown_cb); + + if (_formatter) + new_logger->set_formatter(_formatter); + + if (_err_handler) + new_logger->set_error_handler(_err_handler); + + new_logger->set_level(_level); + + //Add to registry + _loggers[logger_name] = new_logger; + return new_logger; + } + + void apply_all(std::function)> fun) + { + std::lock_guard lock(_mutex); + for (auto &l : _loggers) + fun(l.second); + } + + void drop(const std::string& logger_name) + { + std::lock_guard lock(_mutex); + _loggers.erase(logger_name); + } + + void drop_all() + { + std::lock_guard lock(_mutex); + _loggers.clear(); + } + std::shared_ptr create(const std::string& logger_name, sinks_init_list sinks) + { + return create(logger_name, sinks.begin(), sinks.end()); + } + + std::shared_ptr create(const std::string& logger_name, sink_ptr sink) + { + return create(logger_name, { sink }); + } + + std::shared_ptr create_async(const std::string& logger_name, size_t queue_size, const async_overflow_policy overflow_policy, const std::function& worker_warmup_cb, const std::chrono::milliseconds& flush_interval_ms, const std::function& worker_teardown_cb, sinks_init_list sinks) + { + return create_async(logger_name, queue_size, overflow_policy, worker_warmup_cb, flush_interval_ms, worker_teardown_cb, sinks.begin(), sinks.end()); + } + + std::shared_ptr create_async(const std::string& logger_name, size_t queue_size, const async_overflow_policy overflow_policy, const std::function& worker_warmup_cb, const std::chrono::milliseconds& flush_interval_ms, const std::function& worker_teardown_cb, sink_ptr sink) + { + return create_async(logger_name, queue_size, overflow_policy, worker_warmup_cb, flush_interval_ms, worker_teardown_cb, { sink }); + } + + void formatter(formatter_ptr f) + { + std::lock_guard lock(_mutex); + _formatter = f; + for (auto& l : _loggers) + l.second->set_formatter(_formatter); + } + + void set_pattern(const std::string& pattern) + { + std::lock_guard lock(_mutex); + _formatter = std::make_shared(pattern); + for (auto& l : _loggers) + l.second->set_formatter(_formatter); + } + + void set_level(level::level_enum log_level) + { + std::lock_guard lock(_mutex); + for (auto& l : _loggers) + l.second->set_level(log_level); + _level = log_level; + } + + void set_error_handler(log_err_handler handler) + { + for (auto& l : _loggers) + l.second->set_error_handler(handler); + _err_handler = handler; + } + + void set_async_mode(size_t q_size, const async_overflow_policy overflow_policy, const std::function& worker_warmup_cb, const std::chrono::milliseconds& flush_interval_ms, const std::function& worker_teardown_cb) + { + std::lock_guard lock(_mutex); + _async_mode = true; + _async_q_size = q_size; + _overflow_policy = overflow_policy; + _worker_warmup_cb = worker_warmup_cb; + _flush_interval_ms = flush_interval_ms; + _worker_teardown_cb = worker_teardown_cb; + } + + void set_sync_mode() + { + std::lock_guard lock(_mutex); + _async_mode = false; + } + + static registry_t& instance() + { + static registry_t s_instance; + return s_instance; + } + +private: + registry_t() {} + registry_t(const registry_t&) = delete; + registry_t& operator=(const registry_t&) = delete; + + void throw_if_exists(const std::string &logger_name) + { + if (_loggers.find(logger_name) != _loggers.end()) + throw spdlog_ex("logger with name '" + logger_name + "' already exists"); + } + Mutex _mutex; + std::unordered_map > _loggers; + formatter_ptr _formatter; + level::level_enum _level = level::info; + log_err_handler _err_handler; + bool _async_mode = false; + size_t _async_q_size = 0; + async_overflow_policy _overflow_policy = async_overflow_policy::block_retry; + std::function _worker_warmup_cb = nullptr; + std::chrono::milliseconds _flush_interval_ms; + std::function _worker_teardown_cb = nullptr; +}; +#ifdef SPDLOG_NO_REGISTRY_MUTEX +typedef registry_t registry; +#else +typedef registry_t registry; +#endif +} +} diff --git a/src/dionysus/wasserstein/spdlog/details/spdlog_impl.h b/src/dionysus/wasserstein/spdlog/details/spdlog_impl.h new file mode 100755 index 0000000..7b9151f --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/details/spdlog_impl.h @@ -0,0 +1,263 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +// +// Global registry functions +// +#include "../spdlog.h" +#include "../details/registry.h" +#include "../sinks/file_sinks.h" +#include "../sinks/stdout_sinks.h" +#ifdef SPDLOG_ENABLE_SYSLOG +#include "../sinks/syslog_sink.h" +#endif + +#ifdef _WIN32 +#include "../sinks/wincolor_sink.h" +#else +#include "../sinks/ansicolor_sink.h" +#endif + + +#ifdef __ANDROID__ +#include "../sinks/android_sink.h" +#endif + +#include +#include +#include +#include + +inline void spdlog::register_logger(std::shared_ptr logger) +{ + return details::registry::instance().register_logger(logger); +} + +inline std::shared_ptr spdlog::get(const std::string& name) +{ + return details::registry::instance().get(name); +} + +inline void spdlog::drop(const std::string &name) +{ + details::registry::instance().drop(name); +} + +// Create multi/single threaded simple file logger +inline std::shared_ptr spdlog::basic_logger_mt(const std::string& logger_name, const filename_t& filename, bool truncate) +{ + return create(logger_name, filename, truncate); +} + +inline std::shared_ptr spdlog::basic_logger_st(const std::string& logger_name, const filename_t& filename, bool truncate) +{ + return create(logger_name, filename, truncate); +} + +// Create multi/single threaded rotating file logger +inline std::shared_ptr spdlog::rotating_logger_mt(const std::string& logger_name, const filename_t& filename, size_t max_file_size, size_t max_files) +{ + return create(logger_name, filename, max_file_size, max_files); +} + +inline std::shared_ptr spdlog::rotating_logger_st(const std::string& logger_name, const filename_t& filename, size_t max_file_size, size_t max_files) +{ + return create(logger_name, filename, max_file_size, max_files); +} + +// Create file logger which creates new file at midnight): +inline std::shared_ptr spdlog::daily_logger_mt(const std::string& logger_name, const filename_t& filename, int hour, int minute) +{ + return create(logger_name, filename, hour, minute); +} + +inline std::shared_ptr spdlog::daily_logger_st(const std::string& logger_name, const filename_t& filename, int hour, int minute) +{ + return create(logger_name, filename, hour, minute); +} + + +// +// stdout/stderr loggers +// +inline std::shared_ptr spdlog::stdout_logger_mt(const std::string& logger_name) +{ + return spdlog::details::registry::instance().create(logger_name, spdlog::sinks::stdout_sink_mt::instance()); +} + +inline std::shared_ptr spdlog::stdout_logger_st(const std::string& logger_name) +{ + return spdlog::details::registry::instance().create(logger_name, spdlog::sinks::stdout_sink_st::instance()); +} + +inline std::shared_ptr spdlog::stderr_logger_mt(const std::string& logger_name) +{ + return spdlog::details::registry::instance().create(logger_name, spdlog::sinks::stderr_sink_mt::instance()); +} + +inline std::shared_ptr spdlog::stderr_logger_st(const std::string& logger_name) +{ + return spdlog::details::registry::instance().create(logger_name, spdlog::sinks::stderr_sink_st::instance()); +} + +// +// stdout/stderr color loggers +// +#ifdef _WIN32 +inline std::shared_ptr spdlog::stdout_color_mt(const std::string& logger_name) +{ + auto sink = std::make_shared(); + return spdlog::details::registry::instance().create(logger_name, sink); +} + +inline std::shared_ptr spdlog::stdout_color_st(const std::string& logger_name) +{ + auto sink = std::make_shared(); + return spdlog::details::registry::instance().create(logger_name, sink); +} + +inline std::shared_ptr spdlog::stderr_color_mt(const std::string& logger_name) +{ + auto sink = std::make_shared(); + return spdlog::details::registry::instance().create(logger_name, sink); +} + + +inline std::shared_ptr spdlog::stderr_color_st(const std::string& logger_name) +{ + auto sink = std::make_shared(); + return spdlog::details::registry::instance().create(logger_name, sink); +} + +#else //ansi terminal colors + +inline std::shared_ptr spdlog::stdout_color_mt(const std::string& logger_name) +{ + auto sink = std::make_shared(); + return spdlog::details::registry::instance().create(logger_name, sink); +} + +inline std::shared_ptr spdlog::stdout_color_st(const std::string& logger_name) +{ + auto sink = std::make_shared(); + return spdlog::details::registry::instance().create(logger_name, sink); +} + +inline std::shared_ptr spdlog::stderr_color_mt(const std::string& logger_name) +{ + auto sink = std::make_shared(); + return spdlog::details::registry::instance().create(logger_name, sink); +} + +inline std::shared_ptr spdlog::stderr_color_st(const std::string& logger_name) +{ + auto sink = std::make_shared(); + return spdlog::details::registry::instance().create(logger_name, sink); +} +#endif + +#ifdef SPDLOG_ENABLE_SYSLOG +// Create syslog logger +inline std::shared_ptr spdlog::syslog_logger(const std::string& logger_name, const std::string& syslog_ident, int syslog_option) +{ + return create(logger_name, syslog_ident, syslog_option); +} +#endif + +#ifdef __ANDROID__ +inline std::shared_ptr spdlog::android_logger(const std::string& logger_name, const std::string& tag) +{ + return create(logger_name, tag); +} +#endif + +// Create and register a logger a single sink +inline std::shared_ptr spdlog::create(const std::string& logger_name, const spdlog::sink_ptr& sink) +{ + return details::registry::instance().create(logger_name, sink); +} + +//Create logger with multiple sinks + +inline std::shared_ptr spdlog::create(const std::string& logger_name, spdlog::sinks_init_list sinks) +{ + return details::registry::instance().create(logger_name, sinks); +} + + +template +inline std::shared_ptr spdlog::create(const std::string& logger_name, Args... args) +{ + sink_ptr sink = std::make_shared(args...); + return details::registry::instance().create(logger_name, { sink }); +} + + +template +inline std::shared_ptr spdlog::create(const std::string& logger_name, const It& sinks_begin, const It& sinks_end) +{ + return details::registry::instance().create(logger_name, sinks_begin, sinks_end); +} + +// Create and register an async logger with a single sink +inline std::shared_ptr spdlog::create_async(const std::string& logger_name, const sink_ptr& sink, size_t queue_size, const async_overflow_policy overflow_policy, const std::function& worker_warmup_cb, const std::chrono::milliseconds& flush_interval_ms, const std::function& worker_teardown_cb) +{ + return details::registry::instance().create_async(logger_name, queue_size, overflow_policy, worker_warmup_cb, flush_interval_ms, worker_teardown_cb, sink); +} + +// Create and register an async logger with multiple sinks +inline std::shared_ptr spdlog::create_async(const std::string& logger_name, sinks_init_list sinks, size_t queue_size, const async_overflow_policy overflow_policy, const std::function& worker_warmup_cb, const std::chrono::milliseconds& flush_interval_ms, const std::function& worker_teardown_cb ) +{ + return details::registry::instance().create_async(logger_name, queue_size, overflow_policy, worker_warmup_cb, flush_interval_ms, worker_teardown_cb, sinks); +} + +template +inline std::shared_ptr spdlog::create_async(const std::string& logger_name, const It& sinks_begin, const It& sinks_end, size_t queue_size, const async_overflow_policy overflow_policy, const std::function& worker_warmup_cb, const std::chrono::milliseconds& flush_interval_ms, const std::function& worker_teardown_cb) +{ + return details::registry::instance().create_async(logger_name, queue_size, overflow_policy, worker_warmup_cb, flush_interval_ms, worker_teardown_cb, sinks_begin, sinks_end); +} + +inline void spdlog::set_formatter(spdlog::formatter_ptr f) +{ + details::registry::instance().formatter(f); +} + +inline void spdlog::set_pattern(const std::string& format_string) +{ + return details::registry::instance().set_pattern(format_string); +} + +inline void spdlog::set_level(level::level_enum log_level) +{ + return details::registry::instance().set_level(log_level); +} + +inline void spdlog::set_error_handler(log_err_handler handler) +{ + return details::registry::instance().set_error_handler(handler); +} + + +inline void spdlog::set_async_mode(size_t queue_size, const async_overflow_policy overflow_policy, const std::function& worker_warmup_cb, const std::chrono::milliseconds& flush_interval_ms, const std::function& worker_teardown_cb) +{ + details::registry::instance().set_async_mode(queue_size, overflow_policy, worker_warmup_cb, flush_interval_ms, worker_teardown_cb); +} + +inline void spdlog::set_sync_mode() +{ + details::registry::instance().set_sync_mode(); +} + +inline void spdlog::apply_all(std::function)> fun) +{ + details::registry::instance().apply_all(fun); +} + +inline void spdlog::drop_all() +{ + details::registry::instance().drop_all(); +} diff --git a/src/dionysus/wasserstein/spdlog/fmt/bundled/format.cc b/src/dionysus/wasserstein/spdlog/fmt/bundled/format.cc new file mode 100755 index 0000000..2bd774e --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/fmt/bundled/format.cc @@ -0,0 +1,940 @@ +/* + Formatting library for C++ + + Copyright (c) 2012 - 2016, Victor Zverovich + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include "format.h" + +#include + +#include +#include +#include +#include +#include +#include // for std::ptrdiff_t + +#if defined(_WIN32) && defined(__MINGW32__) +# include +#endif + +#if FMT_USE_WINDOWS_H +# if defined(NOMINMAX) || defined(FMT_WIN_MINMAX) +# include +# else +# define NOMINMAX +# include +# undef NOMINMAX +# endif +#endif + +using fmt::internal::Arg; + +#if FMT_EXCEPTIONS +# define FMT_TRY try +# define FMT_CATCH(x) catch (x) +#else +# define FMT_TRY if (true) +# define FMT_CATCH(x) if (false) +#endif + +#ifdef _MSC_VER +# pragma warning(push) +# pragma warning(disable: 4127) // conditional expression is constant +# pragma warning(disable: 4702) // unreachable code +// Disable deprecation warning for strerror. The latter is not called but +// MSVC fails to detect it. +# pragma warning(disable: 4996) +#endif + +// Dummy implementations of strerror_r and strerror_s called if corresponding +// system functions are not available. +static inline fmt::internal::Null<> strerror_r(int, char *, ...) { + return fmt::internal::Null<>(); +} +static inline fmt::internal::Null<> strerror_s(char *, std::size_t, ...) { + return fmt::internal::Null<>(); +} + +namespace fmt { + +FMT_FUNC internal::RuntimeError::~RuntimeError() throw() {} +FMT_FUNC FormatError::~FormatError() throw() {} +FMT_FUNC SystemError::~SystemError() throw() {} + +namespace { + +#ifndef _MSC_VER +# define FMT_SNPRINTF snprintf +#else // _MSC_VER +inline int fmt_snprintf(char *buffer, size_t size, const char *format, ...) { + va_list args; + va_start(args, format); + int result = vsnprintf_s(buffer, size, _TRUNCATE, format, args); + va_end(args); + return result; +} +# define FMT_SNPRINTF fmt_snprintf +#endif // _MSC_VER + +#if defined(_WIN32) && defined(__MINGW32__) && !defined(__NO_ISOCEXT) +# define FMT_SWPRINTF snwprintf +#else +# define FMT_SWPRINTF swprintf +#endif // defined(_WIN32) && defined(__MINGW32__) && !defined(__NO_ISOCEXT) + +// Checks if a value fits in int - used to avoid warnings about comparing +// signed and unsigned integers. +template +struct IntChecker { + template + static bool fits_in_int(T value) { + unsigned max = INT_MAX; + return value <= max; + } + static bool fits_in_int(bool) { return true; } +}; + +template <> +struct IntChecker { + template + static bool fits_in_int(T value) { + return value >= INT_MIN && value <= INT_MAX; + } + static bool fits_in_int(int) { return true; } +}; + +const char RESET_COLOR[] = "\x1b[0m"; + +typedef void (*FormatFunc)(Writer &, int, StringRef); + +// Portable thread-safe version of strerror. +// Sets buffer to point to a string describing the error code. +// This can be either a pointer to a string stored in buffer, +// or a pointer to some static immutable string. +// Returns one of the following values: +// 0 - success +// ERANGE - buffer is not large enough to store the error message +// other - failure +// Buffer should be at least of size 1. +int safe_strerror( + int error_code, char *&buffer, std::size_t buffer_size) FMT_NOEXCEPT { + FMT_ASSERT(buffer != 0 && buffer_size != 0, "invalid buffer"); + + class StrError { + private: + int error_code_; + char *&buffer_; + std::size_t buffer_size_; + + // A noop assignment operator to avoid bogus warnings. + void operator=(const StrError &) {} + + // Handle the result of XSI-compliant version of strerror_r. + int handle(int result) { + // glibc versions before 2.13 return result in errno. + return result == -1 ? errno : result; + } + + // Handle the result of GNU-specific version of strerror_r. + int handle(char *message) { + // If the buffer is full then the message is probably truncated. + if (message == buffer_ && strlen(buffer_) == buffer_size_ - 1) + return ERANGE; + buffer_ = message; + return 0; + } + + // Handle the case when strerror_r is not available. + int handle(internal::Null<>) { + return fallback(strerror_s(buffer_, buffer_size_, error_code_)); + } + + // Fallback to strerror_s when strerror_r is not available. + int fallback(int result) { + // If the buffer is full then the message is probably truncated. + return result == 0 && strlen(buffer_) == buffer_size_ - 1 ? + ERANGE : result; + } + + // Fallback to strerror if strerror_r and strerror_s are not available. + int fallback(internal::Null<>) { + errno = 0; + buffer_ = strerror(error_code_); + return errno; + } + + public: + StrError(int err_code, char *&buf, std::size_t buf_size) + : error_code_(err_code), buffer_(buf), buffer_size_(buf_size) {} + + int run() { + strerror_r(0, 0, ""); // Suppress a warning about unused strerror_r. + return handle(strerror_r(error_code_, buffer_, buffer_size_)); + } + }; + return StrError(error_code, buffer, buffer_size).run(); +} + +void format_error_code(Writer &out, int error_code, + StringRef message) FMT_NOEXCEPT { + // Report error code making sure that the output fits into + // INLINE_BUFFER_SIZE to avoid dynamic memory allocation and potential + // bad_alloc. + out.clear(); + static const char SEP[] = ": "; + static const char ERROR_STR[] = "error "; + // Subtract 2 to account for terminating null characters in SEP and ERROR_STR. + std::size_t error_code_size = sizeof(SEP) + sizeof(ERROR_STR) - 2; + typedef internal::IntTraits::MainType MainType; + MainType abs_value = static_cast(error_code); + if (internal::is_negative(error_code)) { + abs_value = 0 - abs_value; + ++error_code_size; + } + error_code_size += internal::count_digits(abs_value); + if (message.size() <= internal::INLINE_BUFFER_SIZE - error_code_size) + out << message << SEP; + out << ERROR_STR << error_code; + assert(out.size() <= internal::INLINE_BUFFER_SIZE); +} + +void report_error(FormatFunc func, int error_code, + StringRef message) FMT_NOEXCEPT { + MemoryWriter full_message; + func(full_message, error_code, message); + // Use Writer::data instead of Writer::c_str to avoid potential memory + // allocation. + std::fwrite(full_message.data(), full_message.size(), 1, stderr); + std::fputc('\n', stderr); +} + +// IsZeroInt::visit(arg) returns true iff arg is a zero integer. +class IsZeroInt : public ArgVisitor { + public: + template + bool visit_any_int(T value) { return value == 0; } +}; + +// Checks if an argument is a valid printf width specifier and sets +// left alignment if it is negative. +class WidthHandler : public ArgVisitor { + private: + FormatSpec &spec_; + + FMT_DISALLOW_COPY_AND_ASSIGN(WidthHandler); + + public: + explicit WidthHandler(FormatSpec &spec) : spec_(spec) {} + + void report_unhandled_arg() { + FMT_THROW(FormatError("width is not integer")); + } + + template + unsigned visit_any_int(T value) { + typedef typename internal::IntTraits::MainType UnsignedType; + UnsignedType width = static_cast(value); + if (internal::is_negative(value)) { + spec_.align_ = ALIGN_LEFT; + width = 0 - width; + } + if (width > INT_MAX) + FMT_THROW(FormatError("number is too big")); + return static_cast(width); + } +}; + +class PrecisionHandler : public ArgVisitor { + public: + void report_unhandled_arg() { + FMT_THROW(FormatError("precision is not integer")); + } + + template + int visit_any_int(T value) { + if (!IntChecker::is_signed>::fits_in_int(value)) + FMT_THROW(FormatError("number is too big")); + return static_cast(value); + } +}; + +template +struct is_same { + enum { value = 0 }; +}; + +template +struct is_same { + enum { value = 1 }; +}; + +// An argument visitor that converts an integer argument to T for printf, +// if T is an integral type. If T is void, the argument is converted to +// corresponding signed or unsigned type depending on the type specifier: +// 'd' and 'i' - signed, other - unsigned) +template +class ArgConverter : public ArgVisitor, void> { + private: + internal::Arg &arg_; + wchar_t type_; + + FMT_DISALLOW_COPY_AND_ASSIGN(ArgConverter); + + public: + ArgConverter(internal::Arg &arg, wchar_t type) + : arg_(arg), type_(type) {} + + void visit_bool(bool value) { + if (type_ != 's') + visit_any_int(value); + } + + template + void visit_any_int(U value) { + bool is_signed = type_ == 'd' || type_ == 'i'; + using internal::Arg; + typedef typename internal::Conditional< + is_same::value, U, T>::type TargetType; + if (sizeof(TargetType) <= sizeof(int)) { + // Extra casts are used to silence warnings. + if (is_signed) { + arg_.type = Arg::INT; + arg_.int_value = static_cast(static_cast(value)); + } else { + arg_.type = Arg::UINT; + typedef typename internal::MakeUnsigned::Type Unsigned; + arg_.uint_value = static_cast(static_cast(value)); + } + } else { + if (is_signed) { + arg_.type = Arg::LONG_LONG; + // glibc's printf doesn't sign extend arguments of smaller types: + // std::printf("%lld", -42); // prints "4294967254" + // but we don't have to do the same because it's a UB. + arg_.long_long_value = static_cast(value); + } else { + arg_.type = Arg::ULONG_LONG; + arg_.ulong_long_value = + static_cast::Type>(value); + } + } + } +}; + +// Converts an integer argument to char for printf. +class CharConverter : public ArgVisitor { + private: + internal::Arg &arg_; + + FMT_DISALLOW_COPY_AND_ASSIGN(CharConverter); + + public: + explicit CharConverter(internal::Arg &arg) : arg_(arg) {} + + template + void visit_any_int(T value) { + arg_.type = internal::Arg::CHAR; + arg_.int_value = static_cast(value); + } +}; +} // namespace + +namespace internal { + +template +class PrintfArgFormatter : + public ArgFormatterBase, Char> { + + void write_null_pointer() { + this->spec().type_ = 0; + this->write("(nil)"); + } + + typedef ArgFormatterBase, Char> Base; + + public: + PrintfArgFormatter(BasicWriter &w, FormatSpec &s) + : ArgFormatterBase, Char>(w, s) {} + + void visit_bool(bool value) { + FormatSpec &fmt_spec = this->spec(); + if (fmt_spec.type_ != 's') + return this->visit_any_int(value); + fmt_spec.type_ = 0; + this->write(value); + } + + void visit_char(int value) { + const FormatSpec &fmt_spec = this->spec(); + BasicWriter &w = this->writer(); + if (fmt_spec.type_ && fmt_spec.type_ != 'c') + w.write_int(value, fmt_spec); + typedef typename BasicWriter::CharPtr CharPtr; + CharPtr out = CharPtr(); + if (fmt_spec.width_ > 1) { + Char fill = ' '; + out = w.grow_buffer(fmt_spec.width_); + if (fmt_spec.align_ != ALIGN_LEFT) { + std::fill_n(out, fmt_spec.width_ - 1, fill); + out += fmt_spec.width_ - 1; + } else { + std::fill_n(out + 1, fmt_spec.width_ - 1, fill); + } + } else { + out = w.grow_buffer(1); + } + *out = static_cast(value); + } + + void visit_cstring(const char *value) { + if (value) + Base::visit_cstring(value); + else if (this->spec().type_ == 'p') + write_null_pointer(); + else + this->write("(null)"); + } + + void visit_pointer(const void *value) { + if (value) + return Base::visit_pointer(value); + this->spec().type_ = 0; + write_null_pointer(); + } + + void visit_custom(Arg::CustomValue c) { + BasicFormatter formatter(ArgList(), this->writer()); + const Char format_str[] = {'}', 0}; + const Char *format = format_str; + c.format(&formatter, c.value, &format); + } +}; +} // namespace internal +} // namespace fmt + +FMT_FUNC void fmt::SystemError::init( + int err_code, CStringRef format_str, ArgList args) { + error_code_ = err_code; + MemoryWriter w; + internal::format_system_error(w, err_code, format(format_str, args)); + std::runtime_error &base = *this; + base = std::runtime_error(w.str()); +} + +template +int fmt::internal::CharTraits::format_float( + char *buffer, std::size_t size, const char *format, + unsigned width, int precision, T value) { + if (width == 0) { + return precision < 0 ? + FMT_SNPRINTF(buffer, size, format, value) : + FMT_SNPRINTF(buffer, size, format, precision, value); + } + return precision < 0 ? + FMT_SNPRINTF(buffer, size, format, width, value) : + FMT_SNPRINTF(buffer, size, format, width, precision, value); +} + +template +int fmt::internal::CharTraits::format_float( + wchar_t *buffer, std::size_t size, const wchar_t *format, + unsigned width, int precision, T value) { + if (width == 0) { + return precision < 0 ? + FMT_SWPRINTF(buffer, size, format, value) : + FMT_SWPRINTF(buffer, size, format, precision, value); + } + return precision < 0 ? + FMT_SWPRINTF(buffer, size, format, width, value) : + FMT_SWPRINTF(buffer, size, format, width, precision, value); +} + +template +const char fmt::internal::BasicData::DIGITS[] = + "0001020304050607080910111213141516171819" + "2021222324252627282930313233343536373839" + "4041424344454647484950515253545556575859" + "6061626364656667686970717273747576777879" + "8081828384858687888990919293949596979899"; + +#define FMT_POWERS_OF_10(factor) \ + factor * 10, \ + factor * 100, \ + factor * 1000, \ + factor * 10000, \ + factor * 100000, \ + factor * 1000000, \ + factor * 10000000, \ + factor * 100000000, \ + factor * 1000000000 + +template +const uint32_t fmt::internal::BasicData::POWERS_OF_10_32[] = { + 0, FMT_POWERS_OF_10(1) +}; + +template +const uint64_t fmt::internal::BasicData::POWERS_OF_10_64[] = { + 0, + FMT_POWERS_OF_10(1), + FMT_POWERS_OF_10(fmt::ULongLong(1000000000)), + // Multiply several constants instead of using a single long long constant + // to avoid warnings about C++98 not supporting long long. + fmt::ULongLong(1000000000) * fmt::ULongLong(1000000000) * 10 +}; + +FMT_FUNC void fmt::internal::report_unknown_type(char code, const char *type) { + (void)type; + if (std::isprint(static_cast(code))) { + FMT_THROW(fmt::FormatError( + fmt::format("unknown format code '{}' for {}", code, type))); + } + FMT_THROW(fmt::FormatError( + fmt::format("unknown format code '\\x{:02x}' for {}", + static_cast(code), type))); +} + +#if FMT_USE_WINDOWS_H + +FMT_FUNC fmt::internal::UTF8ToUTF16::UTF8ToUTF16(fmt::StringRef s) { + static const char ERROR_MSG[] = "cannot convert string from UTF-8 to UTF-16"; + if (s.size() > INT_MAX) + FMT_THROW(WindowsError(ERROR_INVALID_PARAMETER, ERROR_MSG)); + int s_size = static_cast(s.size()); + int length = MultiByteToWideChar( + CP_UTF8, MB_ERR_INVALID_CHARS, s.data(), s_size, 0, 0); + if (length == 0) + FMT_THROW(WindowsError(GetLastError(), ERROR_MSG)); + buffer_.resize(length + 1); + length = MultiByteToWideChar( + CP_UTF8, MB_ERR_INVALID_CHARS, s.data(), s_size, &buffer_[0], length); + if (length == 0) + FMT_THROW(WindowsError(GetLastError(), ERROR_MSG)); + buffer_[length] = 0; +} + +FMT_FUNC fmt::internal::UTF16ToUTF8::UTF16ToUTF8(fmt::WStringRef s) { + if (int error_code = convert(s)) { + FMT_THROW(WindowsError(error_code, + "cannot convert string from UTF-16 to UTF-8")); + } +} + +FMT_FUNC int fmt::internal::UTF16ToUTF8::convert(fmt::WStringRef s) { + if (s.size() > INT_MAX) + return ERROR_INVALID_PARAMETER; + int s_size = static_cast(s.size()); + int length = WideCharToMultiByte(CP_UTF8, 0, s.data(), s_size, 0, 0, 0, 0); + if (length == 0) + return GetLastError(); + buffer_.resize(length + 1); + length = WideCharToMultiByte( + CP_UTF8, 0, s.data(), s_size, &buffer_[0], length, 0, 0); + if (length == 0) + return GetLastError(); + buffer_[length] = 0; + return 0; +} + +FMT_FUNC void fmt::WindowsError::init( + int err_code, CStringRef format_str, ArgList args) { + error_code_ = err_code; + MemoryWriter w; + internal::format_windows_error(w, err_code, format(format_str, args)); + std::runtime_error &base = *this; + base = std::runtime_error(w.str()); +} + +FMT_FUNC void fmt::internal::format_windows_error( + fmt::Writer &out, int error_code, + fmt::StringRef message) FMT_NOEXCEPT { + FMT_TRY { + MemoryBuffer buffer; + buffer.resize(INLINE_BUFFER_SIZE); + for (;;) { + wchar_t *system_message = &buffer[0]; + int result = FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + 0, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + system_message, static_cast(buffer.size()), 0); + if (result != 0) { + UTF16ToUTF8 utf8_message; + if (utf8_message.convert(system_message) == ERROR_SUCCESS) { + out << message << ": " << utf8_message; + return; + } + break; + } + if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) + break; // Can't get error message, report error code instead. + buffer.resize(buffer.size() * 2); + } + } FMT_CATCH(...) {} + fmt::format_error_code(out, error_code, message); // 'fmt::' is for bcc32. +} + +#endif // FMT_USE_WINDOWS_H + +FMT_FUNC void fmt::internal::format_system_error( + fmt::Writer &out, int error_code, + fmt::StringRef message) FMT_NOEXCEPT { + FMT_TRY { + MemoryBuffer buffer; + buffer.resize(INLINE_BUFFER_SIZE); + for (;;) { + char *system_message = &buffer[0]; + int result = safe_strerror(error_code, system_message, buffer.size()); + if (result == 0) { + out << message << ": " << system_message; + return; + } + if (result != ERANGE) + break; // Can't get error message, report error code instead. + buffer.resize(buffer.size() * 2); + } + } FMT_CATCH(...) {} + fmt::format_error_code(out, error_code, message); // 'fmt::' is for bcc32. +} + +template +void fmt::internal::ArgMap::init(const ArgList &args) { + if (!map_.empty()) + return; + typedef internal::NamedArg NamedArg; + const NamedArg *named_arg = 0; + bool use_values = + args.type(ArgList::MAX_PACKED_ARGS - 1) == internal::Arg::NONE; + if (use_values) { + for (unsigned i = 0;/*nothing*/; ++i) { + internal::Arg::Type arg_type = args.type(i); + switch (arg_type) { + case internal::Arg::NONE: + return; + case internal::Arg::NAMED_ARG: + named_arg = static_cast(args.values_[i].pointer); + map_.push_back(Pair(named_arg->name, *named_arg)); + break; + default: + /*nothing*/; + } + } + return; + } + for (unsigned i = 0; i != ArgList::MAX_PACKED_ARGS; ++i) { + internal::Arg::Type arg_type = args.type(i); + if (arg_type == internal::Arg::NAMED_ARG) { + named_arg = static_cast(args.args_[i].pointer); + map_.push_back(Pair(named_arg->name, *named_arg)); + } + } + for (unsigned i = ArgList::MAX_PACKED_ARGS;/*nothing*/; ++i) { + switch (args.args_[i].type) { + case internal::Arg::NONE: + return; + case internal::Arg::NAMED_ARG: + named_arg = static_cast(args.args_[i].pointer); + map_.push_back(Pair(named_arg->name, *named_arg)); + break; + default: + /*nothing*/; + } + } +} + +template +void fmt::internal::FixedBuffer::grow(std::size_t) { + FMT_THROW(std::runtime_error("buffer overflow")); +} + +FMT_FUNC Arg fmt::internal::FormatterBase::do_get_arg( + unsigned arg_index, const char *&error) { + Arg arg = args_[arg_index]; + switch (arg.type) { + case Arg::NONE: + error = "argument index out of range"; + break; + case Arg::NAMED_ARG: + arg = *static_cast(arg.pointer); + break; + default: + /*nothing*/; + } + return arg; +} + +template +void fmt::internal::PrintfFormatter::parse_flags( + FormatSpec &spec, const Char *&s) { + for (;;) { + switch (*s++) { + case '-': + spec.align_ = ALIGN_LEFT; + break; + case '+': + spec.flags_ |= SIGN_FLAG | PLUS_FLAG; + break; + case '0': + spec.fill_ = '0'; + break; + case ' ': + spec.flags_ |= SIGN_FLAG; + break; + case '#': + spec.flags_ |= HASH_FLAG; + break; + default: + --s; + return; + } + } +} + +template +Arg fmt::internal::PrintfFormatter::get_arg( + const Char *s, unsigned arg_index) { + (void)s; + const char *error = 0; + Arg arg = arg_index == UINT_MAX ? + next_arg(error) : FormatterBase::get_arg(arg_index - 1, error); + if (error) + FMT_THROW(FormatError(!*s ? "invalid format string" : error)); + return arg; +} + +template +unsigned fmt::internal::PrintfFormatter::parse_header( + const Char *&s, FormatSpec &spec) { + unsigned arg_index = UINT_MAX; + Char c = *s; + if (c >= '0' && c <= '9') { + // Parse an argument index (if followed by '$') or a width possibly + // preceded with '0' flag(s). + unsigned value = parse_nonnegative_int(s); + if (*s == '$') { // value is an argument index + ++s; + arg_index = value; + } else { + if (c == '0') + spec.fill_ = '0'; + if (value != 0) { + // Nonzero value means that we parsed width and don't need to + // parse it or flags again, so return now. + spec.width_ = value; + return arg_index; + } + } + } + parse_flags(spec, s); + // Parse width. + if (*s >= '0' && *s <= '9') { + spec.width_ = parse_nonnegative_int(s); + } else if (*s == '*') { + ++s; + spec.width_ = WidthHandler(spec).visit(get_arg(s)); + } + return arg_index; +} + +template +void fmt::internal::PrintfFormatter::format( + BasicWriter &writer, BasicCStringRef format_str) { + const Char *start = format_str.c_str(); + const Char *s = start; + while (*s) { + Char c = *s++; + if (c != '%') continue; + if (*s == c) { + write(writer, start, s); + start = ++s; + continue; + } + write(writer, start, s - 1); + + FormatSpec spec; + spec.align_ = ALIGN_RIGHT; + + // Parse argument index, flags and width. + unsigned arg_index = parse_header(s, spec); + + // Parse precision. + if (*s == '.') { + ++s; + if ('0' <= *s && *s <= '9') { + spec.precision_ = static_cast(parse_nonnegative_int(s)); + } else if (*s == '*') { + ++s; + spec.precision_ = PrecisionHandler().visit(get_arg(s)); + } + } + + Arg arg = get_arg(s, arg_index); + if (spec.flag(HASH_FLAG) && IsZeroInt().visit(arg)) + spec.flags_ &= ~to_unsigned(HASH_FLAG); + if (spec.fill_ == '0') { + if (arg.type <= Arg::LAST_NUMERIC_TYPE) + spec.align_ = ALIGN_NUMERIC; + else + spec.fill_ = ' '; // Ignore '0' flag for non-numeric types. + } + + // Parse length and convert the argument to the required type. + switch (*s++) { + case 'h': + if (*s == 'h') + ArgConverter(arg, *++s).visit(arg); + else + ArgConverter(arg, *s).visit(arg); + break; + case 'l': + if (*s == 'l') + ArgConverter(arg, *++s).visit(arg); + else + ArgConverter(arg, *s).visit(arg); + break; + case 'j': + ArgConverter(arg, *s).visit(arg); + break; + case 'z': + ArgConverter(arg, *s).visit(arg); + break; + case 't': + ArgConverter(arg, *s).visit(arg); + break; + case 'L': + // printf produces garbage when 'L' is omitted for long double, no + // need to do the same. + break; + default: + --s; + ArgConverter(arg, *s).visit(arg); + } + + // Parse type. + if (!*s) + FMT_THROW(FormatError("invalid format string")); + spec.type_ = static_cast(*s++); + if (arg.type <= Arg::LAST_INTEGER_TYPE) { + // Normalize type. + switch (spec.type_) { + case 'i': case 'u': + spec.type_ = 'd'; + break; + case 'c': + // TODO: handle wchar_t + CharConverter(arg).visit(arg); + break; + } + } + + start = s; + + // Format argument. + internal::PrintfArgFormatter(writer, spec).visit(arg); + } + write(writer, start, s); +} + +FMT_FUNC void fmt::report_system_error( + int error_code, fmt::StringRef message) FMT_NOEXCEPT { + // 'fmt::' is for bcc32. + fmt::report_error(internal::format_system_error, error_code, message); +} + +#if FMT_USE_WINDOWS_H +FMT_FUNC void fmt::report_windows_error( + int error_code, fmt::StringRef message) FMT_NOEXCEPT { + // 'fmt::' is for bcc32. + fmt::report_error(internal::format_windows_error, error_code, message); +} +#endif + +FMT_FUNC void fmt::print(std::FILE *f, CStringRef format_str, ArgList args) { + MemoryWriter w; + w.write(format_str, args); + std::fwrite(w.data(), 1, w.size(), f); +} + +FMT_FUNC void fmt::print(CStringRef format_str, ArgList args) { + print(stdout, format_str, args); +} + +FMT_FUNC void fmt::print_colored(Color c, CStringRef format, ArgList args) { + char escape[] = "\x1b[30m"; + escape[3] = static_cast('0' + c); + std::fputs(escape, stdout); + print(format, args); + std::fputs(RESET_COLOR, stdout); +} + +FMT_FUNC int fmt::fprintf(std::FILE *f, CStringRef format, ArgList args) { + MemoryWriter w; + printf(w, format, args); + std::size_t size = w.size(); + return std::fwrite(w.data(), 1, size, f) < size ? -1 : static_cast(size); +} + +#ifndef FMT_HEADER_ONLY + +template struct fmt::internal::BasicData; + +// Explicit instantiations for char. + +template void fmt::internal::FixedBuffer::grow(std::size_t); + +template void fmt::internal::ArgMap::init(const fmt::ArgList &args); + +template void fmt::internal::PrintfFormatter::format( + BasicWriter &writer, CStringRef format); + +template int fmt::internal::CharTraits::format_float( + char *buffer, std::size_t size, const char *format, + unsigned width, int precision, double value); + +template int fmt::internal::CharTraits::format_float( + char *buffer, std::size_t size, const char *format, + unsigned width, int precision, long double value); + +// Explicit instantiations for wchar_t. + +template void fmt::internal::FixedBuffer::grow(std::size_t); + +template void fmt::internal::ArgMap::init(const fmt::ArgList &args); + +template void fmt::internal::PrintfFormatter::format( + BasicWriter &writer, WCStringRef format); + +template int fmt::internal::CharTraits::format_float( + wchar_t *buffer, std::size_t size, const wchar_t *format, + unsigned width, int precision, double value); + +template int fmt::internal::CharTraits::format_float( + wchar_t *buffer, std::size_t size, const wchar_t *format, + unsigned width, int precision, long double value); + +#endif // FMT_HEADER_ONLY + +#ifdef _MSC_VER +# pragma warning(pop) +#endif diff --git a/src/dionysus/wasserstein/spdlog/fmt/bundled/format.h b/src/dionysus/wasserstein/spdlog/fmt/bundled/format.h new file mode 100755 index 0000000..64c949b --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/fmt/bundled/format.h @@ -0,0 +1,4501 @@ +/* + Formatting library for C++ + + Copyright (c) 2012 - 2016, Victor Zverovich + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef FMT_FORMAT_H_ +#define FMT_FORMAT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _SECURE_SCL +# define FMT_SECURE_SCL _SECURE_SCL +#else +# define FMT_SECURE_SCL 0 +#endif + +#if FMT_SECURE_SCL +# include +#endif + +#ifdef _MSC_VER +# define FMT_MSC_VER _MSC_VER +#else +# define FMT_MSC_VER 0 +#endif + +#if FMT_MSC_VER && FMT_MSC_VER <= 1500 +typedef unsigned __int32 uint32_t; +typedef unsigned __int64 uint64_t; +typedef __int64 intmax_t; +#else +#include +#endif + +#if !defined(FMT_HEADER_ONLY) && defined(_WIN32) +# ifdef FMT_EXPORT +# define FMT_API __declspec(dllexport) +# elif defined(FMT_SHARED) +# define FMT_API __declspec(dllimport) +# endif +#endif +#ifndef FMT_API +# define FMT_API +#endif + +#ifdef __GNUC__ +# define FMT_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) +# define FMT_GCC_EXTENSION __extension__ +# if FMT_GCC_VERSION >= 406 +# pragma GCC diagnostic push +// Disable the warning about "long long" which is sometimes reported even +// when using __extension__. +# pragma GCC diagnostic ignored "-Wlong-long" +// Disable the warning about declaration shadowing because it affects too +// many valid cases. +# pragma GCC diagnostic ignored "-Wshadow" +// Disable the warning about implicit conversions that may change the sign of +// an integer; silencing it otherwise would require many explicit casts. +# pragma GCC diagnostic ignored "-Wsign-conversion" +# endif +# if __cplusplus >= 201103L || defined __GXX_EXPERIMENTAL_CXX0X__ +# define FMT_HAS_GXX_CXX11 1 +# endif +#else +# define FMT_GCC_EXTENSION +#endif + +#if defined(__INTEL_COMPILER) +# define FMT_ICC_VERSION __INTEL_COMPILER +#elif defined(__ICL) +# define FMT_ICC_VERSION __ICL +#endif + +#if defined(__clang__) && !defined(FMT_ICC_VERSION) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wdocumentation-unknown-command" +# pragma clang diagnostic ignored "-Wpadded" +#endif + +#ifdef __GNUC_LIBSTD__ +# define FMT_GNUC_LIBSTD_VERSION (__GNUC_LIBSTD__ * 100 + __GNUC_LIBSTD_MINOR__) +#endif + +#ifdef __has_feature +# define FMT_HAS_FEATURE(x) __has_feature(x) +#else +# define FMT_HAS_FEATURE(x) 0 +#endif + +#ifdef __has_builtin +# define FMT_HAS_BUILTIN(x) __has_builtin(x) +#else +# define FMT_HAS_BUILTIN(x) 0 +#endif + +#ifdef __has_cpp_attribute +# define FMT_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) +#else +# define FMT_HAS_CPP_ATTRIBUTE(x) 0 +#endif + +#ifndef FMT_USE_VARIADIC_TEMPLATES +// Variadic templates are available in GCC since version 4.4 +// (http://gcc.gnu.org/projects/cxx0x.html) and in Visual C++ +// since version 2013. +# define FMT_USE_VARIADIC_TEMPLATES \ + (FMT_HAS_FEATURE(cxx_variadic_templates) || \ + (FMT_GCC_VERSION >= 404 && FMT_HAS_GXX_CXX11) || FMT_MSC_VER >= 1800) +#endif + +#ifndef FMT_USE_RVALUE_REFERENCES +// Don't use rvalue references when compiling with clang and an old libstdc++ +// as the latter doesn't provide std::move. +# if defined(FMT_GNUC_LIBSTD_VERSION) && FMT_GNUC_LIBSTD_VERSION <= 402 +# define FMT_USE_RVALUE_REFERENCES 0 +# else +# define FMT_USE_RVALUE_REFERENCES \ + (FMT_HAS_FEATURE(cxx_rvalue_references) || \ + (FMT_GCC_VERSION >= 403 && FMT_HAS_GXX_CXX11) || FMT_MSC_VER >= 1600) +# endif +#endif + +#if FMT_USE_RVALUE_REFERENCES +# include // for std::move +#endif + +// Check if exceptions are disabled. +#if defined(__GNUC__) && !defined(__EXCEPTIONS) +# define FMT_EXCEPTIONS 0 +#endif +#if FMT_MSC_VER && !_HAS_EXCEPTIONS +# define FMT_EXCEPTIONS 0 +#endif +#ifndef FMT_EXCEPTIONS +# define FMT_EXCEPTIONS 1 +#endif + +#ifndef FMT_THROW +# if FMT_EXCEPTIONS +# define FMT_THROW(x) throw x +# else +# define FMT_THROW(x) assert(false) +# endif +#endif + +// Define FMT_USE_NOEXCEPT to make fmt use noexcept (C++11 feature). +#ifndef FMT_USE_NOEXCEPT +# define FMT_USE_NOEXCEPT 0 +#endif + +#ifndef FMT_NOEXCEPT +# if FMT_EXCEPTIONS +# if FMT_USE_NOEXCEPT || FMT_HAS_FEATURE(cxx_noexcept) || \ + (FMT_GCC_VERSION >= 408 && FMT_HAS_GXX_CXX11) || \ + FMT_MSC_VER >= 1900 +# define FMT_NOEXCEPT noexcept +# else +# define FMT_NOEXCEPT throw() +# endif +# else +# define FMT_NOEXCEPT +# endif +#endif + +#ifndef FMT_OVERRIDE +# if FMT_USE_OVERRIDE || FMT_HAS_FEATURE(cxx_override) || \ + (FMT_GCC_VERSION >= 408 && FMT_HAS_GXX_CXX11) || \ + FMT_MSC_VER >= 1900 +# define FMT_OVERRIDE override +# else +# define FMT_OVERRIDE +# endif +#endif + + +// A macro to disallow the copy constructor and operator= functions +// This should be used in the private: declarations for a class +#ifndef FMT_USE_DELETED_FUNCTIONS +# define FMT_USE_DELETED_FUNCTIONS 0 +#endif + +#if FMT_USE_DELETED_FUNCTIONS || FMT_HAS_FEATURE(cxx_deleted_functions) || \ + (FMT_GCC_VERSION >= 404 && FMT_HAS_GXX_CXX11) || FMT_MSC_VER >= 1800 +# define FMT_DELETED_OR_UNDEFINED = delete +# define FMT_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&) = delete; \ + TypeName& operator=(const TypeName&) = delete +#else +# define FMT_DELETED_OR_UNDEFINED +# define FMT_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&); \ + TypeName& operator=(const TypeName&) +#endif + +#ifndef FMT_USE_USER_DEFINED_LITERALS +// All compilers which support UDLs also support variadic templates. This +// makes the fmt::literals implementation easier. However, an explicit check +// for variadic templates is added here just in case. +// For Intel's compiler both it and the system gcc/msc must support UDLs. +# define FMT_USE_USER_DEFINED_LITERALS \ + FMT_USE_VARIADIC_TEMPLATES && FMT_USE_RVALUE_REFERENCES && \ + (FMT_HAS_FEATURE(cxx_user_literals) || \ + (FMT_GCC_VERSION >= 407 && FMT_HAS_GXX_CXX11) || FMT_MSC_VER >= 1900) && \ + (!defined(FMT_ICC_VERSION) || FMT_ICC_VERSION >= 1500) +#endif + +#ifndef FMT_ASSERT +# define FMT_ASSERT(condition, message) assert((condition) && message) +#endif + +#if FMT_GCC_VERSION >= 400 || FMT_HAS_BUILTIN(__builtin_clz) +# define FMT_BUILTIN_CLZ(n) __builtin_clz(n) +#endif + +#if FMT_GCC_VERSION >= 400 || FMT_HAS_BUILTIN(__builtin_clzll) +# define FMT_BUILTIN_CLZLL(n) __builtin_clzll(n) +#endif + +// Some compilers masquerade as both MSVC and GCC-likes or +// otherwise support __builtin_clz and __builtin_clzll, so +// only define FMT_BUILTIN_CLZ using the MSVC intrinsics +// if the clz and clzll builtins are not available. +#if FMT_MSC_VER && !defined(FMT_BUILTIN_CLZLL) +# include // _BitScanReverse, _BitScanReverse64 + +namespace fmt +{ +namespace internal +{ +# pragma intrinsic(_BitScanReverse) +inline uint32_t clz(uint32_t x) +{ + unsigned long r = 0; + _BitScanReverse(&r, x); + + assert(x != 0); + // Static analysis complains about using uninitialized data + // "r", but the only way that can happen is if "x" is 0, + // which the callers guarantee to not happen. +# pragma warning(suppress: 6102) + return 31 - r; +} +# define FMT_BUILTIN_CLZ(n) fmt::internal::clz(n) + +# ifdef _WIN64 +# pragma intrinsic(_BitScanReverse64) +# endif + +inline uint32_t clzll(uint64_t x) +{ + unsigned long r = 0; +# ifdef _WIN64 + _BitScanReverse64(&r, x); +# else + // Scan the high 32 bits. + if (_BitScanReverse(&r, static_cast(x >> 32))) + return 63 - (r + 32); + + // Scan the low 32 bits. + _BitScanReverse(&r, static_cast(x)); +# endif + + assert(x != 0); + // Static analysis complains about using uninitialized data + // "r", but the only way that can happen is if "x" is 0, + // which the callers guarantee to not happen. +# pragma warning(suppress: 6102) + return 63 - r; +} +# define FMT_BUILTIN_CLZLL(n) fmt::internal::clzll(n) +} +} +#endif + +namespace fmt +{ +namespace internal +{ +struct DummyInt +{ + int data[2]; + operator int() const + { + return 0; + } +}; +typedef std::numeric_limits FPUtil; + +// Dummy implementations of system functions such as signbit and ecvt called +// if the latter are not available. +inline DummyInt signbit(...) +{ + return DummyInt(); +} +inline DummyInt _ecvt_s(...) +{ + return DummyInt(); +} +inline DummyInt isinf(...) +{ + return DummyInt(); +} +inline DummyInt _finite(...) +{ + return DummyInt(); +} +inline DummyInt isnan(...) +{ + return DummyInt(); +} +inline DummyInt _isnan(...) +{ + return DummyInt(); +} + +// A helper function to suppress bogus "conditional expression is constant" +// warnings. +template +inline T const_check(T value) +{ + return value; +} +} +} // namespace fmt + +namespace std +{ +// Standard permits specialization of std::numeric_limits. This specialization +// is used to resolve ambiguity between isinf and std::isinf in glibc: +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=48891 +// and the same for isnan and signbit. +template <> +class numeric_limits : + public std::numeric_limits +{ +public: + // Portable version of isinf. + template + static bool isinfinity(T x) + { + using namespace fmt::internal; + // The resolution "priority" is: + // isinf macro > std::isinf > ::isinf > fmt::internal::isinf + if (const_check(sizeof(isinf(x)) == sizeof(bool) || + sizeof(isinf(x)) == sizeof(int))) + { + return isinf(x) != 0; + } + return !_finite(static_cast(x)); + } + + // Portable version of isnan. + template + static bool isnotanumber(T x) + { + using namespace fmt::internal; + if (const_check(sizeof(isnan(x)) == sizeof(bool) || + sizeof(isnan(x)) == sizeof(int))) + { + return isnan(x) != 0; + } + return _isnan(static_cast(x)) != 0; + } + + // Portable version of signbit. + static bool isnegative(double x) + { + using namespace fmt::internal; + if (const_check(sizeof(signbit(x)) == sizeof(int))) + return signbit(x) != 0; + if (x < 0) return true; + if (!isnotanumber(x)) return false; + int dec = 0, sign = 0; + char buffer[2]; // The buffer size must be >= 2 or _ecvt_s will fail. + _ecvt_s(buffer, sizeof(buffer), x, 0, &dec, &sign); + return sign != 0; + } +}; +} // namespace std + +namespace fmt +{ + +// Fix the warning about long long on older versions of GCC +// that don't support the diagnostic pragma. +FMT_GCC_EXTENSION typedef long long LongLong; +FMT_GCC_EXTENSION typedef unsigned long long ULongLong; + +#if FMT_USE_RVALUE_REFERENCES +using std::move; +#endif + +template +class BasicWriter; + +typedef BasicWriter Writer; +typedef BasicWriter WWriter; + +template +class ArgFormatter; + +template > +class BasicFormatter; + +/** + \rst + A string reference. It can be constructed from a C string or ``std::string``. + + You can use one of the following typedefs for common character types: + + +------------+-------------------------+ + | Type | Definition | + +============+=========================+ + | StringRef | BasicStringRef | + +------------+-------------------------+ + | WStringRef | BasicStringRef | + +------------+-------------------------+ + + This class is most useful as a parameter type to allow passing + different types of strings to a function, for example:: + + template + std::string format(StringRef format_str, const Args & ... args); + + format("{}", 42); + format(std::string("{}"), 42); + \endrst + */ +template +class BasicStringRef +{ +private: + const Char *data_; + std::size_t size_; + +public: + /** Constructs a string reference object from a C string and a size. */ + BasicStringRef(const Char *s, std::size_t size) : data_(s), size_(size) {} + + /** + \rst + Constructs a string reference object from a C string computing + the size with ``std::char_traits::length``. + \endrst + */ + BasicStringRef(const Char *s) + : data_(s), size_(std::char_traits::length(s)) {} + + /** + \rst + Constructs a string reference from an ``std::string`` object. + \endrst + */ + BasicStringRef(const std::basic_string &s) + : data_(s.c_str()), size_(s.size()) {} + + /** + \rst + Converts a string reference to an ``std::string`` object. + \endrst + */ + std::basic_string to_string() const + { + return std::basic_string(data_, size_); + } + + /** Returns a pointer to the string data. */ + const Char *data() const + { + return data_; + } + + /** Returns the string size. */ + std::size_t size() const + { + return size_; + } + + // Lexicographically compare this string reference to other. + int compare(BasicStringRef other) const + { + std::size_t size = size_ < other.size_ ? size_ : other.size_; + int result = std::char_traits::compare(data_, other.data_, size); + if (result == 0) + result = size_ == other.size_ ? 0 : (size_ < other.size_ ? -1 : 1); + return result; + } + + friend bool operator==(BasicStringRef lhs, BasicStringRef rhs) + { + return lhs.compare(rhs) == 0; + } + friend bool operator!=(BasicStringRef lhs, BasicStringRef rhs) + { + return lhs.compare(rhs) != 0; + } + friend bool operator<(BasicStringRef lhs, BasicStringRef rhs) + { + return lhs.compare(rhs) < 0; + } + friend bool operator<=(BasicStringRef lhs, BasicStringRef rhs) + { + return lhs.compare(rhs) <= 0; + } + friend bool operator>(BasicStringRef lhs, BasicStringRef rhs) + { + return lhs.compare(rhs) > 0; + } + friend bool operator>=(BasicStringRef lhs, BasicStringRef rhs) + { + return lhs.compare(rhs) >= 0; + } +}; + +typedef BasicStringRef StringRef; +typedef BasicStringRef WStringRef; + +/** + \rst + A reference to a null terminated string. It can be constructed from a C + string or ``std::string``. + + You can use one of the following typedefs for common character types: + + +-------------+--------------------------+ + | Type | Definition | + +=============+==========================+ + | CStringRef | BasicCStringRef | + +-------------+--------------------------+ + | WCStringRef | BasicCStringRef | + +-------------+--------------------------+ + + This class is most useful as a parameter type to allow passing + different types of strings to a function, for example:: + + template + std::string format(CStringRef format_str, const Args & ... args); + + format("{}", 42); + format(std::string("{}"), 42); + \endrst + */ +template +class BasicCStringRef +{ +private: + const Char *data_; + +public: + /** Constructs a string reference object from a C string. */ + BasicCStringRef(const Char *s) : data_(s) {} + + /** + \rst + Constructs a string reference from an ``std::string`` object. + \endrst + */ + BasicCStringRef(const std::basic_string &s) : data_(s.c_str()) {} + + /** Returns the pointer to a C string. */ + const Char *c_str() const + { + return data_; + } +}; + +typedef BasicCStringRef CStringRef; +typedef BasicCStringRef WCStringRef; + +/** A formatting error such as invalid format string. */ +class FormatError : public std::runtime_error +{ +public: + explicit FormatError(CStringRef message) + : std::runtime_error(message.c_str()) {} + ~FormatError() throw(); +}; + +namespace internal +{ + +// MakeUnsigned::Type gives an unsigned type corresponding to integer type T. +template +struct MakeUnsigned +{ + typedef T Type; +}; + +#define FMT_SPECIALIZE_MAKE_UNSIGNED(T, U) \ + template <> \ + struct MakeUnsigned { typedef U Type; } + +FMT_SPECIALIZE_MAKE_UNSIGNED(char, unsigned char); +FMT_SPECIALIZE_MAKE_UNSIGNED(signed char, unsigned char); +FMT_SPECIALIZE_MAKE_UNSIGNED(short, unsigned short); +FMT_SPECIALIZE_MAKE_UNSIGNED(int, unsigned); +FMT_SPECIALIZE_MAKE_UNSIGNED(long, unsigned long); +FMT_SPECIALIZE_MAKE_UNSIGNED(LongLong, ULongLong); + +// Casts nonnegative integer to unsigned. +template +inline typename MakeUnsigned::Type to_unsigned(Int value) +{ + FMT_ASSERT(value >= 0, "negative value"); + return static_cast::Type>(value); +} + +// The number of characters to store in the MemoryBuffer object itself +// to avoid dynamic memory allocation. +enum { INLINE_BUFFER_SIZE = 500 }; + +#if FMT_SECURE_SCL +// Use checked iterator to avoid warnings on MSVC. +template +inline stdext::checked_array_iterator make_ptr(T *ptr, std::size_t size) +{ + return stdext::checked_array_iterator(ptr, size); +} +#else +template +inline T *make_ptr(T *ptr, std::size_t) +{ + return ptr; +} +#endif +} // namespace internal + +/** + \rst + A buffer supporting a subset of ``std::vector``'s operations. + \endrst + */ +template +class Buffer +{ +private: + FMT_DISALLOW_COPY_AND_ASSIGN(Buffer); + +protected: + T *ptr_; + std::size_t size_; + std::size_t capacity_; + + Buffer(T *ptr = 0, std::size_t capacity = 0) + : ptr_(ptr), size_(0), capacity_(capacity) {} + + /** + \rst + Increases the buffer capacity to hold at least *size* elements updating + ``ptr_`` and ``capacity_``. + \endrst + */ + virtual void grow(std::size_t size) = 0; + +public: + virtual ~Buffer() {} + + /** Returns the size of this buffer. */ + std::size_t size() const + { + return size_; + } + + /** Returns the capacity of this buffer. */ + std::size_t capacity() const + { + return capacity_; + } + + /** + Resizes the buffer. If T is a POD type new elements may not be initialized. + */ + void resize(std::size_t new_size) + { + if (new_size > capacity_) + grow(new_size); + size_ = new_size; + } + + /** + \rst + Reserves space to store at least *capacity* elements. + \endrst + */ + void reserve(std::size_t capacity) + { + if (capacity > capacity_) + grow(capacity); + } + + void clear() FMT_NOEXCEPT { size_ = 0; } + + void push_back(const T &value) + { + if (size_ == capacity_) + grow(size_ + 1); + ptr_[size_++] = value; + } + + /** Appends data to the end of the buffer. */ + template + void append(const U *begin, const U *end); + + T &operator[](std::size_t index) + { + return ptr_[index]; + } + const T &operator[](std::size_t index) const + { + return ptr_[index]; + } +}; + +template +template +void Buffer::append(const U *begin, const U *end) +{ + std::size_t new_size = size_ + internal::to_unsigned(end - begin); + if (new_size > capacity_) + grow(new_size); + std::uninitialized_copy(begin, end, + internal::make_ptr(ptr_, capacity_) + size_); + size_ = new_size; +} + +namespace internal +{ + +// A memory buffer for trivially copyable/constructible types with the first +// SIZE elements stored in the object itself. +template > +class MemoryBuffer : private Allocator, public Buffer +{ +private: + T data_[SIZE]; + + // Deallocate memory allocated by the buffer. + void deallocate() + { + if (this->ptr_ != data_) Allocator::deallocate(this->ptr_, this->capacity_); + } + +protected: + void grow(std::size_t size) FMT_OVERRIDE; + +public: + explicit MemoryBuffer(const Allocator &alloc = Allocator()) + : Allocator(alloc), Buffer(data_, SIZE) {} + ~MemoryBuffer() + { + deallocate(); + } + +#if FMT_USE_RVALUE_REFERENCES +private: + // Move data from other to this buffer. + void move(MemoryBuffer &other) + { + Allocator &this_alloc = *this, &other_alloc = other; + this_alloc = std::move(other_alloc); + this->size_ = other.size_; + this->capacity_ = other.capacity_; + if (other.ptr_ == other.data_) + { + this->ptr_ = data_; + std::uninitialized_copy(other.data_, other.data_ + this->size_, + make_ptr(data_, this->capacity_)); + } + else + { + this->ptr_ = other.ptr_; + // Set pointer to the inline array so that delete is not called + // when deallocating. + other.ptr_ = other.data_; + } + } + +public: + MemoryBuffer(MemoryBuffer &&other) + { + move(other); + } + + MemoryBuffer &operator=(MemoryBuffer &&other) + { + assert(this != &other); + deallocate(); + move(other); + return *this; + } +#endif + + // Returns a copy of the allocator associated with this buffer. + Allocator get_allocator() const + { + return *this; + } +}; + +template +void MemoryBuffer::grow(std::size_t size) +{ + std::size_t new_capacity = this->capacity_ + this->capacity_ / 2; + if (size > new_capacity) + new_capacity = size; + T *new_ptr = this->allocate(new_capacity); + // The following code doesn't throw, so the raw pointer above doesn't leak. + std::uninitialized_copy(this->ptr_, this->ptr_ + this->size_, + make_ptr(new_ptr, new_capacity)); + std::size_t old_capacity = this->capacity_; + T *old_ptr = this->ptr_; + this->capacity_ = new_capacity; + this->ptr_ = new_ptr; + // deallocate may throw (at least in principle), but it doesn't matter since + // the buffer already uses the new storage and will deallocate it in case + // of exception. + if (old_ptr != data_) + Allocator::deallocate(old_ptr, old_capacity); +} + +// A fixed-size buffer. +template +class FixedBuffer : public fmt::Buffer +{ +public: + FixedBuffer(Char *array, std::size_t size) : fmt::Buffer(array, size) {} + +protected: + FMT_API void grow(std::size_t size); +}; + +template +class BasicCharTraits +{ +public: +#if FMT_SECURE_SCL + typedef stdext::checked_array_iterator CharPtr; +#else + typedef Char *CharPtr; +#endif + static Char cast(int value) + { + return static_cast(value); + } +}; + +template +class CharTraits; + +template <> +class CharTraits : public BasicCharTraits +{ +private: + // Conversion from wchar_t to char is not allowed. + static char convert(wchar_t); + +public: + static char convert(char value) + { + return value; + } + + // Formats a floating-point number. + template + FMT_API static int format_float(char *buffer, std::size_t size, + const char *format, unsigned width, int precision, T value); +}; + +template <> +class CharTraits : public BasicCharTraits +{ +public: + static wchar_t convert(char value) + { + return value; + } + static wchar_t convert(wchar_t value) + { + return value; + } + + template + FMT_API static int format_float(wchar_t *buffer, std::size_t size, + const wchar_t *format, unsigned width, int precision, T value); +}; + +// Checks if a number is negative - used to avoid warnings. +template +struct SignChecker +{ + template + static bool is_negative(T value) + { + return value < 0; + } +}; + +template <> +struct SignChecker +{ + template + static bool is_negative(T) + { + return false; + } +}; + +// Returns true if value is negative, false otherwise. +// Same as (value < 0) but doesn't produce warnings if T is an unsigned type. +template +inline bool is_negative(T value) +{ + return SignChecker::is_signed>::is_negative(value); +} + +// Selects uint32_t if FitsIn32Bits is true, uint64_t otherwise. +template +struct TypeSelector +{ + typedef uint32_t Type; +}; + +template <> +struct TypeSelector +{ + typedef uint64_t Type; +}; + +template +struct IntTraits +{ + // Smallest of uint32_t and uint64_t that is large enough to represent + // all values of T. + typedef typename + TypeSelector::digits <= 32>::Type MainType; +}; + +FMT_API void report_unknown_type(char code, const char *type); + +// Static data is placed in this class template to allow header-only +// configuration. +template +struct FMT_API BasicData +{ + static const uint32_t POWERS_OF_10_32[]; + static const uint64_t POWERS_OF_10_64[]; + static const char DIGITS[]; +}; + +#ifndef FMT_USE_EXTERN_TEMPLATES +// Clang doesn't have a feature check for extern templates so we check +// for variadic templates which were introduced in the same version. +# define FMT_USE_EXTERN_TEMPLATES (__clang__ && FMT_USE_VARIADIC_TEMPLATES) +#endif + +#if FMT_USE_EXTERN_TEMPLATES && !defined(FMT_HEADER_ONLY) +extern template struct BasicData; +#endif + +typedef BasicData<> Data; + +#ifdef FMT_BUILTIN_CLZLL +// Returns the number of decimal digits in n. Leading zeros are not counted +// except for n == 0 in which case count_digits returns 1. +inline unsigned count_digits(uint64_t n) +{ + // Based on http://graphics.stanford.edu/~seander/bithacks.html#IntegerLog10 + // and the benchmark https://github.com/localvoid/cxx-benchmark-count-digits. + int t = (64 - FMT_BUILTIN_CLZLL(n | 1)) * 1233 >> 12; + return to_unsigned(t) - (n < Data::POWERS_OF_10_64[t]) + 1; +} +#else +// Fallback version of count_digits used when __builtin_clz is not available. +inline unsigned count_digits(uint64_t n) +{ + unsigned count = 1; + for (;;) + { + // Integer division is slow so do it for a group of four digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + if (n < 10) return count; + if (n < 100) return count + 1; + if (n < 1000) return count + 2; + if (n < 10000) return count + 3; + n /= 10000u; + count += 4; + } +} +#endif + +#ifdef FMT_BUILTIN_CLZ +// Optional version of count_digits for better performance on 32-bit platforms. +inline unsigned count_digits(uint32_t n) +{ + int t = (32 - FMT_BUILTIN_CLZ(n | 1)) * 1233 >> 12; + return to_unsigned(t) - (n < Data::POWERS_OF_10_32[t]) + 1; +} +#endif + +// A functor that doesn't add a thousands separator. +struct NoThousandsSep +{ + template + void operator()(Char *) {} +}; + +// A functor that adds a thousands separator. +class ThousandsSep +{ +private: + fmt::StringRef sep_; + + // Index of a decimal digit with the least significant digit having index 0. + unsigned digit_index_; + +public: + explicit ThousandsSep(fmt::StringRef sep) : sep_(sep), digit_index_(0) {} + + template + void operator()(Char *&buffer) + { + if (++digit_index_ % 3 != 0) + return; + buffer -= sep_.size(); + std::uninitialized_copy(sep_.data(), sep_.data() + sep_.size(), + internal::make_ptr(buffer, sep_.size())); + } +}; + +// Formats a decimal unsigned integer value writing into buffer. +// thousands_sep is a functor that is called after writing each char to +// add a thousands separator if necessary. +template +inline void format_decimal(Char *buffer, UInt value, unsigned num_digits, + ThousandsSep thousands_sep) +{ + buffer += num_digits; + while (value >= 100) + { + // Integer division is slow so do it for a group of two digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + unsigned index = static_cast((value % 100) * 2); + value /= 100; + *--buffer = Data::DIGITS[index + 1]; + thousands_sep(buffer); + *--buffer = Data::DIGITS[index]; + thousands_sep(buffer); + } + if (value < 10) + { + *--buffer = static_cast('0' + value); + return; + } + unsigned index = static_cast(value * 2); + *--buffer = Data::DIGITS[index + 1]; + thousands_sep(buffer); + *--buffer = Data::DIGITS[index]; +} + +template +inline void format_decimal(Char *buffer, UInt value, unsigned num_digits) +{ + return format_decimal(buffer, value, num_digits, NoThousandsSep()); +} + +#ifndef _WIN32 +# define FMT_USE_WINDOWS_H 0 +#elif !defined(FMT_USE_WINDOWS_H) +# define FMT_USE_WINDOWS_H 1 +#endif + +// Define FMT_USE_WINDOWS_H to 0 to disable use of windows.h. +// All the functionality that relies on it will be disabled too. +#if FMT_USE_WINDOWS_H +// A converter from UTF-8 to UTF-16. +// It is only provided for Windows since other systems support UTF-8 natively. +class UTF8ToUTF16 +{ +private: + MemoryBuffer buffer_; + +public: + FMT_API explicit UTF8ToUTF16(StringRef s); + operator WStringRef() const + { + return WStringRef(&buffer_[0], size()); + } + size_t size() const + { + return buffer_.size() - 1; + } + const wchar_t *c_str() const + { + return &buffer_[0]; + } + std::wstring str() const + { + return std::wstring(&buffer_[0], size()); + } +}; + +// A converter from UTF-16 to UTF-8. +// It is only provided for Windows since other systems support UTF-8 natively. +class UTF16ToUTF8 +{ +private: + MemoryBuffer buffer_; + +public: + UTF16ToUTF8() {} + FMT_API explicit UTF16ToUTF8(WStringRef s); + operator StringRef() const + { + return StringRef(&buffer_[0], size()); + } + size_t size() const + { + return buffer_.size() - 1; + } + const char *c_str() const + { + return &buffer_[0]; + } + std::string str() const + { + return std::string(&buffer_[0], size()); + } + + // Performs conversion returning a system error code instead of + // throwing exception on conversion error. This method may still throw + // in case of memory allocation error. + FMT_API int convert(WStringRef s); +}; + +FMT_API void format_windows_error(fmt::Writer &out, int error_code, + fmt::StringRef message) FMT_NOEXCEPT; +#endif + +FMT_API void format_system_error(fmt::Writer &out, int error_code, + fmt::StringRef message) FMT_NOEXCEPT; + +// A formatting argument value. +struct Value +{ + template + struct StringValue + { + const Char *value; + std::size_t size; + }; + + typedef void (*FormatFunc)( + void *formatter, const void *arg, void *format_str_ptr); + + struct CustomValue + { + const void *value; + FormatFunc format; + }; + + union + { + int int_value; + unsigned uint_value; + LongLong long_long_value; + ULongLong ulong_long_value; + double double_value; + long double long_double_value; + const void *pointer; + StringValue string; + StringValue sstring; + StringValue ustring; + StringValue wstring; + CustomValue custom; + }; + + enum Type + { + NONE, NAMED_ARG, + // Integer types should go first, + INT, UINT, LONG_LONG, ULONG_LONG, BOOL, CHAR, LAST_INTEGER_TYPE = CHAR, + // followed by floating-point types. + DOUBLE, LONG_DOUBLE, LAST_NUMERIC_TYPE = LONG_DOUBLE, + CSTRING, STRING, WSTRING, POINTER, CUSTOM + }; +}; + +// A formatting argument. It is a trivially copyable/constructible type to +// allow storage in internal::MemoryBuffer. +struct Arg : Value +{ + Type type; +}; + +template +struct NamedArg; + +template +struct Null {}; + +// A helper class template to enable or disable overloads taking wide +// characters and strings in MakeValue. +template +struct WCharHelper +{ + typedef Null Supported; + typedef T Unsupported; +}; + +template +struct WCharHelper +{ + typedef T Supported; + typedef Null Unsupported; +}; + +typedef char Yes[1]; +typedef char No[2]; + +template +T &get(); + +// These are non-members to workaround an overload resolution bug in bcc32. +Yes &convert(fmt::ULongLong); +No &convert(...); + +template +struct ConvertToIntImpl +{ + enum { value = ENABLE_CONVERSION }; +}; + +template +struct ConvertToIntImpl2 +{ + enum { value = false }; +}; + +template +struct ConvertToIntImpl2 +{ + enum + { + // Don't convert numeric types. + value = ConvertToIntImpl::is_specialized>::value + }; +}; + +template +struct ConvertToInt +{ + enum { enable_conversion = sizeof(convert(get())) == sizeof(Yes) }; + enum { value = ConvertToIntImpl2::value }; +}; + +#define FMT_DISABLE_CONVERSION_TO_INT(Type) \ + template <> \ + struct ConvertToInt { enum { value = 0 }; } + +// Silence warnings about convering float to int. +FMT_DISABLE_CONVERSION_TO_INT(float); +FMT_DISABLE_CONVERSION_TO_INT(double); +FMT_DISABLE_CONVERSION_TO_INT(long double); + +template +struct EnableIf {}; + +template +struct EnableIf +{ + typedef T type; +}; + +template +struct Conditional +{ + typedef T type; +}; + +template +struct Conditional +{ + typedef F type; +}; + +// For bcc32 which doesn't understand ! in template arguments. +template +struct Not +{ + enum { value = 0 }; +}; + +template<> +struct Not +{ + enum { value = 1 }; +}; + +template struct LConvCheck +{ + LConvCheck(int) {} +}; + +// Returns the thousands separator for the current locale. +// We check if ``lconv`` contains ``thousands_sep`` because on Android +// ``lconv`` is stubbed as an empty struct. +template +inline StringRef thousands_sep( + LConv *lc, LConvCheck = 0) +{ + return lc->thousands_sep; +} + +inline fmt::StringRef thousands_sep(...) +{ + return ""; +} + +// Makes an Arg object from any type. +template +class MakeValue : public Arg +{ +public: + typedef typename Formatter::Char Char; + +private: + // The following two methods are private to disallow formatting of + // arbitrary pointers. If you want to output a pointer cast it to + // "void *" or "const void *". In particular, this forbids formatting + // of "[const] volatile char *" which is printed as bool by iostreams. + // Do not implement! + template + MakeValue(const T *value); + template + MakeValue(T *value); + + // The following methods are private to disallow formatting of wide + // characters and strings into narrow strings as in + // fmt::format("{}", L"test"); + // To fix this, use a wide format string: fmt::format(L"{}", L"test"). +#if !FMT_MSC_VER || defined(_NATIVE_WCHAR_T_DEFINED) + MakeValue(typename WCharHelper::Unsupported); +#endif + MakeValue(typename WCharHelper::Unsupported); + MakeValue(typename WCharHelper::Unsupported); + MakeValue(typename WCharHelper::Unsupported); + MakeValue(typename WCharHelper::Unsupported); + + void set_string(StringRef str) + { + string.value = str.data(); + string.size = str.size(); + } + + void set_string(WStringRef str) + { + wstring.value = str.data(); + wstring.size = str.size(); + } + + // Formats an argument of a custom type, such as a user-defined class. + template + static void format_custom_arg( + void *formatter, const void *arg, void *format_str_ptr) + { + format(*static_cast(formatter), + *static_cast(format_str_ptr), + *static_cast(arg)); + } + +public: + MakeValue() {} + +#define FMT_MAKE_VALUE_(Type, field, TYPE, rhs) \ + MakeValue(Type value) { field = rhs; } \ + static uint64_t type(Type) { return Arg::TYPE; } + +#define FMT_MAKE_VALUE(Type, field, TYPE) \ + FMT_MAKE_VALUE_(Type, field, TYPE, value) + + FMT_MAKE_VALUE(bool, int_value, BOOL) + FMT_MAKE_VALUE(short, int_value, INT) + FMT_MAKE_VALUE(unsigned short, uint_value, UINT) + FMT_MAKE_VALUE(int, int_value, INT) + FMT_MAKE_VALUE(unsigned, uint_value, UINT) + + MakeValue(long value) + { + // To minimize the number of types we need to deal with, long is + // translated either to int or to long long depending on its size. + if (const_check(sizeof(long) == sizeof(int))) + int_value = static_cast(value); + else + long_long_value = value; + } + static uint64_t type(long) + { + return sizeof(long) == sizeof(int) ? Arg::INT : Arg::LONG_LONG; + } + + MakeValue(unsigned long value) + { + if (const_check(sizeof(unsigned long) == sizeof(unsigned))) + uint_value = static_cast(value); + else + ulong_long_value = value; + } + static uint64_t type(unsigned long) + { + return sizeof(unsigned long) == sizeof(unsigned) ? + Arg::UINT : Arg::ULONG_LONG; + } + + FMT_MAKE_VALUE(LongLong, long_long_value, LONG_LONG) + FMT_MAKE_VALUE(ULongLong, ulong_long_value, ULONG_LONG) + FMT_MAKE_VALUE(float, double_value, DOUBLE) + FMT_MAKE_VALUE(double, double_value, DOUBLE) + FMT_MAKE_VALUE(long double, long_double_value, LONG_DOUBLE) + FMT_MAKE_VALUE(signed char, int_value, INT) + FMT_MAKE_VALUE(unsigned char, uint_value, UINT) + FMT_MAKE_VALUE(char, int_value, CHAR) + +#if !defined(_MSC_VER) || defined(_NATIVE_WCHAR_T_DEFINED) + MakeValue(typename WCharHelper::Supported value) + { + int_value = value; + } + static uint64_t type(wchar_t) + { + return Arg::CHAR; + } +#endif + +#define FMT_MAKE_STR_VALUE(Type, TYPE) \ + MakeValue(Type value) { set_string(value); } \ + static uint64_t type(Type) { return Arg::TYPE; } + + FMT_MAKE_VALUE(char *, string.value, CSTRING) + FMT_MAKE_VALUE(const char *, string.value, CSTRING) + FMT_MAKE_VALUE(signed char *, sstring.value, CSTRING) + FMT_MAKE_VALUE(const signed char *, sstring.value, CSTRING) + FMT_MAKE_VALUE(unsigned char *, ustring.value, CSTRING) + FMT_MAKE_VALUE(const unsigned char *, ustring.value, CSTRING) + FMT_MAKE_STR_VALUE(const std::string &, STRING) + FMT_MAKE_STR_VALUE(StringRef, STRING) + FMT_MAKE_VALUE_(CStringRef, string.value, CSTRING, value.c_str()) + +#define FMT_MAKE_WSTR_VALUE(Type, TYPE) \ + MakeValue(typename WCharHelper::Supported value) { \ + set_string(value); \ + } \ + static uint64_t type(Type) { return Arg::TYPE; } + + FMT_MAKE_WSTR_VALUE(wchar_t *, WSTRING) + FMT_MAKE_WSTR_VALUE(const wchar_t *, WSTRING) + FMT_MAKE_WSTR_VALUE(const std::wstring &, WSTRING) + FMT_MAKE_WSTR_VALUE(WStringRef, WSTRING) + + FMT_MAKE_VALUE(void *, pointer, POINTER) + FMT_MAKE_VALUE(const void *, pointer, POINTER) + + template + MakeValue(const T &value, + typename EnableIf::value>::value, int>::type = 0) + { + custom.value = &value; + custom.format = &format_custom_arg; + } + + template + MakeValue(const T &value, + typename EnableIf::value, int>::type = 0) + { + int_value = value; + } + + template + static uint64_t type(const T &) + { + return ConvertToInt::value ? Arg::INT : Arg::CUSTOM; + } + + // Additional template param `Char_` is needed here because make_type always + // uses char. + template + MakeValue(const NamedArg &value) + { + pointer = &value; + } + + template + static uint64_t type(const NamedArg &) + { + return Arg::NAMED_ARG; + } +}; + +template +class MakeArg : public Arg +{ +public: + MakeArg() + { + type = Arg::NONE; + } + + template + MakeArg(const T &value) + : Arg(MakeValue(value)) + { + type = static_cast(MakeValue::type(value)); + } +}; + +template +struct NamedArg : Arg +{ + BasicStringRef name; + + template + NamedArg(BasicStringRef argname, const T &value) + : Arg(MakeArg< BasicFormatter >(value)), name(argname) {} +}; + +class RuntimeError : public std::runtime_error +{ +protected: + RuntimeError() : std::runtime_error("") {} + ~RuntimeError() throw(); +}; + +template +class PrintfArgFormatter; + +template +class ArgMap; +} // namespace internal + +/** An argument list. */ +class ArgList +{ +private: + // To reduce compiled code size per formatting function call, types of first + // MAX_PACKED_ARGS arguments are passed in the types_ field. + uint64_t types_; + union + { + // If the number of arguments is less than MAX_PACKED_ARGS, the argument + // values are stored in values_, otherwise they are stored in args_. + // This is done to reduce compiled code size as storing larger objects + // may require more code (at least on x86-64) even if the same amount of + // data is actually copied to stack. It saves ~10% on the bloat test. + const internal::Value *values_; + const internal::Arg *args_; + }; + + internal::Arg::Type type(unsigned index) const + { + unsigned shift = index * 4; + uint64_t mask = 0xf; + return static_cast( + (types_ & (mask << shift)) >> shift); + } + + template + friend class internal::ArgMap; + +public: + // Maximum number of arguments with packed types. + enum { MAX_PACKED_ARGS = 16 }; + + ArgList() : types_(0) {} + + ArgList(ULongLong types, const internal::Value *values) + : types_(types), values_(values) {} + ArgList(ULongLong types, const internal::Arg *args) + : types_(types), args_(args) {} + + /** Returns the argument at specified index. */ + internal::Arg operator[](unsigned index) const + { + using internal::Arg; + Arg arg; + bool use_values = type(MAX_PACKED_ARGS - 1) == Arg::NONE; + if (index < MAX_PACKED_ARGS) + { + Arg::Type arg_type = type(index); + internal::Value &val = arg; + if (arg_type != Arg::NONE) + val = use_values ? values_[index] : args_[index]; + arg.type = arg_type; + return arg; + } + if (use_values) + { + // The index is greater than the number of arguments that can be stored + // in values, so return a "none" argument. + arg.type = Arg::NONE; + return arg; + } + for (unsigned i = MAX_PACKED_ARGS; i <= index; ++i) + { + if (args_[i].type == Arg::NONE) + return args_[i]; + } + return args_[index]; + } +}; + +#define FMT_DISPATCH(call) static_cast(this)->call + +/** + \rst + An argument visitor based on the `curiously recurring template pattern + `_. + + To use `~fmt::ArgVisitor` define a subclass that implements some or all of the + visit methods with the same signatures as the methods in `~fmt::ArgVisitor`, + for example, `~fmt::ArgVisitor::visit_int()`. + Pass the subclass as the *Impl* template parameter. Then calling + `~fmt::ArgVisitor::visit` for some argument will dispatch to a visit method + specific to the argument type. For example, if the argument type is + ``double`` then the `~fmt::ArgVisitor::visit_double()` method of a subclass + will be called. If the subclass doesn't contain a method with this signature, + then a corresponding method of `~fmt::ArgVisitor` will be called. + + **Example**:: + + class MyArgVisitor : public fmt::ArgVisitor { + public: + void visit_int(int value) { fmt::print("{}", value); } + void visit_double(double value) { fmt::print("{}", value ); } + }; + \endrst + */ +template +class ArgVisitor +{ +private: + typedef internal::Arg Arg; + +public: + void report_unhandled_arg() {} + + Result visit_unhandled_arg() + { + FMT_DISPATCH(report_unhandled_arg()); + return Result(); + } + + /** Visits an ``int`` argument. **/ + Result visit_int(int value) + { + return FMT_DISPATCH(visit_any_int(value)); + } + + /** Visits a ``long long`` argument. **/ + Result visit_long_long(LongLong value) + { + return FMT_DISPATCH(visit_any_int(value)); + } + + /** Visits an ``unsigned`` argument. **/ + Result visit_uint(unsigned value) + { + return FMT_DISPATCH(visit_any_int(value)); + } + + /** Visits an ``unsigned long long`` argument. **/ + Result visit_ulong_long(ULongLong value) + { + return FMT_DISPATCH(visit_any_int(value)); + } + + /** Visits a ``bool`` argument. **/ + Result visit_bool(bool value) + { + return FMT_DISPATCH(visit_any_int(value)); + } + + /** Visits a ``char`` or ``wchar_t`` argument. **/ + Result visit_char(int value) + { + return FMT_DISPATCH(visit_any_int(value)); + } + + /** Visits an argument of any integral type. **/ + template + Result visit_any_int(T) + { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** Visits a ``double`` argument. **/ + Result visit_double(double value) + { + return FMT_DISPATCH(visit_any_double(value)); + } + + /** Visits a ``long double`` argument. **/ + Result visit_long_double(long double value) + { + return FMT_DISPATCH(visit_any_double(value)); + } + + /** Visits a ``double`` or ``long double`` argument. **/ + template + Result visit_any_double(T) + { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** Visits a null-terminated C string (``const char *``) argument. **/ + Result visit_cstring(const char *) + { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** Visits a string argument. **/ + Result visit_string(Arg::StringValue) + { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** Visits a wide string argument. **/ + Result visit_wstring(Arg::StringValue) + { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** Visits a pointer argument. **/ + Result visit_pointer(const void *) + { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** Visits an argument of a custom (user-defined) type. **/ + Result visit_custom(Arg::CustomValue) + { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** + \rst + Visits an argument dispatching to the appropriate visit method based on + the argument type. For example, if the argument type is ``double`` then + the `~fmt::ArgVisitor::visit_double()` method of the *Impl* class will be + called. + \endrst + */ + Result visit(const Arg &arg) + { + switch (arg.type) + { + case Arg::NONE: + case Arg::NAMED_ARG: + FMT_ASSERT(false, "invalid argument type"); + break; + case Arg::INT: + return FMT_DISPATCH(visit_int(arg.int_value)); + case Arg::UINT: + return FMT_DISPATCH(visit_uint(arg.uint_value)); + case Arg::LONG_LONG: + return FMT_DISPATCH(visit_long_long(arg.long_long_value)); + case Arg::ULONG_LONG: + return FMT_DISPATCH(visit_ulong_long(arg.ulong_long_value)); + case Arg::BOOL: + return FMT_DISPATCH(visit_bool(arg.int_value != 0)); + case Arg::CHAR: + return FMT_DISPATCH(visit_char(arg.int_value)); + case Arg::DOUBLE: + return FMT_DISPATCH(visit_double(arg.double_value)); + case Arg::LONG_DOUBLE: + return FMT_DISPATCH(visit_long_double(arg.long_double_value)); + case Arg::CSTRING: + return FMT_DISPATCH(visit_cstring(arg.string.value)); + case Arg::STRING: + return FMT_DISPATCH(visit_string(arg.string)); + case Arg::WSTRING: + return FMT_DISPATCH(visit_wstring(arg.wstring)); + case Arg::POINTER: + return FMT_DISPATCH(visit_pointer(arg.pointer)); + case Arg::CUSTOM: + return FMT_DISPATCH(visit_custom(arg.custom)); + } + return Result(); + } +}; + +enum Alignment +{ + ALIGN_DEFAULT, ALIGN_LEFT, ALIGN_RIGHT, ALIGN_CENTER, ALIGN_NUMERIC +}; + +// Flags. +enum +{ + SIGN_FLAG = 1, PLUS_FLAG = 2, MINUS_FLAG = 4, HASH_FLAG = 8, + CHAR_FLAG = 0x10 // Argument has char type - used in error reporting. +}; + +// An empty format specifier. +struct EmptySpec {}; + +// A type specifier. +template +struct TypeSpec : EmptySpec +{ + Alignment align() const + { + return ALIGN_DEFAULT; + } + unsigned width() const + { + return 0; + } + int precision() const + { + return -1; + } + bool flag(unsigned) const + { + return false; + } + char type() const + { + return TYPE; + } + char fill() const + { + return ' '; + } +}; + +// A width specifier. +struct WidthSpec +{ + unsigned width_; + // Fill is always wchar_t and cast to char if necessary to avoid having + // two specialization of WidthSpec and its subclasses. + wchar_t fill_; + + WidthSpec(unsigned width, wchar_t fill) : width_(width), fill_(fill) {} + + unsigned width() const + { + return width_; + } + wchar_t fill() const + { + return fill_; + } +}; + +// An alignment specifier. +struct AlignSpec : WidthSpec +{ + Alignment align_; + + AlignSpec(unsigned width, wchar_t fill, Alignment align = ALIGN_DEFAULT) + : WidthSpec(width, fill), align_(align) {} + + Alignment align() const + { + return align_; + } + + int precision() const + { + return -1; + } +}; + +// An alignment and type specifier. +template +struct AlignTypeSpec : AlignSpec +{ + AlignTypeSpec(unsigned width, wchar_t fill) : AlignSpec(width, fill) {} + + bool flag(unsigned) const + { + return false; + } + char type() const + { + return TYPE; + } +}; + +// A full format specifier. +struct FormatSpec : AlignSpec +{ + unsigned flags_; + int precision_; + char type_; + + FormatSpec( + unsigned width = 0, char type = 0, wchar_t fill = ' ') + : AlignSpec(width, fill), flags_(0), precision_(-1), type_(type) {} + + bool flag(unsigned f) const + { + return (flags_ & f) != 0; + } + int precision() const + { + return precision_; + } + char type() const + { + return type_; + } +}; + +// An integer format specifier. +template , typename Char = char> +class IntFormatSpec : public SpecT +{ +private: + T value_; + +public: + IntFormatSpec(T val, const SpecT &spec = SpecT()) + : SpecT(spec), value_(val) {} + + T value() const + { + return value_; + } +}; + +// A string format specifier. +template +class StrFormatSpec : public AlignSpec +{ +private: + const Char *str_; + +public: + template + StrFormatSpec(const Char *str, unsigned width, FillChar fill) + : AlignSpec(width, fill), str_(str) + { + internal::CharTraits::convert(FillChar()); + } + + const Char *str() const + { + return str_; + } +}; + +/** + Returns an integer format specifier to format the value in base 2. + */ +IntFormatSpec > bin(int value); + +/** + Returns an integer format specifier to format the value in base 8. + */ +IntFormatSpec > oct(int value); + +/** + Returns an integer format specifier to format the value in base 16 using + lower-case letters for the digits above 9. + */ +IntFormatSpec > hex(int value); + +/** + Returns an integer formatter format specifier to format in base 16 using + upper-case letters for the digits above 9. + */ +IntFormatSpec > hexu(int value); + +/** + \rst + Returns an integer format specifier to pad the formatted argument with the + fill character to the specified width using the default (right) numeric + alignment. + + **Example**:: + + MemoryWriter out; + out << pad(hex(0xcafe), 8, '0'); + // out.str() == "0000cafe" + + \endrst + */ +template +IntFormatSpec, Char> pad( + int value, unsigned width, Char fill = ' '); + +#define FMT_DEFINE_INT_FORMATTERS(TYPE) \ +inline IntFormatSpec > bin(TYPE value) { \ + return IntFormatSpec >(value, TypeSpec<'b'>()); \ +} \ + \ +inline IntFormatSpec > oct(TYPE value) { \ + return IntFormatSpec >(value, TypeSpec<'o'>()); \ +} \ + \ +inline IntFormatSpec > hex(TYPE value) { \ + return IntFormatSpec >(value, TypeSpec<'x'>()); \ +} \ + \ +inline IntFormatSpec > hexu(TYPE value) { \ + return IntFormatSpec >(value, TypeSpec<'X'>()); \ +} \ + \ +template \ +inline IntFormatSpec > pad( \ + IntFormatSpec > f, unsigned width) { \ + return IntFormatSpec >( \ + f.value(), AlignTypeSpec(width, ' ')); \ +} \ + \ +/* For compatibility with older compilers we provide two overloads for pad, */ \ +/* one that takes a fill character and one that doesn't. In the future this */ \ +/* can be replaced with one overload making the template argument Char */ \ +/* default to char (C++11). */ \ +template \ +inline IntFormatSpec, Char> pad( \ + IntFormatSpec, Char> f, \ + unsigned width, Char fill) { \ + return IntFormatSpec, Char>( \ + f.value(), AlignTypeSpec(width, fill)); \ +} \ + \ +inline IntFormatSpec > pad( \ + TYPE value, unsigned width) { \ + return IntFormatSpec >( \ + value, AlignTypeSpec<0>(width, ' ')); \ +} \ + \ +template \ +inline IntFormatSpec, Char> pad( \ + TYPE value, unsigned width, Char fill) { \ + return IntFormatSpec, Char>( \ + value, AlignTypeSpec<0>(width, fill)); \ +} + +FMT_DEFINE_INT_FORMATTERS(int) +FMT_DEFINE_INT_FORMATTERS(long) +FMT_DEFINE_INT_FORMATTERS(unsigned) +FMT_DEFINE_INT_FORMATTERS(unsigned long) +FMT_DEFINE_INT_FORMATTERS(LongLong) +FMT_DEFINE_INT_FORMATTERS(ULongLong) + +/** + \rst + Returns a string formatter that pads the formatted argument with the fill + character to the specified width using the default (left) string alignment. + + **Example**:: + + std::string s = str(MemoryWriter() << pad("abc", 8)); + // s == "abc " + + \endrst + */ +template +inline StrFormatSpec pad( + const Char *str, unsigned width, Char fill = ' ') +{ + return StrFormatSpec(str, width, fill); +} + +inline StrFormatSpec pad( + const wchar_t *str, unsigned width, char fill = ' ') +{ + return StrFormatSpec(str, width, fill); +} + +namespace internal +{ + +template +class ArgMap +{ +private: + typedef std::vector< + std::pair, internal::Arg> > MapType; + typedef typename MapType::value_type Pair; + + MapType map_; + +public: + FMT_API void init(const ArgList &args); + + const internal::Arg* find(const fmt::BasicStringRef &name) const + { + // The list is unsorted, so just return the first matching name. + for (typename MapType::const_iterator it = map_.begin(), end = map_.end(); + it != end; ++it) + { + if (it->first == name) + return &it->second; + } + return 0; + } +}; + +template +class ArgFormatterBase : public ArgVisitor +{ +private: + BasicWriter &writer_; + FormatSpec &spec_; + + FMT_DISALLOW_COPY_AND_ASSIGN(ArgFormatterBase); + + void write_pointer(const void *p) + { + spec_.flags_ = HASH_FLAG; + spec_.type_ = 'x'; + writer_.write_int(reinterpret_cast(p), spec_); + } + +protected: + BasicWriter &writer() + { + return writer_; + } + FormatSpec &spec() + { + return spec_; + } + + void write(bool value) + { + const char *str_value = value ? "true" : "false"; + Arg::StringValue str = { str_value, std::strlen(str_value) }; + writer_.write_str(str, spec_); + } + + void write(const char *value) + { + Arg::StringValue str = {value, value != 0 ? std::strlen(value) : 0}; + writer_.write_str(str, spec_); + } + +public: + ArgFormatterBase(BasicWriter &w, FormatSpec &s) + : writer_(w), spec_(s) {} + + template + void visit_any_int(T value) + { + writer_.write_int(value, spec_); + } + + template + void visit_any_double(T value) + { + writer_.write_double(value, spec_); + } + + void visit_bool(bool value) + { + if (spec_.type_) + return visit_any_int(value); + write(value); + } + + void visit_char(int value) + { + if (spec_.type_ && spec_.type_ != 'c') + { + spec_.flags_ |= CHAR_FLAG; + writer_.write_int(value, spec_); + return; + } + if (spec_.align_ == ALIGN_NUMERIC || spec_.flags_ != 0) + FMT_THROW(FormatError("invalid format specifier for char")); + typedef typename BasicWriter::CharPtr CharPtr; + Char fill = internal::CharTraits::cast(spec_.fill()); + CharPtr out = CharPtr(); + const unsigned CHAR_SIZE = 1; + if (spec_.width_ > CHAR_SIZE) + { + out = writer_.grow_buffer(spec_.width_); + if (spec_.align_ == ALIGN_RIGHT) + { + std::uninitialized_fill_n(out, spec_.width_ - CHAR_SIZE, fill); + out += spec_.width_ - CHAR_SIZE; + } + else if (spec_.align_ == ALIGN_CENTER) + { + out = writer_.fill_padding(out, spec_.width_, + internal::const_check(CHAR_SIZE), fill); + } + else + { + std::uninitialized_fill_n(out + CHAR_SIZE, + spec_.width_ - CHAR_SIZE, fill); + } + } + else + { + out = writer_.grow_buffer(CHAR_SIZE); + } + *out = internal::CharTraits::cast(value); + } + + void visit_cstring(const char *value) + { + if (spec_.type_ == 'p') + return write_pointer(value); + write(value); + } + + void visit_string(Arg::StringValue value) + { + writer_.write_str(value, spec_); + } + + using ArgVisitor::visit_wstring; + + void visit_wstring(Arg::StringValue value) + { + writer_.write_str(value, spec_); + } + + void visit_pointer(const void *value) + { + if (spec_.type_ && spec_.type_ != 'p') + report_unknown_type(spec_.type_, "pointer"); + write_pointer(value); + } +}; + +class FormatterBase +{ +private: + ArgList args_; + int next_arg_index_; + + // Returns the argument with specified index. + FMT_API Arg do_get_arg(unsigned arg_index, const char *&error); + +protected: + const ArgList &args() const + { + return args_; + } + + explicit FormatterBase(const ArgList &args) + { + args_ = args; + next_arg_index_ = 0; + } + + // Returns the next argument. + Arg next_arg(const char *&error) + { + if (next_arg_index_ >= 0) + return do_get_arg(internal::to_unsigned(next_arg_index_++), error); + error = "cannot switch from manual to automatic argument indexing"; + return Arg(); + } + + // Checks if manual indexing is used and returns the argument with + // specified index. + Arg get_arg(unsigned arg_index, const char *&error) + { + return check_no_auto_index(error) ? do_get_arg(arg_index, error) : Arg(); + } + + bool check_no_auto_index(const char *&error) + { + if (next_arg_index_ > 0) + { + error = "cannot switch from automatic to manual argument indexing"; + return false; + } + next_arg_index_ = -1; + return true; + } + + template + void write(BasicWriter &w, const Char *start, const Char *end) + { + if (start != end) + w << BasicStringRef(start, internal::to_unsigned(end - start)); + } +}; + +// A printf formatter. +template +class PrintfFormatter : private FormatterBase +{ +private: + void parse_flags(FormatSpec &spec, const Char *&s); + + // Returns the argument with specified index or, if arg_index is equal + // to the maximum unsigned value, the next argument. + Arg get_arg(const Char *s, + unsigned arg_index = (std::numeric_limits::max)()); + + // Parses argument index, flags and width and returns the argument index. + unsigned parse_header(const Char *&s, FormatSpec &spec); + +public: + explicit PrintfFormatter(const ArgList &args) : FormatterBase(args) {} + FMT_API void format(BasicWriter &writer, + BasicCStringRef format_str); +}; +} // namespace internal + +/** + \rst + An argument formatter based on the `curiously recurring template pattern + `_. + + To use `~fmt::BasicArgFormatter` define a subclass that implements some or + all of the visit methods with the same signatures as the methods in + `~fmt::ArgVisitor`, for example, `~fmt::ArgVisitor::visit_int()`. + Pass the subclass as the *Impl* template parameter. When a formatting + function processes an argument, it will dispatch to a visit method + specific to the argument type. For example, if the argument type is + ``double`` then the `~fmt::ArgVisitor::visit_double()` method of a subclass + will be called. If the subclass doesn't contain a method with this signature, + then a corresponding method of `~fmt::BasicArgFormatter` or its superclass + will be called. + \endrst + */ +template +class BasicArgFormatter : public internal::ArgFormatterBase +{ +private: + BasicFormatter &formatter_; + const Char *format_; + +public: + /** + \rst + Constructs an argument formatter object. + *formatter* is a reference to the main formatter object, *spec* contains + format specifier information for standard argument types, and *fmt* points + to the part of the format string being parsed for custom argument types. + \endrst + */ + BasicArgFormatter(BasicFormatter &formatter, + FormatSpec &spec, const Char *fmt) + : internal::ArgFormatterBase(formatter.writer(), spec), + formatter_(formatter), format_(fmt) {} + + /** Formats argument of a custom (user-defined) type. */ + void visit_custom(internal::Arg::CustomValue c) + { + c.format(&formatter_, c.value, &format_); + } +}; + +/** The default argument formatter. */ +template +class ArgFormatter : public BasicArgFormatter, Char> +{ +public: + /** Constructs an argument formatter object. */ + ArgFormatter(BasicFormatter &formatter, + FormatSpec &spec, const Char *fmt) + : BasicArgFormatter, Char>(formatter, spec, fmt) {} +}; + +/** This template formats data and writes the output to a writer. */ +template +class BasicFormatter : private internal::FormatterBase +{ +public: + /** The character type for the output. */ + typedef CharType Char; + +private: + BasicWriter &writer_; + internal::ArgMap map_; + + FMT_DISALLOW_COPY_AND_ASSIGN(BasicFormatter); + + using internal::FormatterBase::get_arg; + + // Checks if manual indexing is used and returns the argument with + // specified name. + internal::Arg get_arg(BasicStringRef arg_name, const char *&error); + + // Parses argument index and returns corresponding argument. + internal::Arg parse_arg_index(const Char *&s); + + // Parses argument name and returns corresponding argument. + internal::Arg parse_arg_name(const Char *&s); + +public: + /** + \rst + Constructs a ``BasicFormatter`` object. References to the arguments and + the writer are stored in the formatter object so make sure they have + appropriate lifetimes. + \endrst + */ + BasicFormatter(const ArgList &args, BasicWriter &w) + : internal::FormatterBase(args), writer_(w) {} + + /** Returns a reference to the writer associated with this formatter. */ + BasicWriter &writer() + { + return writer_; + } + + /** Formats stored arguments and writes the output to the writer. */ + void format(BasicCStringRef format_str); + + // Formats a single argument and advances format_str, a format string pointer. + const Char *format(const Char *&format_str, const internal::Arg &arg); +}; + +// Generates a comma-separated list with results of applying f to +// numbers 0..n-1. +# define FMT_GEN(n, f) FMT_GEN##n(f) +# define FMT_GEN1(f) f(0) +# define FMT_GEN2(f) FMT_GEN1(f), f(1) +# define FMT_GEN3(f) FMT_GEN2(f), f(2) +# define FMT_GEN4(f) FMT_GEN3(f), f(3) +# define FMT_GEN5(f) FMT_GEN4(f), f(4) +# define FMT_GEN6(f) FMT_GEN5(f), f(5) +# define FMT_GEN7(f) FMT_GEN6(f), f(6) +# define FMT_GEN8(f) FMT_GEN7(f), f(7) +# define FMT_GEN9(f) FMT_GEN8(f), f(8) +# define FMT_GEN10(f) FMT_GEN9(f), f(9) +# define FMT_GEN11(f) FMT_GEN10(f), f(10) +# define FMT_GEN12(f) FMT_GEN11(f), f(11) +# define FMT_GEN13(f) FMT_GEN12(f), f(12) +# define FMT_GEN14(f) FMT_GEN13(f), f(13) +# define FMT_GEN15(f) FMT_GEN14(f), f(14) + +namespace internal +{ +inline uint64_t make_type() +{ + return 0; +} + +template +inline uint64_t make_type(const T &arg) +{ + return MakeValue< BasicFormatter >::type(arg); +} + +template + struct ArgArray; + +template +struct ArgArray +{ + typedef Value Type[N > 0 ? N : 1]; + +template +static Value make(const T &value) +{ +#ifdef __clang__ + Value result = MakeValue(value); + // Workaround a bug in Apple LLVM version 4.2 (clang-425.0.28) of clang: + // https://github.com/fmtlib/fmt/issues/276 + (void)result.custom.format; + return result; +#else + return MakeValue(value); +#endif +} + }; + +template +struct ArgArray +{ + typedef Arg Type[N + 1]; // +1 for the list end Arg::NONE + + template + static Arg make(const T &value) + { + return MakeArg(value); + } +}; + +#if FMT_USE_VARIADIC_TEMPLATES +template +inline uint64_t make_type(const Arg &first, const Args & ... tail) +{ + return make_type(first) | (make_type(tail...) << 4); +} + +#else + +struct ArgType +{ + uint64_t type; + + ArgType() : type(0) {} + + template + ArgType(const T &arg) : type(make_type(arg)) {} +}; + +# define FMT_ARG_TYPE_DEFAULT(n) ArgType t##n = ArgType() + +inline uint64_t make_type(FMT_GEN15(FMT_ARG_TYPE_DEFAULT)) +{ + return t0.type | (t1.type << 4) | (t2.type << 8) | (t3.type << 12) | + (t4.type << 16) | (t5.type << 20) | (t6.type << 24) | (t7.type << 28) | + (t8.type << 32) | (t9.type << 36) | (t10.type << 40) | (t11.type << 44) | + (t12.type << 48) | (t13.type << 52) | (t14.type << 56); +} +#endif +} // namespace internal + +# define FMT_MAKE_TEMPLATE_ARG(n) typename T##n +# define FMT_MAKE_ARG_TYPE(n) T##n +# define FMT_MAKE_ARG(n) const T##n &v##n +# define FMT_ASSIGN_char(n) \ + arr[n] = fmt::internal::MakeValue< fmt::BasicFormatter >(v##n) +# define FMT_ASSIGN_wchar_t(n) \ + arr[n] = fmt::internal::MakeValue< fmt::BasicFormatter >(v##n) + +#if FMT_USE_VARIADIC_TEMPLATES +// Defines a variadic function returning void. +# define FMT_VARIADIC_VOID(func, arg_type) \ + template \ + void func(arg_type arg0, const Args & ... args) { \ + typedef fmt::internal::ArgArray ArgArray; \ + typename ArgArray::Type array{ \ + ArgArray::template make >(args)...}; \ + func(arg0, fmt::ArgList(fmt::internal::make_type(args...), array)); \ + } + +// Defines a variadic constructor. +# define FMT_VARIADIC_CTOR(ctor, func, arg0_type, arg1_type) \ + template \ + ctor(arg0_type arg0, arg1_type arg1, const Args & ... args) { \ + typedef fmt::internal::ArgArray ArgArray; \ + typename ArgArray::Type array{ \ + ArgArray::template make >(args)...}; \ + func(arg0, arg1, fmt::ArgList(fmt::internal::make_type(args...), array)); \ + } + +#else + +# define FMT_MAKE_REF(n) \ + fmt::internal::MakeValue< fmt::BasicFormatter >(v##n) +# define FMT_MAKE_REF2(n) v##n + +// Defines a wrapper for a function taking one argument of type arg_type +// and n additional arguments of arbitrary types. +# define FMT_WRAP1(func, arg_type, n) \ + template \ + inline void func(arg_type arg1, FMT_GEN(n, FMT_MAKE_ARG)) { \ + const fmt::internal::ArgArray::Type array = {FMT_GEN(n, FMT_MAKE_REF)}; \ + func(arg1, fmt::ArgList( \ + fmt::internal::make_type(FMT_GEN(n, FMT_MAKE_REF2)), array)); \ + } + +// Emulates a variadic function returning void on a pre-C++11 compiler. +# define FMT_VARIADIC_VOID(func, arg_type) \ + inline void func(arg_type arg) { func(arg, fmt::ArgList()); } \ + FMT_WRAP1(func, arg_type, 1) FMT_WRAP1(func, arg_type, 2) \ + FMT_WRAP1(func, arg_type, 3) FMT_WRAP1(func, arg_type, 4) \ + FMT_WRAP1(func, arg_type, 5) FMT_WRAP1(func, arg_type, 6) \ + FMT_WRAP1(func, arg_type, 7) FMT_WRAP1(func, arg_type, 8) \ + FMT_WRAP1(func, arg_type, 9) FMT_WRAP1(func, arg_type, 10) + +# define FMT_CTOR(ctor, func, arg0_type, arg1_type, n) \ + template \ + ctor(arg0_type arg0, arg1_type arg1, FMT_GEN(n, FMT_MAKE_ARG)) { \ + const fmt::internal::ArgArray::Type array = {FMT_GEN(n, FMT_MAKE_REF)}; \ + func(arg0, arg1, fmt::ArgList( \ + fmt::internal::make_type(FMT_GEN(n, FMT_MAKE_REF2)), array)); \ + } + +// Emulates a variadic constructor on a pre-C++11 compiler. +# define FMT_VARIADIC_CTOR(ctor, func, arg0_type, arg1_type) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 1) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 2) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 3) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 4) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 5) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 6) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 7) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 8) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 9) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 10) +#endif + +// Generates a comma-separated list with results of applying f to pairs +// (argument, index). +#define FMT_FOR_EACH1(f, x0) f(x0, 0) +#define FMT_FOR_EACH2(f, x0, x1) \ + FMT_FOR_EACH1(f, x0), f(x1, 1) +#define FMT_FOR_EACH3(f, x0, x1, x2) \ + FMT_FOR_EACH2(f, x0 ,x1), f(x2, 2) +#define FMT_FOR_EACH4(f, x0, x1, x2, x3) \ + FMT_FOR_EACH3(f, x0, x1, x2), f(x3, 3) +#define FMT_FOR_EACH5(f, x0, x1, x2, x3, x4) \ + FMT_FOR_EACH4(f, x0, x1, x2, x3), f(x4, 4) +#define FMT_FOR_EACH6(f, x0, x1, x2, x3, x4, x5) \ + FMT_FOR_EACH5(f, x0, x1, x2, x3, x4), f(x5, 5) +#define FMT_FOR_EACH7(f, x0, x1, x2, x3, x4, x5, x6) \ + FMT_FOR_EACH6(f, x0, x1, x2, x3, x4, x5), f(x6, 6) +#define FMT_FOR_EACH8(f, x0, x1, x2, x3, x4, x5, x6, x7) \ + FMT_FOR_EACH7(f, x0, x1, x2, x3, x4, x5, x6), f(x7, 7) +#define FMT_FOR_EACH9(f, x0, x1, x2, x3, x4, x5, x6, x7, x8) \ + FMT_FOR_EACH8(f, x0, x1, x2, x3, x4, x5, x6, x7), f(x8, 8) +#define FMT_FOR_EACH10(f, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) \ + FMT_FOR_EACH9(f, x0, x1, x2, x3, x4, x5, x6, x7, x8), f(x9, 9) + +/** + An error returned by an operating system or a language runtime, + for example a file opening error. +*/ +class SystemError : public internal::RuntimeError +{ +private: + void init(int err_code, CStringRef format_str, ArgList args); + +protected: + int error_code_; + + typedef char Char; // For FMT_VARIADIC_CTOR. + + SystemError() {} + +public: + /** + \rst + Constructs a :class:`fmt::SystemError` object with the description + of the form + + .. parsed-literal:: + **: ** + + where ** is the formatted message and ** is + the system message corresponding to the error code. + *error_code* is a system error code as given by ``errno``. + If *error_code* is not a valid error code such as -1, the system message + may look like "Unknown error -1" and is platform-dependent. + + **Example**:: + + // This throws a SystemError with the description + // cannot open file 'madeup': No such file or directory + // or similar (system message may vary). + const char *filename = "madeup"; + std::FILE *file = std::fopen(filename, "r"); + if (!file) + throw fmt::SystemError(errno, "cannot open file '{}'", filename); + \endrst + */ + SystemError(int error_code, CStringRef message) + { + init(error_code, message, ArgList()); + } + FMT_VARIADIC_CTOR(SystemError, init, int, CStringRef) + + ~SystemError() throw(); + + int error_code() const + { + return error_code_; + } +}; + +/** + \rst + This template provides operations for formatting and writing data into + a character stream. The output is stored in a buffer provided by a subclass + such as :class:`fmt::BasicMemoryWriter`. + + You can use one of the following typedefs for common character types: + + +---------+----------------------+ + | Type | Definition | + +=========+======================+ + | Writer | BasicWriter | + +---------+----------------------+ + | WWriter | BasicWriter | + +---------+----------------------+ + + \endrst + */ +template +class BasicWriter +{ +private: + // Output buffer. + Buffer &buffer_; + + FMT_DISALLOW_COPY_AND_ASSIGN(BasicWriter); + + typedef typename internal::CharTraits::CharPtr CharPtr; + +#if FMT_SECURE_SCL + // Returns pointer value. + static Char *get(CharPtr p) + { + return p.base(); + } +#else + static Char *get(Char *p) + { + return p; + } +#endif + + // Fills the padding around the content and returns the pointer to the + // content area. + static CharPtr fill_padding(CharPtr buffer, + unsigned total_size, std::size_t content_size, wchar_t fill); + + // Grows the buffer by n characters and returns a pointer to the newly + // allocated area. + CharPtr grow_buffer(std::size_t n) + { + std::size_t size = buffer_.size(); + buffer_.resize(size + n); + return internal::make_ptr(&buffer_[size], n); + } + + // Writes an unsigned decimal integer. + template + Char *write_unsigned_decimal(UInt value, unsigned prefix_size = 0) + { + unsigned num_digits = internal::count_digits(value); + Char *ptr = get(grow_buffer(prefix_size + num_digits)); + internal::format_decimal(ptr + prefix_size, value, num_digits); + return ptr; + } + + // Writes a decimal integer. + template + void write_decimal(Int value) + { + typedef typename internal::IntTraits::MainType MainType; + MainType abs_value = static_cast(value); + if (internal::is_negative(value)) + { + abs_value = 0 - abs_value; + *write_unsigned_decimal(abs_value, 1) = '-'; + } + else + { + write_unsigned_decimal(abs_value, 0); + } + } + + // Prepare a buffer for integer formatting. + CharPtr prepare_int_buffer(unsigned num_digits, + const EmptySpec &, const char *prefix, unsigned prefix_size) + { + unsigned size = prefix_size + num_digits; + CharPtr p = grow_buffer(size); + std::uninitialized_copy(prefix, prefix + prefix_size, p); + return p + size - 1; + } + + template + CharPtr prepare_int_buffer(unsigned num_digits, + const Spec &spec, const char *prefix, unsigned prefix_size); + + // Formats an integer. + template + void write_int(T value, Spec spec); + + // Formats a floating-point number (double or long double). + template + void write_double(T value, const FormatSpec &spec); + + // Writes a formatted string. + template + CharPtr write_str(const StrChar *s, std::size_t size, const AlignSpec &spec); + + template + void write_str(const internal::Arg::StringValue &str, + const FormatSpec &spec); + + // This following methods are private to disallow writing wide characters + // and strings to a char stream. If you want to print a wide string as a + // pointer as std::ostream does, cast it to const void*. + // Do not implement! + void operator<<(typename internal::WCharHelper::Unsupported); + void operator<<( + typename internal::WCharHelper::Unsupported); + + // Appends floating-point length specifier to the format string. + // The second argument is only used for overload resolution. + void append_float_length(Char *&format_ptr, long double) + { + *format_ptr++ = 'L'; + } + + template + void append_float_length(Char *&, T) {} + + template + friend class internal::ArgFormatterBase; + + friend class internal::PrintfArgFormatter; + +protected: + /** + Constructs a ``BasicWriter`` object. + */ + explicit BasicWriter(Buffer &b) : buffer_(b) {} + +public: + /** + \rst + Destroys a ``BasicWriter`` object. + \endrst + */ + virtual ~BasicWriter() {} + + /** + Returns the total number of characters written. + */ + std::size_t size() const + { + return buffer_.size(); + } + + /** + Returns a pointer to the output buffer content. No terminating null + character is appended. + */ + const Char *data() const FMT_NOEXCEPT + { + return &buffer_[0]; + } + + /** + Returns a pointer to the output buffer content with terminating null + character appended. + */ + const Char *c_str() const + { + std::size_t size = buffer_.size(); + buffer_.reserve(size + 1); + buffer_[size] = '\0'; + return &buffer_[0]; + } + + /** + \rst + Returns the content of the output buffer as an `std::string`. + \endrst + */ + std::basic_string str() const + { + return std::basic_string(&buffer_[0], buffer_.size()); + } + + /** + \rst + Writes formatted data. + + *args* is an argument list representing arbitrary arguments. + + **Example**:: + + MemoryWriter out; + out.write("Current point:\n"); + out.write("({:+f}, {:+f})", -3.14, 3.14); + + This will write the following output to the ``out`` object: + + .. code-block:: none + + Current point: + (-3.140000, +3.140000) + + The output can be accessed using :func:`data()`, :func:`c_str` or + :func:`str` methods. + + See also :ref:`syntax`. + \endrst + */ + void write(BasicCStringRef format, ArgList args) + { + BasicFormatter(args, *this).format(format); + } + FMT_VARIADIC_VOID(write, BasicCStringRef) + + BasicWriter &operator<<(int value) + { + write_decimal(value); + return *this; + } + BasicWriter &operator<<(unsigned value) + { + return *this << IntFormatSpec(value); + } + BasicWriter &operator<<(long value) + { + write_decimal(value); + return *this; + } + BasicWriter &operator<<(unsigned long value) + { + return *this << IntFormatSpec(value); + } + BasicWriter &operator<<(LongLong value) + { + write_decimal(value); + return *this; + } + + /** + \rst + Formats *value* and writes it to the stream. + \endrst + */ + BasicWriter &operator<<(ULongLong value) + { + return *this << IntFormatSpec(value); + } + + BasicWriter &operator<<(double value) + { + write_double(value, FormatSpec()); + return *this; + } + + /** + \rst + Formats *value* using the general format for floating-point numbers + (``'g'``) and writes it to the stream. + \endrst + */ + BasicWriter &operator<<(long double value) + { + write_double(value, FormatSpec()); + return *this; + } + + /** + Writes a character to the stream. + */ + BasicWriter &operator<<(char value) + { + buffer_.push_back(value); + return *this; + } + + BasicWriter &operator<<( + typename internal::WCharHelper::Supported value) + { + buffer_.push_back(value); + return *this; + } + + /** + \rst + Writes *value* to the stream. + \endrst + */ + BasicWriter &operator<<(fmt::BasicStringRef value) + { + const Char *str = value.data(); + buffer_.append(str, str + value.size()); + return *this; + } + + BasicWriter &operator<<( + typename internal::WCharHelper::Supported value) + { + const char *str = value.data(); + buffer_.append(str, str + value.size()); + return *this; + } + + template + BasicWriter &operator<<(IntFormatSpec spec) + { + internal::CharTraits::convert(FillChar()); + write_int(spec.value(), spec); + return *this; + } + + template + BasicWriter &operator<<(const StrFormatSpec &spec) + { + const StrChar *s = spec.str(); + write_str(s, std::char_traits::length(s), spec); + return *this; + } + + void clear() FMT_NOEXCEPT { buffer_.clear(); } + + Buffer &buffer() FMT_NOEXCEPT { return buffer_; } +}; + +template +template +typename BasicWriter::CharPtr BasicWriter::write_str( + const StrChar *s, std::size_t size, const AlignSpec &spec) +{ + CharPtr out = CharPtr(); + if (spec.width() > size) + { + out = grow_buffer(spec.width()); + Char fill = internal::CharTraits::cast(spec.fill()); + if (spec.align() == ALIGN_RIGHT) + { + std::uninitialized_fill_n(out, spec.width() - size, fill); + out += spec.width() - size; + } + else if (spec.align() == ALIGN_CENTER) + { + out = fill_padding(out, spec.width(), size, fill); + } + else + { + std::uninitialized_fill_n(out + size, spec.width() - size, fill); + } + } + else + { + out = grow_buffer(size); + } + std::uninitialized_copy(s, s + size, out); + return out; +} + +template +template +void BasicWriter::write_str( + const internal::Arg::StringValue &s, const FormatSpec &spec) +{ + // Check if StrChar is convertible to Char. + internal::CharTraits::convert(StrChar()); + if (spec.type_ && spec.type_ != 's') + internal::report_unknown_type(spec.type_, "string"); + const StrChar *str_value = s.value; + std::size_t str_size = s.size; + if (str_size == 0) + { + if (!str_value) + { + FMT_THROW(FormatError("string pointer is null")); + } + } + std::size_t precision = static_cast(spec.precision_); + if (spec.precision_ >= 0 && precision < str_size) + str_size = precision; + write_str(str_value, str_size, spec); +} + +template +typename BasicWriter::CharPtr +BasicWriter::fill_padding( + CharPtr buffer, unsigned total_size, + std::size_t content_size, wchar_t fill) +{ + std::size_t padding = total_size - content_size; + std::size_t left_padding = padding / 2; + Char fill_char = internal::CharTraits::cast(fill); + std::uninitialized_fill_n(buffer, left_padding, fill_char); + buffer += left_padding; + CharPtr content = buffer; + std::uninitialized_fill_n(buffer + content_size, + padding - left_padding, fill_char); + return content; +} + +template +template +typename BasicWriter::CharPtr +BasicWriter::prepare_int_buffer( + unsigned num_digits, const Spec &spec, + const char *prefix, unsigned prefix_size) +{ + unsigned width = spec.width(); + Alignment align = spec.align(); + Char fill = internal::CharTraits::cast(spec.fill()); + if (spec.precision() > static_cast(num_digits)) + { + // Octal prefix '0' is counted as a digit, so ignore it if precision + // is specified. + if (prefix_size > 0 && prefix[prefix_size - 1] == '0') + --prefix_size; + unsigned number_size = + prefix_size + internal::to_unsigned(spec.precision()); + AlignSpec subspec(number_size, '0', ALIGN_NUMERIC); + if (number_size >= width) + return prepare_int_buffer(num_digits, subspec, prefix, prefix_size); + buffer_.reserve(width); + unsigned fill_size = width - number_size; + if (align != ALIGN_LEFT) + { + CharPtr p = grow_buffer(fill_size); + std::uninitialized_fill(p, p + fill_size, fill); + } + CharPtr result = prepare_int_buffer( + num_digits, subspec, prefix, prefix_size); + if (align == ALIGN_LEFT) + { + CharPtr p = grow_buffer(fill_size); + std::uninitialized_fill(p, p + fill_size, fill); + } + return result; + } + unsigned size = prefix_size + num_digits; + if (width <= size) + { + CharPtr p = grow_buffer(size); + std::uninitialized_copy(prefix, prefix + prefix_size, p); + return p + size - 1; + } + CharPtr p = grow_buffer(width); + CharPtr end = p + width; + if (align == ALIGN_LEFT) + { + std::uninitialized_copy(prefix, prefix + prefix_size, p); + p += size; + std::uninitialized_fill(p, end, fill); + } + else if (align == ALIGN_CENTER) + { + p = fill_padding(p, width, size, fill); + std::uninitialized_copy(prefix, prefix + prefix_size, p); + p += size; + } + else + { + if (align == ALIGN_NUMERIC) + { + if (prefix_size != 0) + { + p = std::uninitialized_copy(prefix, prefix + prefix_size, p); + size -= prefix_size; + } + } + else + { + std::uninitialized_copy(prefix, prefix + prefix_size, end - size); + } + std::uninitialized_fill(p, end - size, fill); + p = end; + } + return p - 1; +} + +template +template +void BasicWriter::write_int(T value, Spec spec) +{ + unsigned prefix_size = 0; + typedef typename internal::IntTraits::MainType UnsignedType; + UnsignedType abs_value = static_cast(value); + char prefix[4] = ""; + if (internal::is_negative(value)) + { + prefix[0] = '-'; + ++prefix_size; + abs_value = 0 - abs_value; + } + else if (spec.flag(SIGN_FLAG)) + { + prefix[0] = spec.flag(PLUS_FLAG) ? '+' : ' '; + ++prefix_size; + } + switch (spec.type()) + { + case 0: + case 'd': + { + unsigned num_digits = internal::count_digits(abs_value); + CharPtr p = prepare_int_buffer(num_digits, spec, prefix, prefix_size) + 1; + internal::format_decimal(get(p), abs_value, 0); + break; + } + case 'x': + case 'X': + { + UnsignedType n = abs_value; + if (spec.flag(HASH_FLAG)) + { + prefix[prefix_size++] = '0'; + prefix[prefix_size++] = spec.type(); + } + unsigned num_digits = 0; + do + { + ++num_digits; + } + while ((n >>= 4) != 0); + Char *p = get(prepare_int_buffer( + num_digits, spec, prefix, prefix_size)); + n = abs_value; + const char *digits = spec.type() == 'x' ? + "0123456789abcdef" : "0123456789ABCDEF"; + do + { + *p-- = digits[n & 0xf]; + } + while ((n >>= 4) != 0); + break; + } + case 'b': + case 'B': + { + UnsignedType n = abs_value; + if (spec.flag(HASH_FLAG)) + { + prefix[prefix_size++] = '0'; + prefix[prefix_size++] = spec.type(); + } + unsigned num_digits = 0; + do + { + ++num_digits; + } + while ((n >>= 1) != 0); + Char *p = get(prepare_int_buffer(num_digits, spec, prefix, prefix_size)); + n = abs_value; + do + { + *p-- = static_cast('0' + (n & 1)); + } + while ((n >>= 1) != 0); + break; + } + case 'o': + { + UnsignedType n = abs_value; + if (spec.flag(HASH_FLAG)) + prefix[prefix_size++] = '0'; + unsigned num_digits = 0; + do + { + ++num_digits; + } + while ((n >>= 3) != 0); + Char *p = get(prepare_int_buffer(num_digits, spec, prefix, prefix_size)); + n = abs_value; + do + { + *p-- = static_cast('0' + (n & 7)); + } + while ((n >>= 3) != 0); + break; + } + case 'n': + { + unsigned num_digits = internal::count_digits(abs_value); + fmt::StringRef sep = ""; +#ifndef ANDROID + sep = internal::thousands_sep(std::localeconv()); +#endif + unsigned size = static_cast( + num_digits + sep.size() * ((num_digits - 1) / 3)); + CharPtr p = prepare_int_buffer(size, spec, prefix, prefix_size) + 1; + internal::format_decimal(get(p), abs_value, 0, internal::ThousandsSep(sep)); + break; + } + default: + internal::report_unknown_type( + spec.type(), spec.flag(CHAR_FLAG) ? "char" : "integer"); + break; + } +} + +template +template +void BasicWriter::write_double(T value, const FormatSpec &spec) +{ + // Check type. + char type = spec.type(); + bool upper = false; + switch (type) + { + case 0: + type = 'g'; + break; + case 'e': + case 'f': + case 'g': + case 'a': + break; + case 'F': +#if FMT_MSC_VER + // MSVC's printf doesn't support 'F'. + type = 'f'; +#endif + // Fall through. + case 'E': + case 'G': + case 'A': + upper = true; + break; + default: + internal::report_unknown_type(type, "double"); + break; + } + + char sign = 0; + // Use isnegative instead of value < 0 because the latter is always + // false for NaN. + if (internal::FPUtil::isnegative(static_cast(value))) + { + sign = '-'; + value = -value; + } + else if (spec.flag(SIGN_FLAG)) + { + sign = spec.flag(PLUS_FLAG) ? '+' : ' '; + } + + if (internal::FPUtil::isnotanumber(value)) + { + // Format NaN ourselves because sprintf's output is not consistent + // across platforms. + std::size_t nan_size = 4; + const char *nan = upper ? " NAN" : " nan"; + if (!sign) + { + --nan_size; + ++nan; + } + CharPtr out = write_str(nan, nan_size, spec); + if (sign) + *out = sign; + return; + } + + if (internal::FPUtil::isinfinity(value)) + { + // Format infinity ourselves because sprintf's output is not consistent + // across platforms. + std::size_t inf_size = 4; + const char *inf = upper ? " INF" : " inf"; + if (!sign) + { + --inf_size; + ++inf; + } + CharPtr out = write_str(inf, inf_size, spec); + if (sign) + *out = sign; + return; + } + + std::size_t offset = buffer_.size(); + unsigned width = spec.width(); + if (sign) + { + buffer_.reserve(buffer_.size() + (width > 1u ? width : 1u)); + if (width > 0) + --width; + ++offset; + } + + // Build format string. + enum { MAX_FORMAT_SIZE = 10}; // longest format: %#-*.*Lg + Char format[MAX_FORMAT_SIZE]; + Char *format_ptr = format; + *format_ptr++ = '%'; + unsigned width_for_sprintf = width; + if (spec.flag(HASH_FLAG)) + *format_ptr++ = '#'; + if (spec.align() == ALIGN_CENTER) + { + width_for_sprintf = 0; + } + else + { + if (spec.align() == ALIGN_LEFT) + *format_ptr++ = '-'; + if (width != 0) + *format_ptr++ = '*'; + } + if (spec.precision() >= 0) + { + *format_ptr++ = '.'; + *format_ptr++ = '*'; + } + + append_float_length(format_ptr, value); + *format_ptr++ = type; + *format_ptr = '\0'; + + // Format using snprintf. + Char fill = internal::CharTraits::cast(spec.fill()); + unsigned n = 0; + Char *start = 0; + for (;;) + { + std::size_t buffer_size = buffer_.capacity() - offset; +#if FMT_MSC_VER + // MSVC's vsnprintf_s doesn't work with zero size, so reserve + // space for at least one extra character to make the size non-zero. + // Note that the buffer's capacity will increase by more than 1. + if (buffer_size == 0) + { + buffer_.reserve(offset + 1); + buffer_size = buffer_.capacity() - offset; + } +#endif + start = &buffer_[offset]; + int result = internal::CharTraits::format_float( + start, buffer_size, format, width_for_sprintf, spec.precision(), value); + if (result >= 0) + { + n = internal::to_unsigned(result); + if (offset + n < buffer_.capacity()) + break; // The buffer is large enough - continue with formatting. + buffer_.reserve(offset + n + 1); + } + else + { + // If result is negative we ask to increase the capacity by at least 1, + // but as std::vector, the buffer grows exponentially. + buffer_.reserve(buffer_.capacity() + 1); + } + } + if (sign) + { + if ((spec.align() != ALIGN_RIGHT && spec.align() != ALIGN_DEFAULT) || + *start != ' ') + { + *(start - 1) = sign; + sign = 0; + } + else + { + *(start - 1) = fill; + } + ++n; + } + if (spec.align() == ALIGN_CENTER && spec.width() > n) + { + width = spec.width(); + CharPtr p = grow_buffer(width); + std::memmove(get(p) + (width - n) / 2, get(p), n * sizeof(Char)); + fill_padding(p, spec.width(), n, fill); + return; + } + if (spec.fill() != ' ' || sign) + { + while (*start == ' ') + *start++ = fill; + if (sign) + *(start - 1) = sign; + } + grow_buffer(n); +} + +/** + \rst + This class template provides operations for formatting and writing data + into a character stream. The output is stored in a memory buffer that grows + dynamically. + + You can use one of the following typedefs for common character types + and the standard allocator: + + +---------------+-----------------------------------------------------+ + | Type | Definition | + +===============+=====================================================+ + | MemoryWriter | BasicMemoryWriter> | + +---------------+-----------------------------------------------------+ + | WMemoryWriter | BasicMemoryWriter> | + +---------------+-----------------------------------------------------+ + + **Example**:: + + MemoryWriter out; + out << "The answer is " << 42 << "\n"; + out.write("({:+f}, {:+f})", -3.14, 3.14); + + This will write the following output to the ``out`` object: + + .. code-block:: none + + The answer is 42 + (-3.140000, +3.140000) + + The output can be converted to an ``std::string`` with ``out.str()`` or + accessed as a C string with ``out.c_str()``. + \endrst + */ +template > +class BasicMemoryWriter : public BasicWriter +{ +private: + internal::MemoryBuffer buffer_; + +public: + explicit BasicMemoryWriter(const Allocator& alloc = Allocator()) + : BasicWriter(buffer_), buffer_(alloc) {} + +#if FMT_USE_RVALUE_REFERENCES + /** + \rst + Constructs a :class:`fmt::BasicMemoryWriter` object moving the content + of the other object to it. + \endrst + */ + BasicMemoryWriter(BasicMemoryWriter &&other) + : BasicWriter(buffer_), buffer_(std::move(other.buffer_)) + { + } + + /** + \rst + Moves the content of the other ``BasicMemoryWriter`` object to this one. + \endrst + */ + BasicMemoryWriter &operator=(BasicMemoryWriter &&other) + { + buffer_ = std::move(other.buffer_); + return *this; + } +#endif +}; + +typedef BasicMemoryWriter MemoryWriter; +typedef BasicMemoryWriter WMemoryWriter; + +/** + \rst + This class template provides operations for formatting and writing data + into a fixed-size array. For writing into a dynamically growing buffer + use :class:`fmt::BasicMemoryWriter`. + + Any write method will throw ``std::runtime_error`` if the output doesn't fit + into the array. + + You can use one of the following typedefs for common character types: + + +--------------+---------------------------+ + | Type | Definition | + +==============+===========================+ + | ArrayWriter | BasicArrayWriter | + +--------------+---------------------------+ + | WArrayWriter | BasicArrayWriter | + +--------------+---------------------------+ + \endrst + */ +template +class BasicArrayWriter : public BasicWriter +{ +private: + internal::FixedBuffer buffer_; + +public: + /** + \rst + Constructs a :class:`fmt::BasicArrayWriter` object for *array* of the + given size. + \endrst + */ + BasicArrayWriter(Char *array, std::size_t size) + : BasicWriter(buffer_), buffer_(array, size) {} + + /** + \rst + Constructs a :class:`fmt::BasicArrayWriter` object for *array* of the + size known at compile time. + \endrst + */ + template + explicit BasicArrayWriter(Char (&array)[SIZE]) + : BasicWriter(buffer_), buffer_(array, SIZE) {} +}; + +typedef BasicArrayWriter ArrayWriter; +typedef BasicArrayWriter WArrayWriter; + +// Reports a system error without throwing an exception. +// Can be used to report errors from destructors. +FMT_API void report_system_error(int error_code, + StringRef message) FMT_NOEXCEPT; + +#if FMT_USE_WINDOWS_H + +/** A Windows error. */ +class WindowsError : public SystemError +{ +private: + FMT_API void init(int error_code, CStringRef format_str, ArgList args); + +public: + /** + \rst + Constructs a :class:`fmt::WindowsError` object with the description + of the form + + .. parsed-literal:: + **: ** + + where ** is the formatted message and ** is the + system message corresponding to the error code. + *error_code* is a Windows error code as given by ``GetLastError``. + If *error_code* is not a valid error code such as -1, the system message + will look like "error -1". + + **Example**:: + + // This throws a WindowsError with the description + // cannot open file 'madeup': The system cannot find the file specified. + // or similar (system message may vary). + const char *filename = "madeup"; + LPOFSTRUCT of = LPOFSTRUCT(); + HFILE file = OpenFile(filename, &of, OF_READ); + if (file == HFILE_ERROR) { + throw fmt::WindowsError(GetLastError(), + "cannot open file '{}'", filename); + } + \endrst + */ + WindowsError(int error_code, CStringRef message) + { + init(error_code, message, ArgList()); + } + FMT_VARIADIC_CTOR(WindowsError, init, int, CStringRef) +}; + +// Reports a Windows error without throwing an exception. +// Can be used to report errors from destructors. +FMT_API void report_windows_error(int error_code, + StringRef message) FMT_NOEXCEPT; + +#endif + +enum Color { BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE }; + +/** + Formats a string and prints it to stdout using ANSI escape sequences + to specify color (experimental). + Example: + print_colored(fmt::RED, "Elapsed time: {0:.2f} seconds", 1.23); + */ +FMT_API void print_colored(Color c, CStringRef format, ArgList args); + +/** + \rst + Formats arguments and returns the result as a string. + + **Example**:: + + std::string message = format("The answer is {}", 42); + \endrst +*/ +inline std::string format(CStringRef format_str, ArgList args) +{ + MemoryWriter w; + w.write(format_str, args); + return w.str(); +} + +inline std::wstring format(WCStringRef format_str, ArgList args) +{ + WMemoryWriter w; + w.write(format_str, args); + return w.str(); +} + +/** + \rst + Prints formatted data to the file *f*. + + **Example**:: + + print(stderr, "Don't {}!", "panic"); + \endrst + */ +FMT_API void print(std::FILE *f, CStringRef format_str, ArgList args); + +/** + \rst + Prints formatted data to ``stdout``. + + **Example**:: + + print("Elapsed time: {0:.2f} seconds", 1.23); + \endrst + */ +FMT_API void print(CStringRef format_str, ArgList args); + +template +void printf(BasicWriter &w, BasicCStringRef format, ArgList args) +{ + internal::PrintfFormatter(args).format(w, format); +} + +/** + \rst + Formats arguments and returns the result as a string. + + **Example**:: + + std::string message = fmt::sprintf("The answer is %d", 42); + \endrst +*/ +inline std::string sprintf(CStringRef format, ArgList args) +{ + MemoryWriter w; + printf(w, format, args); + return w.str(); +} + +inline std::wstring sprintf(WCStringRef format, ArgList args) +{ + WMemoryWriter w; + printf(w, format, args); + return w.str(); +} + +/** + \rst + Prints formatted data to the file *f*. + + **Example**:: + + fmt::fprintf(stderr, "Don't %s!", "panic"); + \endrst + */ +FMT_API int fprintf(std::FILE *f, CStringRef format, ArgList args); + +/** + \rst + Prints formatted data to ``stdout``. + + **Example**:: + + fmt::printf("Elapsed time: %.2f seconds", 1.23); + \endrst + */ +inline int printf(CStringRef format, ArgList args) +{ + return fprintf(stdout, format, args); +} + +/** + Fast integer formatter. + */ +class FormatInt +{ +private: + // Buffer should be large enough to hold all digits (digits10 + 1), + // a sign and a null character. + enum {BUFFER_SIZE = std::numeric_limits::digits10 + 3}; + mutable char buffer_[BUFFER_SIZE]; + char *str_; + + // Formats value in reverse and returns the number of digits. + char *format_decimal(ULongLong value) + { + char *buffer_end = buffer_ + BUFFER_SIZE - 1; + while (value >= 100) + { + // Integer division is slow so do it for a group of two digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + unsigned index = static_cast((value % 100) * 2); + value /= 100; + *--buffer_end = internal::Data::DIGITS[index + 1]; + *--buffer_end = internal::Data::DIGITS[index]; + } + if (value < 10) + { + *--buffer_end = static_cast('0' + value); + return buffer_end; + } + unsigned index = static_cast(value * 2); + *--buffer_end = internal::Data::DIGITS[index + 1]; + *--buffer_end = internal::Data::DIGITS[index]; + return buffer_end; + } + + void FormatSigned(LongLong value) + { + ULongLong abs_value = static_cast(value); + bool negative = value < 0; + if (negative) + abs_value = 0 - abs_value; + str_ = format_decimal(abs_value); + if (negative) + *--str_ = '-'; + } + +public: + explicit FormatInt(int value) + { + FormatSigned(value); + } + explicit FormatInt(long value) + { + FormatSigned(value); + } + explicit FormatInt(LongLong value) + { + FormatSigned(value); + } + explicit FormatInt(unsigned value) : str_(format_decimal(value)) {} + explicit FormatInt(unsigned long value) : str_(format_decimal(value)) {} + explicit FormatInt(ULongLong value) : str_(format_decimal(value)) {} + + /** Returns the number of characters written to the output buffer. */ + std::size_t size() const + { + return internal::to_unsigned(buffer_ - str_ + BUFFER_SIZE - 1); + } + + /** + Returns a pointer to the output buffer content. No terminating null + character is appended. + */ + const char *data() const + { + return str_; + } + + /** + Returns a pointer to the output buffer content with terminating null + character appended. + */ + const char *c_str() const + { + buffer_[BUFFER_SIZE - 1] = '\0'; + return str_; + } + + /** + \rst + Returns the content of the output buffer as an ``std::string``. + \endrst + */ + std::string str() const + { + return std::string(str_, size()); + } +}; + +// Formats a decimal integer value writing into buffer and returns +// a pointer to the end of the formatted string. This function doesn't +// write a terminating null character. +template +inline void format_decimal(char *&buffer, T value) +{ + typedef typename internal::IntTraits::MainType MainType; + MainType abs_value = static_cast(value); + if (internal::is_negative(value)) + { + *buffer++ = '-'; + abs_value = 0 - abs_value; + } + if (abs_value < 100) + { + if (abs_value < 10) + { + *buffer++ = static_cast('0' + abs_value); + return; + } + unsigned index = static_cast(abs_value * 2); + *buffer++ = internal::Data::DIGITS[index]; + *buffer++ = internal::Data::DIGITS[index + 1]; + return; + } + unsigned num_digits = internal::count_digits(abs_value); + internal::format_decimal(buffer, abs_value, num_digits); + buffer += num_digits; +} + +/** + \rst + Returns a named argument for formatting functions. + + **Example**:: + + print("Elapsed time: {s:.2f} seconds", arg("s", 1.23)); + + \endrst + */ +template +inline internal::NamedArg arg(StringRef name, const T &arg) +{ + return internal::NamedArg(name, arg); +} + +template +inline internal::NamedArg arg(WStringRef name, const T &arg) +{ + return internal::NamedArg(name, arg); +} + +// The following two functions are deleted intentionally to disable +// nested named arguments as in ``format("{}", arg("a", arg("b", 42)))``. +template +void arg(StringRef, const internal::NamedArg&) FMT_DELETED_OR_UNDEFINED; +template +void arg(WStringRef, const internal::NamedArg&) FMT_DELETED_OR_UNDEFINED; +} + +#if FMT_GCC_VERSION +// Use the system_header pragma to suppress warnings about variadic macros +// because suppressing -Wvariadic-macros with the diagnostic pragma doesn't +// work. It is used at the end because we want to suppress as little warnings +// as possible. +# pragma GCC system_header +#endif + +// This is used to work around VC++ bugs in handling variadic macros. +#define FMT_EXPAND(args) args + +// Returns the number of arguments. +// Based on https://groups.google.com/forum/#!topic/comp.std.c/d-6Mj5Lko_s. +#define FMT_NARG(...) FMT_NARG_(__VA_ARGS__, FMT_RSEQ_N()) +#define FMT_NARG_(...) FMT_EXPAND(FMT_ARG_N(__VA_ARGS__)) +#define FMT_ARG_N(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N +#define FMT_RSEQ_N() 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 + +#define FMT_CONCAT(a, b) a##b +#define FMT_FOR_EACH_(N, f, ...) \ + FMT_EXPAND(FMT_CONCAT(FMT_FOR_EACH, N)(f, __VA_ARGS__)) +#define FMT_FOR_EACH(f, ...) \ + FMT_EXPAND(FMT_FOR_EACH_(FMT_NARG(__VA_ARGS__), f, __VA_ARGS__)) + +#define FMT_ADD_ARG_NAME(type, index) type arg##index +#define FMT_GET_ARG_NAME(type, index) arg##index + +#if FMT_USE_VARIADIC_TEMPLATES +# define FMT_VARIADIC_(Char, ReturnType, func, call, ...) \ + template \ + ReturnType func(FMT_FOR_EACH(FMT_ADD_ARG_NAME, __VA_ARGS__), \ + const Args & ... args) { \ + typedef fmt::internal::ArgArray ArgArray; \ + typename ArgArray::Type array{ \ + ArgArray::template make >(args)...}; \ + call(FMT_FOR_EACH(FMT_GET_ARG_NAME, __VA_ARGS__), \ + fmt::ArgList(fmt::internal::make_type(args...), array)); \ + } +#else +// Defines a wrapper for a function taking __VA_ARGS__ arguments +// and n additional arguments of arbitrary types. +# define FMT_WRAP(Char, ReturnType, func, call, n, ...) \ + template \ + inline ReturnType func(FMT_FOR_EACH(FMT_ADD_ARG_NAME, __VA_ARGS__), \ + FMT_GEN(n, FMT_MAKE_ARG)) { \ + fmt::internal::ArgArray::Type arr; \ + FMT_GEN(n, FMT_ASSIGN_##Char); \ + call(FMT_FOR_EACH(FMT_GET_ARG_NAME, __VA_ARGS__), fmt::ArgList( \ + fmt::internal::make_type(FMT_GEN(n, FMT_MAKE_REF2)), arr)); \ + } + +# define FMT_VARIADIC_(Char, ReturnType, func, call, ...) \ + inline ReturnType func(FMT_FOR_EACH(FMT_ADD_ARG_NAME, __VA_ARGS__)) { \ + call(FMT_FOR_EACH(FMT_GET_ARG_NAME, __VA_ARGS__), fmt::ArgList()); \ + } \ + FMT_WRAP(Char, ReturnType, func, call, 1, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 2, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 3, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 4, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 5, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 6, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 7, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 8, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 9, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 10, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 11, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 12, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 13, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 14, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 15, __VA_ARGS__) +#endif // FMT_USE_VARIADIC_TEMPLATES + +/** + \rst + Defines a variadic function with the specified return type, function name + and argument types passed as variable arguments to this macro. + + **Example**:: + + void print_error(const char *file, int line, const char *format, + fmt::ArgList args) { + fmt::print("{}: {}: ", file, line); + fmt::print(format, args); + } + FMT_VARIADIC(void, print_error, const char *, int, const char *) + + ``FMT_VARIADIC`` is used for compatibility with legacy C++ compilers that + don't implement variadic templates. You don't have to use this macro if + you don't need legacy compiler support and can use variadic templates + directly:: + + template + void print_error(const char *file, int line, const char *format, + const Args & ... args) { + fmt::print("{}: {}: ", file, line); + fmt::print(format, args...); + } + \endrst + */ +#define FMT_VARIADIC(ReturnType, func, ...) \ + FMT_VARIADIC_(char, ReturnType, func, return func, __VA_ARGS__) + +#define FMT_VARIADIC_W(ReturnType, func, ...) \ + FMT_VARIADIC_(wchar_t, ReturnType, func, return func, __VA_ARGS__) + +#define FMT_CAPTURE_ARG_(id, index) ::fmt::arg(#id, id) + +#define FMT_CAPTURE_ARG_W_(id, index) ::fmt::arg(L###id, id) + +/** + \rst + Convenient macro to capture the arguments' names and values into several + ``fmt::arg(name, value)``. + + **Example**:: + + int x = 1, y = 2; + print("point: ({x}, {y})", FMT_CAPTURE(x, y)); + // same as: + // print("point: ({x}, {y})", arg("x", x), arg("y", y)); + + \endrst + */ +#define FMT_CAPTURE(...) FMT_FOR_EACH(FMT_CAPTURE_ARG_, __VA_ARGS__) + +#define FMT_CAPTURE_W(...) FMT_FOR_EACH(FMT_CAPTURE_ARG_W_, __VA_ARGS__) + +namespace fmt +{ +FMT_VARIADIC(std::string, format, CStringRef) +FMT_VARIADIC_W(std::wstring, format, WCStringRef) +FMT_VARIADIC(void, print, CStringRef) +FMT_VARIADIC(void, print, std::FILE *, CStringRef) + +FMT_VARIADIC(void, print_colored, Color, CStringRef) +FMT_VARIADIC(std::string, sprintf, CStringRef) +FMT_VARIADIC_W(std::wstring, sprintf, WCStringRef) +FMT_VARIADIC(int, printf, CStringRef) +FMT_VARIADIC(int, fprintf, std::FILE *, CStringRef) + +namespace internal +{ +template +inline bool is_name_start(Char c) +{ + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || '_' == c; +} + +// Parses an unsigned integer advancing s to the end of the parsed input. +// This function assumes that the first character of s is a digit. +template +unsigned parse_nonnegative_int(const Char *&s) +{ + assert('0' <= *s && *s <= '9'); + unsigned value = 0; + do + { + unsigned new_value = value * 10 + (*s++ - '0'); + // Check if value wrapped around. + if (new_value < value) + { + value = (std::numeric_limits::max)(); + break; + } + value = new_value; + } + while ('0' <= *s && *s <= '9'); + // Convert to unsigned to prevent a warning. + unsigned max_int = (std::numeric_limits::max)(); + if (value > max_int) + FMT_THROW(FormatError("number is too big")); + return value; +} + +inline void require_numeric_argument(const Arg &arg, char spec) +{ + if (arg.type > Arg::LAST_NUMERIC_TYPE) + { + std::string message = + fmt::format("format specifier '{}' requires numeric argument", spec); + FMT_THROW(fmt::FormatError(message)); + } +} + +template +void check_sign(const Char *&s, const Arg &arg) +{ + char sign = static_cast(*s); + require_numeric_argument(arg, sign); + if (arg.type == Arg::UINT || arg.type == Arg::ULONG_LONG) + { + FMT_THROW(FormatError(fmt::format( + "format specifier '{}' requires signed argument", sign))); + } + ++s; +} +} // namespace internal + +template +inline internal::Arg BasicFormatter::get_arg( + BasicStringRef arg_name, const char *&error) +{ + if (check_no_auto_index(error)) + { + map_.init(args()); + const internal::Arg *arg = map_.find(arg_name); + if (arg) + return *arg; + error = "argument not found"; + } + return internal::Arg(); +} + +template +inline internal::Arg BasicFormatter::parse_arg_index(const Char *&s) +{ + const char *error = 0; + internal::Arg arg = *s < '0' || *s > '9' ? + next_arg(error) : get_arg(internal::parse_nonnegative_int(s), error); + if (error) + { + FMT_THROW(FormatError( + *s != '}' && *s != ':' ? "invalid format string" : error)); + } + return arg; +} + +template +inline internal::Arg BasicFormatter::parse_arg_name(const Char *&s) +{ + assert(internal::is_name_start(*s)); + const Char *start = s; + Char c; + do + { + c = *++s; + } + while (internal::is_name_start(c) || ('0' <= c && c <= '9')); + const char *error = 0; + internal::Arg arg = get_arg(BasicStringRef(start, s - start), error); + if (error) + FMT_THROW(FormatError(error)); + return arg; +} + +template +const Char *BasicFormatter::format( + const Char *&format_str, const internal::Arg &arg) +{ + using internal::Arg; + const Char *s = format_str; + FormatSpec spec; + if (*s == ':') + { + if (arg.type == Arg::CUSTOM) + { + arg.custom.format(this, arg.custom.value, &s); + return s; + } + ++s; + // Parse fill and alignment. + if (Char c = *s) + { + const Char *p = s + 1; + spec.align_ = ALIGN_DEFAULT; + do + { + switch (*p) + { + case '<': + spec.align_ = ALIGN_LEFT; + break; + case '>': + spec.align_ = ALIGN_RIGHT; + break; + case '=': + spec.align_ = ALIGN_NUMERIC; + break; + case '^': + spec.align_ = ALIGN_CENTER; + break; + } + if (spec.align_ != ALIGN_DEFAULT) + { + if (p != s) + { + if (c == '}') break; + if (c == '{') + FMT_THROW(FormatError("invalid fill character '{'")); + s += 2; + spec.fill_ = c; + } + else ++s; + if (spec.align_ == ALIGN_NUMERIC) + require_numeric_argument(arg, '='); + break; + } + } + while (--p >= s); + } + + // Parse sign. + switch (*s) + { + case '+': + check_sign(s, arg); + spec.flags_ |= SIGN_FLAG | PLUS_FLAG; + break; + case '-': + check_sign(s, arg); + spec.flags_ |= MINUS_FLAG; + break; + case ' ': + check_sign(s, arg); + spec.flags_ |= SIGN_FLAG; + break; + } + + if (*s == '#') + { + require_numeric_argument(arg, '#'); + spec.flags_ |= HASH_FLAG; + ++s; + } + + // Parse zero flag. + if (*s == '0') + { + require_numeric_argument(arg, '0'); + spec.align_ = ALIGN_NUMERIC; + spec.fill_ = '0'; + ++s; + } + + // Parse width. + if ('0' <= *s && *s <= '9') + { + spec.width_ = internal::parse_nonnegative_int(s); + } + else if (*s == '{') + { + ++s; + Arg width_arg = internal::is_name_start(*s) ? + parse_arg_name(s) : parse_arg_index(s); + if (*s++ != '}') + FMT_THROW(FormatError("invalid format string")); + ULongLong value = 0; + switch (width_arg.type) + { + case Arg::INT: + if (width_arg.int_value < 0) + FMT_THROW(FormatError("negative width")); + value = width_arg.int_value; + break; + case Arg::UINT: + value = width_arg.uint_value; + break; + case Arg::LONG_LONG: + if (width_arg.long_long_value < 0) + FMT_THROW(FormatError("negative width")); + value = width_arg.long_long_value; + break; + case Arg::ULONG_LONG: + value = width_arg.ulong_long_value; + break; + default: + FMT_THROW(FormatError("width is not integer")); + } + if (value > (std::numeric_limits::max)()) + FMT_THROW(FormatError("number is too big")); + spec.width_ = static_cast(value); + } + + // Parse precision. + if (*s == '.') + { + ++s; + spec.precision_ = 0; + if ('0' <= *s && *s <= '9') + { + spec.precision_ = internal::parse_nonnegative_int(s); + } + else if (*s == '{') + { + ++s; + Arg precision_arg = internal::is_name_start(*s) ? + parse_arg_name(s) : parse_arg_index(s); + if (*s++ != '}') + FMT_THROW(FormatError("invalid format string")); + ULongLong value = 0; + switch (precision_arg.type) + { + case Arg::INT: + if (precision_arg.int_value < 0) + FMT_THROW(FormatError("negative precision")); + value = precision_arg.int_value; + break; + case Arg::UINT: + value = precision_arg.uint_value; + break; + case Arg::LONG_LONG: + if (precision_arg.long_long_value < 0) + FMT_THROW(FormatError("negative precision")); + value = precision_arg.long_long_value; + break; + case Arg::ULONG_LONG: + value = precision_arg.ulong_long_value; + break; + default: + FMT_THROW(FormatError("precision is not integer")); + } + if (value > (std::numeric_limits::max)()) + FMT_THROW(FormatError("number is too big")); + spec.precision_ = static_cast(value); + } + else + { + FMT_THROW(FormatError("missing precision specifier")); + } + if (arg.type <= Arg::LAST_INTEGER_TYPE || arg.type == Arg::POINTER) + { + FMT_THROW(FormatError( + fmt::format("precision not allowed in {} format specifier", + arg.type == Arg::POINTER ? "pointer" : "integer"))); + } + } + + // Parse type. + if (*s != '}' && *s) + spec.type_ = static_cast(*s++); + } + + if (*s++ != '}') + FMT_THROW(FormatError("missing '}' in format string")); + + // Format argument. + ArgFormatter(*this, spec, s - 1).visit(arg); + return s; +} + +template +void BasicFormatter::format(BasicCStringRef format_str) +{ + const Char *s = format_str.c_str(); + const Char *start = s; + while (*s) + { + Char c = *s++; + if (c != '{' && c != '}') continue; + if (*s == c) + { + write(writer_, start, s); + start = ++s; + continue; + } + if (c == '}') + FMT_THROW(FormatError("unmatched '}' in format string")); + write(writer_, start, s - 1); + internal::Arg arg = internal::is_name_start(*s) ? + parse_arg_name(s) : parse_arg_index(s); + start = s = format(s, arg); + } + write(writer_, start, s); +} +} // namespace fmt + +#if FMT_USE_USER_DEFINED_LITERALS +namespace fmt +{ +namespace internal +{ + +template +struct UdlFormat +{ + const Char *str; + + template + auto operator()(Args && ... args) const + -> decltype(format(str, std::forward(args)...)) + { + return format(str, std::forward(args)...); + } +}; + +template +struct UdlArg +{ + const Char *str; + + template + NamedArg operator=(T &&value) const + { + return {str, std::forward(value)}; + } +}; + +} // namespace internal + +inline namespace literals +{ + +/** + \rst + C++11 literal equivalent of :func:`fmt::format`. + + **Example**:: + + using namespace fmt::literals; + std::string message = "The answer is {}"_format(42); + \endrst + */ +inline internal::UdlFormat +operator"" _format(const char *s, std::size_t) +{ + return {s}; +} +inline internal::UdlFormat +operator"" _format(const wchar_t *s, std::size_t) +{ + return {s}; +} + +/** + \rst + C++11 literal equivalent of :func:`fmt::arg`. + + **Example**:: + + using namespace fmt::literals; + print("Elapsed time: {s:.2f} seconds", "s"_a=1.23); + \endrst + */ +inline internal::UdlArg +operator"" _a(const char *s, std::size_t) +{ + return {s}; +} +inline internal::UdlArg +operator"" _a(const wchar_t *s, std::size_t) +{ + return {s}; +} + +} // inline namespace literals +} // namespace fmt +#endif // FMT_USE_USER_DEFINED_LITERALS + +// Restore warnings. +#if FMT_GCC_VERSION >= 406 +# pragma GCC diagnostic pop +#endif + +#if defined(__clang__) && !defined(FMT_ICC_VERSION) +# pragma clang diagnostic pop +#endif + +#ifdef FMT_HEADER_ONLY +# define FMT_FUNC inline +# include "format.cc" +#else +# define FMT_FUNC +#endif + +#endif // FMT_FORMAT_H_ diff --git a/src/dionysus/wasserstein/spdlog/fmt/bundled/ostream.cc b/src/dionysus/wasserstein/spdlog/fmt/bundled/ostream.cc new file mode 100755 index 0000000..bcb67fe --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/fmt/bundled/ostream.cc @@ -0,0 +1,43 @@ +/* + Formatting library for C++ - std::ostream support + + Copyright (c) 2012 - 2016, Victor Zverovich + All rights reserved. + + For the license information refer to format.h. + */ + +#include "ostream.h" + +namespace fmt { + +namespace { +// Write the content of w to os. +void write(std::ostream &os, Writer &w) { + const char *data = w.data(); + typedef internal::MakeUnsigned::Type UnsignedStreamSize; + UnsignedStreamSize size = w.size(); + UnsignedStreamSize max_size = + internal::to_unsigned((std::numeric_limits::max)()); + do { + UnsignedStreamSize n = size <= max_size ? size : max_size; + os.write(data, static_cast(n)); + data += n; + size -= n; + } while (size != 0); +} +} + +FMT_FUNC void print(std::ostream &os, CStringRef format_str, ArgList args) { + MemoryWriter w; + w.write(format_str, args); + write(os, w); +} + +FMT_FUNC int fprintf(std::ostream &os, CStringRef format, ArgList args) { + MemoryWriter w; + printf(w, format, args); + write(os, w); + return static_cast(w.size()); +} +} // namespace fmt diff --git a/src/dionysus/wasserstein/spdlog/fmt/bundled/ostream.h b/src/dionysus/wasserstein/spdlog/fmt/bundled/ostream.h new file mode 100755 index 0000000..c52646d --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/fmt/bundled/ostream.h @@ -0,0 +1,126 @@ +/* + Formatting library for C++ - std::ostream support + + Copyright (c) 2012 - 2016, Victor Zverovich + All rights reserved. + + For the license information refer to format.h. + */ + +#ifndef FMT_OSTREAM_H_ +#define FMT_OSTREAM_H_ + +#include "format.h" +#include + +namespace fmt +{ + +namespace internal +{ + +template +class FormatBuf : public std::basic_streambuf +{ +private: + typedef typename std::basic_streambuf::int_type int_type; + typedef typename std::basic_streambuf::traits_type traits_type; + + Buffer &buffer_; + Char *start_; + +public: + FormatBuf(Buffer &buffer) : buffer_(buffer), start_(&buffer[0]) + { + this->setp(start_, start_ + buffer_.capacity()); + } + + int_type overflow(int_type ch = traits_type::eof()) + { + if (!traits_type::eq_int_type(ch, traits_type::eof())) + { + size_t buf_size = size(); + buffer_.resize(buf_size); + buffer_.reserve(buf_size * 2); + + start_ = &buffer_[0]; + start_[buf_size] = traits_type::to_char_type(ch); + this->setp(start_+ buf_size + 1, start_ + buf_size * 2); + } + return ch; + } + + size_t size() const + { + return to_unsigned(this->pptr() - start_); + } +}; + +Yes &convert(std::ostream &); + +struct DummyStream : std::ostream +{ + DummyStream(); // Suppress a bogus warning in MSVC. + // Hide all operator<< overloads from std::ostream. + void operator<<(Null<>); +}; + +No &operator<<(std::ostream &, int); + +template +struct ConvertToIntImpl +{ + // Convert to int only if T doesn't have an overloaded operator<<. + enum + { + value = sizeof(convert(get() << get())) == sizeof(No) + }; +}; +} // namespace internal + +// Formats a value. +template +void format(BasicFormatter &f, + const Char *&format_str, const T &value) +{ + internal::MemoryBuffer buffer; + + internal::FormatBuf format_buf(buffer); + std::basic_ostream output(&format_buf); + output << value; + + BasicStringRef str(&buffer[0], format_buf.size()); + typedef internal::MakeArg< BasicFormatter > MakeArg; + format_str = f.format(format_str, MakeArg(str)); +} + +/** + \rst + Prints formatted data to the stream *os*. + + **Example**:: + + print(cerr, "Don't {}!", "panic"); + \endrst + */ +FMT_API void print(std::ostream &os, CStringRef format_str, ArgList args); +FMT_VARIADIC(void, print, std::ostream &, CStringRef) + +/** + \rst + Prints formatted data to the stream *os*. + + **Example**:: + + fprintf(cerr, "Don't %s!", "panic"); + \endrst + */ +FMT_API int fprintf(std::ostream &os, CStringRef format_str, ArgList args); +FMT_VARIADIC(int, fprintf, std::ostream &, CStringRef) +} // namespace fmt + +#ifdef FMT_HEADER_ONLY +# include "ostream.cc" +#endif + +#endif // FMT_OSTREAM_H_ diff --git a/src/dionysus/wasserstein/spdlog/fmt/bundled/posix.cc b/src/dionysus/wasserstein/spdlog/fmt/bundled/posix.cc new file mode 100755 index 0000000..76eb7f0 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/fmt/bundled/posix.cc @@ -0,0 +1,238 @@ +/* + A C++ interface to POSIX functions. + + Copyright (c) 2012 - 2016, Victor Zverovich + All rights reserved. + + For the license information refer to format.h. + */ + +// Disable bogus MSVC warnings. +#ifndef _CRT_SECURE_NO_WARNINGS +# define _CRT_SECURE_NO_WARNINGS +#endif + +#include "posix.h" + +#include +#include +#include + +#ifndef _WIN32 +# include +#else +# include +# include + +# define O_CREAT _O_CREAT +# define O_TRUNC _O_TRUNC + +# ifndef S_IRUSR +# define S_IRUSR _S_IREAD +# endif + +# ifndef S_IWUSR +# define S_IWUSR _S_IWRITE +# endif + +# ifdef __MINGW32__ +# define _SH_DENYNO 0x40 +# endif + +#endif // _WIN32 + +#ifdef fileno +# undef fileno +#endif + +namespace { +#ifdef _WIN32 +// Return type of read and write functions. +typedef int RWResult; + +// On Windows the count argument to read and write is unsigned, so convert +// it from size_t preventing integer overflow. +inline unsigned convert_rwcount(std::size_t count) { + return count <= UINT_MAX ? static_cast(count) : UINT_MAX; +} +#else +// Return type of read and write functions. +typedef ssize_t RWResult; + +inline std::size_t convert_rwcount(std::size_t count) { return count; } +#endif +} + +fmt::BufferedFile::~BufferedFile() FMT_NOEXCEPT { + if (file_ && FMT_SYSTEM(fclose(file_)) != 0) + fmt::report_system_error(errno, "cannot close file"); +} + +fmt::BufferedFile::BufferedFile( + fmt::CStringRef filename, fmt::CStringRef mode) { + FMT_RETRY_VAL(file_, FMT_SYSTEM(fopen(filename.c_str(), mode.c_str())), 0); + if (!file_) + FMT_THROW(SystemError(errno, "cannot open file {}", filename)); +} + +void fmt::BufferedFile::close() { + if (!file_) + return; + int result = FMT_SYSTEM(fclose(file_)); + file_ = 0; + if (result != 0) + FMT_THROW(SystemError(errno, "cannot close file")); +} + +// A macro used to prevent expansion of fileno on broken versions of MinGW. +#define FMT_ARGS + +int fmt::BufferedFile::fileno() const { + int fd = FMT_POSIX_CALL(fileno FMT_ARGS(file_)); + if (fd == -1) + FMT_THROW(SystemError(errno, "cannot get file descriptor")); + return fd; +} + +fmt::File::File(fmt::CStringRef path, int oflag) { + int mode = S_IRUSR | S_IWUSR; +#if defined(_WIN32) && !defined(__MINGW32__) + fd_ = -1; + FMT_POSIX_CALL(sopen_s(&fd_, path.c_str(), oflag, _SH_DENYNO, mode)); +#else + FMT_RETRY(fd_, FMT_POSIX_CALL(open(path.c_str(), oflag, mode))); +#endif + if (fd_ == -1) + FMT_THROW(SystemError(errno, "cannot open file {}", path)); +} + +fmt::File::~File() FMT_NOEXCEPT { + // Don't retry close in case of EINTR! + // See http://linux.derkeiler.com/Mailing-Lists/Kernel/2005-09/3000.html + if (fd_ != -1 && FMT_POSIX_CALL(close(fd_)) != 0) + fmt::report_system_error(errno, "cannot close file"); +} + +void fmt::File::close() { + if (fd_ == -1) + return; + // Don't retry close in case of EINTR! + // See http://linux.derkeiler.com/Mailing-Lists/Kernel/2005-09/3000.html + int result = FMT_POSIX_CALL(close(fd_)); + fd_ = -1; + if (result != 0) + FMT_THROW(SystemError(errno, "cannot close file")); +} + +fmt::LongLong fmt::File::size() const { +#ifdef _WIN32 + // Use GetFileSize instead of GetFileSizeEx for the case when _WIN32_WINNT + // is less than 0x0500 as is the case with some default MinGW builds. + // Both functions support large file sizes. + DWORD size_upper = 0; + HANDLE handle = reinterpret_cast(_get_osfhandle(fd_)); + DWORD size_lower = FMT_SYSTEM(GetFileSize(handle, &size_upper)); + if (size_lower == INVALID_FILE_SIZE) { + DWORD error = GetLastError(); + if (error != NO_ERROR) + FMT_THROW(WindowsError(GetLastError(), "cannot get file size")); + } + fmt::ULongLong long_size = size_upper; + return (long_size << sizeof(DWORD) * CHAR_BIT) | size_lower; +#else + typedef struct stat Stat; + Stat file_stat = Stat(); + if (FMT_POSIX_CALL(fstat(fd_, &file_stat)) == -1) + FMT_THROW(SystemError(errno, "cannot get file attributes")); + FMT_STATIC_ASSERT(sizeof(fmt::LongLong) >= sizeof(file_stat.st_size), + "return type of File::size is not large enough"); + return file_stat.st_size; +#endif +} + +std::size_t fmt::File::read(void *buffer, std::size_t count) { + RWResult result = 0; + FMT_RETRY(result, FMT_POSIX_CALL(read(fd_, buffer, convert_rwcount(count)))); + if (result < 0) + FMT_THROW(SystemError(errno, "cannot read from file")); + return internal::to_unsigned(result); +} + +std::size_t fmt::File::write(const void *buffer, std::size_t count) { + RWResult result = 0; + FMT_RETRY(result, FMT_POSIX_CALL(write(fd_, buffer, convert_rwcount(count)))); + if (result < 0) + FMT_THROW(SystemError(errno, "cannot write to file")); + return internal::to_unsigned(result); +} + +fmt::File fmt::File::dup(int fd) { + // Don't retry as dup doesn't return EINTR. + // http://pubs.opengroup.org/onlinepubs/009695399/functions/dup.html + int new_fd = FMT_POSIX_CALL(dup(fd)); + if (new_fd == -1) + FMT_THROW(SystemError(errno, "cannot duplicate file descriptor {}", fd)); + return File(new_fd); +} + +void fmt::File::dup2(int fd) { + int result = 0; + FMT_RETRY(result, FMT_POSIX_CALL(dup2(fd_, fd))); + if (result == -1) { + FMT_THROW(SystemError(errno, + "cannot duplicate file descriptor {} to {}", fd_, fd)); + } +} + +void fmt::File::dup2(int fd, ErrorCode &ec) FMT_NOEXCEPT { + int result = 0; + FMT_RETRY(result, FMT_POSIX_CALL(dup2(fd_, fd))); + if (result == -1) + ec = ErrorCode(errno); +} + +void fmt::File::pipe(File &read_end, File &write_end) { + // Close the descriptors first to make sure that assignments don't throw + // and there are no leaks. + read_end.close(); + write_end.close(); + int fds[2] = {}; +#ifdef _WIN32 + // Make the default pipe capacity same as on Linux 2.6.11+. + enum { DEFAULT_CAPACITY = 65536 }; + int result = FMT_POSIX_CALL(pipe(fds, DEFAULT_CAPACITY, _O_BINARY)); +#else + // Don't retry as the pipe function doesn't return EINTR. + // http://pubs.opengroup.org/onlinepubs/009696799/functions/pipe.html + int result = FMT_POSIX_CALL(pipe(fds)); +#endif + if (result != 0) + FMT_THROW(SystemError(errno, "cannot create pipe")); + // The following assignments don't throw because read_fd and write_fd + // are closed. + read_end = File(fds[0]); + write_end = File(fds[1]); +} + +fmt::BufferedFile fmt::File::fdopen(const char *mode) { + // Don't retry as fdopen doesn't return EINTR. + FILE *f = FMT_POSIX_CALL(fdopen(fd_, mode)); + if (!f) + FMT_THROW(SystemError(errno, "cannot associate stream with file descriptor")); + BufferedFile file(f); + fd_ = -1; + return file; +} + +long fmt::getpagesize() { +#ifdef _WIN32 + SYSTEM_INFO si; + GetSystemInfo(&si); + return si.dwPageSize; +#else + long size = FMT_POSIX_CALL(sysconf(_SC_PAGESIZE)); + if (size < 0) + FMT_THROW(SystemError(errno, "cannot get memory page size")); + return size; +#endif +} diff --git a/src/dionysus/wasserstein/spdlog/fmt/bundled/posix.h b/src/dionysus/wasserstein/spdlog/fmt/bundled/posix.h new file mode 100755 index 0000000..859fcaa --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/fmt/bundled/posix.h @@ -0,0 +1,443 @@ +/* + A C++ interface to POSIX functions. + + Copyright (c) 2012 - 2016, Victor Zverovich + All rights reserved. + + For the license information refer to format.h. + */ + +#ifndef FMT_POSIX_H_ +#define FMT_POSIX_H_ + +#if defined(__MINGW32__) || defined(__CYGWIN__) +// Workaround MinGW bug https://sourceforge.net/p/mingw/bugs/2024/. +# undef __STRICT_ANSI__ +#endif + +#include +#include // for O_RDONLY +#include // for locale_t +#include +#include // for strtod_l + +#include + +#if defined __APPLE__ || defined(__FreeBSD__) +# include // for LC_NUMERIC_MASK on OS X +#endif + +#include "format.h" + +#ifndef FMT_POSIX +# if defined(_WIN32) && !defined(__MINGW32__) +// Fix warnings about deprecated symbols. +# define FMT_POSIX(call) _##call +# else +# define FMT_POSIX(call) call +# endif +#endif + +// Calls to system functions are wrapped in FMT_SYSTEM for testability. +#ifdef FMT_SYSTEM +# define FMT_POSIX_CALL(call) FMT_SYSTEM(call) +#else +# define FMT_SYSTEM(call) call +# ifdef _WIN32 +// Fix warnings about deprecated symbols. +# define FMT_POSIX_CALL(call) ::_##call +# else +# define FMT_POSIX_CALL(call) ::call +# endif +#endif + +#if FMT_GCC_VERSION >= 407 +# define FMT_UNUSED __attribute__((unused)) +#else +# define FMT_UNUSED +#endif + +#ifndef FMT_USE_STATIC_ASSERT +# define FMT_USE_STATIC_ASSERT 0 +#endif + +#if FMT_USE_STATIC_ASSERT || FMT_HAS_FEATURE(cxx_static_assert) || \ + (FMT_GCC_VERSION >= 403 && FMT_HAS_GXX_CXX11) || _MSC_VER >= 1600 +# define FMT_STATIC_ASSERT(cond, message) static_assert(cond, message) +#else +# define FMT_CONCAT_(a, b) FMT_CONCAT(a, b) +# define FMT_STATIC_ASSERT(cond, message) \ + typedef int FMT_CONCAT_(Assert, __LINE__)[(cond) ? 1 : -1] FMT_UNUSED +#endif + +// Retries the expression while it evaluates to error_result and errno +// equals to EINTR. +#ifndef _WIN32 +# define FMT_RETRY_VAL(result, expression, error_result) \ + do { \ + result = (expression); \ + } while (result == error_result && errno == EINTR) +#else +# define FMT_RETRY_VAL(result, expression, error_result) result = (expression) +#endif + +#define FMT_RETRY(result, expression) FMT_RETRY_VAL(result, expression, -1) + +namespace fmt +{ + +// An error code. +class ErrorCode +{ +private: + int value_; + +public: +explicit ErrorCode(int value = 0) FMT_NOEXCEPT : + value_(value) {} + + int get() const FMT_NOEXCEPT + { + return value_; + } +}; + +// A buffered file. +class BufferedFile +{ +private: + FILE *file_; + + friend class File; + + explicit BufferedFile(FILE *f) : file_(f) {} + +public: + // Constructs a BufferedFile object which doesn't represent any file. +BufferedFile() FMT_NOEXCEPT : + file_(0) {} + + // Destroys the object closing the file it represents if any. + ~BufferedFile() FMT_NOEXCEPT; + +#if !FMT_USE_RVALUE_REFERENCES + // Emulate a move constructor and a move assignment operator if rvalue + // references are not supported. + +private: + // A proxy object to emulate a move constructor. + // It is private to make it impossible call operator Proxy directly. + struct Proxy + { + FILE *file; + }; + +public: + // A "move constructor" for moving from a temporary. +BufferedFile(Proxy p) FMT_NOEXCEPT : + file_(p.file) {} + + // A "move constructor" for moving from an lvalue. +BufferedFile(BufferedFile &f) FMT_NOEXCEPT : + file_(f.file_) + { + f.file_ = 0; + } + + // A "move assignment operator" for moving from a temporary. + BufferedFile &operator=(Proxy p) + { + close(); + file_ = p.file; + return *this; + } + + // A "move assignment operator" for moving from an lvalue. + BufferedFile &operator=(BufferedFile &other) + { + close(); + file_ = other.file_; + other.file_ = 0; + return *this; + } + + // Returns a proxy object for moving from a temporary: + // BufferedFile file = BufferedFile(...); + operator Proxy() FMT_NOEXCEPT + { + Proxy p = {file_}; + file_ = 0; + return p; + } + +#else +private: + FMT_DISALLOW_COPY_AND_ASSIGN(BufferedFile); + +public: +BufferedFile(BufferedFile &&other) FMT_NOEXCEPT : + file_(other.file_) + { + other.file_ = 0; + } + + BufferedFile& operator=(BufferedFile &&other) + { + close(); + file_ = other.file_; + other.file_ = 0; + return *this; + } +#endif + + // Opens a file. + BufferedFile(CStringRef filename, CStringRef mode); + + // Closes the file. + void close(); + + // Returns the pointer to a FILE object representing this file. + FILE *get() const FMT_NOEXCEPT + { + return file_; + } + + // We place parentheses around fileno to workaround a bug in some versions + // of MinGW that define fileno as a macro. + int (fileno)() const; + + void print(CStringRef format_str, const ArgList &args) + { + fmt::print(file_, format_str, args); + } + FMT_VARIADIC(void, print, CStringRef) +}; + +// A file. Closed file is represented by a File object with descriptor -1. +// Methods that are not declared with FMT_NOEXCEPT may throw +// fmt::SystemError in case of failure. Note that some errors such as +// closing the file multiple times will cause a crash on Windows rather +// than an exception. You can get standard behavior by overriding the +// invalid parameter handler with _set_invalid_parameter_handler. +class File +{ +private: + int fd_; // File descriptor. + + // Constructs a File object with a given descriptor. + explicit File(int fd) : fd_(fd) {} + +public: + // Possible values for the oflag argument to the constructor. + enum + { + RDONLY = FMT_POSIX(O_RDONLY), // Open for reading only. + WRONLY = FMT_POSIX(O_WRONLY), // Open for writing only. + RDWR = FMT_POSIX(O_RDWR) // Open for reading and writing. + }; + + // Constructs a File object which doesn't represent any file. +File() FMT_NOEXCEPT : + fd_(-1) {} + + // Opens a file and constructs a File object representing this file. + File(CStringRef path, int oflag); + +#if !FMT_USE_RVALUE_REFERENCES + // Emulate a move constructor and a move assignment operator if rvalue + // references are not supported. + +private: + // A proxy object to emulate a move constructor. + // It is private to make it impossible call operator Proxy directly. + struct Proxy + { + int fd; + }; + +public: + // A "move constructor" for moving from a temporary. +File(Proxy p) FMT_NOEXCEPT : + fd_(p.fd) {} + + // A "move constructor" for moving from an lvalue. +File(File &other) FMT_NOEXCEPT : + fd_(other.fd_) + { + other.fd_ = -1; + } + + // A "move assignment operator" for moving from a temporary. + File &operator=(Proxy p) + { + close(); + fd_ = p.fd; + return *this; + } + + // A "move assignment operator" for moving from an lvalue. + File &operator=(File &other) + { + close(); + fd_ = other.fd_; + other.fd_ = -1; + return *this; + } + + // Returns a proxy object for moving from a temporary: + // File file = File(...); + operator Proxy() FMT_NOEXCEPT + { + Proxy p = {fd_}; + fd_ = -1; + return p; + } + +#else +private: + FMT_DISALLOW_COPY_AND_ASSIGN(File); + +public: +File(File &&other) FMT_NOEXCEPT : + fd_(other.fd_) + { + other.fd_ = -1; + } + + File& operator=(File &&other) + { + close(); + fd_ = other.fd_; + other.fd_ = -1; + return *this; + } +#endif + + // Destroys the object closing the file it represents if any. + ~File() FMT_NOEXCEPT; + + // Returns the file descriptor. + int descriptor() const FMT_NOEXCEPT + { + return fd_; + } + + // Closes the file. + void close(); + + // Returns the file size. The size has signed type for consistency with + // stat::st_size. + LongLong size() const; + + // Attempts to read count bytes from the file into the specified buffer. + std::size_t read(void *buffer, std::size_t count); + + // Attempts to write count bytes from the specified buffer to the file. + std::size_t write(const void *buffer, std::size_t count); + + // Duplicates a file descriptor with the dup function and returns + // the duplicate as a file object. + static File dup(int fd); + + // Makes fd be the copy of this file descriptor, closing fd first if + // necessary. + void dup2(int fd); + + // Makes fd be the copy of this file descriptor, closing fd first if + // necessary. + void dup2(int fd, ErrorCode &ec) FMT_NOEXCEPT; + + // Creates a pipe setting up read_end and write_end file objects for reading + // and writing respectively. + static void pipe(File &read_end, File &write_end); + + // Creates a BufferedFile object associated with this file and detaches + // this File object from the file. + BufferedFile fdopen(const char *mode); +}; + +// Returns the memory page size. +long getpagesize(); + +#if (defined(LC_NUMERIC_MASK) || defined(_MSC_VER)) && \ + !defined(__ANDROID__) && !defined(__CYGWIN__) +# define FMT_LOCALE +#endif + +#ifdef FMT_LOCALE +// A "C" numeric locale. +class Locale +{ +private: +# ifdef _MSC_VER + typedef _locale_t locale_t; + + enum { LC_NUMERIC_MASK = LC_NUMERIC }; + + static locale_t newlocale(int category_mask, const char *locale, locale_t) + { + return _create_locale(category_mask, locale); + } + + static void freelocale(locale_t locale) + { + _free_locale(locale); + } + + static double strtod_l(const char *nptr, char **endptr, _locale_t locale) + { + return _strtod_l(nptr, endptr, locale); + } +# endif + + locale_t locale_; + + FMT_DISALLOW_COPY_AND_ASSIGN(Locale); + +public: + typedef locale_t Type; + + Locale() : locale_(newlocale(LC_NUMERIC_MASK, "C", NULL)) + { + if (!locale_) + FMT_THROW(fmt::SystemError(errno, "cannot create locale")); + } + ~Locale() + { + freelocale(locale_); + } + + Type get() const + { + return locale_; + } + + // Converts string to floating-point number and advances str past the end + // of the parsed input. + double strtod(const char *&str) const + { + char *end = 0; + double result = strtod_l(str, &end, locale_); + str = end; + return result; + } +}; +#endif // FMT_LOCALE +} // namespace fmt + +#if !FMT_USE_RVALUE_REFERENCES +namespace std +{ +// For compatibility with C++98. +inline fmt::BufferedFile &move(fmt::BufferedFile &f) +{ + return f; +} +inline fmt::File &move(fmt::File &f) +{ + return f; +} +} +#endif + +#endif // FMT_POSIX_H_ diff --git a/src/dionysus/wasserstein/spdlog/fmt/bundled/time.h b/src/dionysus/wasserstein/spdlog/fmt/bundled/time.h new file mode 100755 index 0000000..10c6cfc --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/fmt/bundled/time.h @@ -0,0 +1,58 @@ +/* + Formatting library for C++ - time formatting + + Copyright (c) 2012 - 2016, Victor Zverovich + All rights reserved. + + For the license information refer to format.h. + */ + +#ifndef FMT_TIME_H_ +#define FMT_TIME_H_ + +#include "format.h" +#include + +namespace fmt +{ +template +void format(BasicFormatter &f, + const char *&format_str, const std::tm &tm) +{ + if (*format_str == ':') + ++format_str; + const char *end = format_str; + while (*end && *end != '}') + ++end; + if (*end != '}') + FMT_THROW(FormatError("missing '}' in format string")); + internal::MemoryBuffer format; + format.append(format_str, end + 1); + format[format.size() - 1] = '\0'; + Buffer &buffer = f.writer().buffer(); + std::size_t start = buffer.size(); + for (;;) + { + std::size_t size = buffer.capacity() - start; + std::size_t count = std::strftime(&buffer[start], size, &format[0], &tm); + if (count != 0) + { + buffer.resize(start + count); + break; + } + if (size >= format.size() * 256) + { + // If the buffer is 256 times larger than the format string, assume + // that `strftime` gives an empty result. There doesn't seem to be a + // better way to distinguish the two cases: + // https://github.com/fmtlib/fmt/issues/367 + break; + } + const std::size_t MIN_GROWTH = 10; + buffer.reserve(buffer.capacity() + (size > MIN_GROWTH ? size : MIN_GROWTH)); + } + format_str = end + 1; +} +} + +#endif // FMT_TIME_H_ diff --git a/src/dionysus/wasserstein/spdlog/fmt/fmt.h b/src/dionysus/wasserstein/spdlog/fmt/fmt.h new file mode 100755 index 0000000..b39104f --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/fmt/fmt.h @@ -0,0 +1,28 @@ +// +// Copyright(c) 2016 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +// +// Include a bundled header-only copy of fmtlib or an external one. +// By default spdlog include its own copy. +// + +#if !defined(SPDLOG_FMT_EXTERNAL) + +#ifndef FMT_HEADER_ONLY +#define FMT_HEADER_ONLY +#endif +#ifndef FMT_USE_WINDOWS_H +#define FMT_USE_WINDOWS_H 0 +#endif +#include "../fmt/bundled/format.h" + +#else //external fmtlib + +#include + +#endif + diff --git a/src/dionysus/wasserstein/spdlog/fmt/ostr.h b/src/dionysus/wasserstein/spdlog/fmt/ostr.h new file mode 100755 index 0000000..6959d48 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/fmt/ostr.h @@ -0,0 +1,17 @@ +// +// Copyright(c) 2016 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +// include external or bundled copy of fmtlib's ostream support +// +#if !defined(SPDLOG_FMT_EXTERNAL) +#include "../fmt/fmt.h" +#include "../fmt/bundled/ostream.h" +#else +#include +#endif + + diff --git a/src/dionysus/wasserstein/spdlog/formatter.h b/src/dionysus/wasserstein/spdlog/formatter.h new file mode 100755 index 0000000..8bf0f43 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/formatter.h @@ -0,0 +1,47 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include "details/log_msg.h" + +#include +#include +#include + +namespace spdlog +{ +namespace details +{ +class flag_formatter; +} + +class formatter +{ +public: + virtual ~formatter() {} + virtual void format(details::log_msg& msg) = 0; +}; + +class pattern_formatter SPDLOG_FINAL : public formatter +{ + +public: + explicit pattern_formatter(const std::string& pattern, pattern_time_type pattern_time = pattern_time_type::local); + pattern_formatter(const pattern_formatter&) = delete; + pattern_formatter& operator=(const pattern_formatter&) = delete; + void format(details::log_msg& msg) override; +private: + const std::string _pattern; + const pattern_time_type _pattern_time; + std::vector> _formatters; + std::tm get_time(details::log_msg& msg); + void handle_flag(char flag); + void compile_pattern(const std::string& pattern); +}; +} + +#include "details/pattern_formatter_impl.h" + diff --git a/src/dionysus/wasserstein/spdlog/logger.h b/src/dionysus/wasserstein/spdlog/logger.h new file mode 100755 index 0000000..eafa9b1 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/logger.h @@ -0,0 +1,132 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +// Thread safe logger (except for set_pattern(..), set_formatter(..) and set_error_handler()) +// Has name, log level, vector of std::shared sink pointers and formatter +// Upon each log write the logger: +// 1. Checks if its log level is enough to log the message +// 2. Format the message using the formatter function +// 3. Pass the formatted message to its sinks to performa the actual logging + +#include "sinks/base_sink.h" +#include "common.h" + +#include +#include +#include + +namespace spdlog +{ + +class logger +{ +public: + logger(const std::string& logger_name, sink_ptr single_sink); + logger(const std::string& name, sinks_init_list); + template + logger(const std::string& name, const It& begin, const It& end); + + virtual ~logger(); + logger(const logger&) = delete; + logger& operator=(const logger&) = delete; + + + template void log(level::level_enum lvl, const char* fmt, const Args&... args); + template void log(level::level_enum lvl, const char* msg); + template void trace(const char* fmt, const Arg1&, const Args&... args); + template void debug(const char* fmt, const Arg1&, const Args&... args); + template void info(const char* fmt, const Arg1&, const Args&... args); + template void warn(const char* fmt, const Arg1&, const Args&... args); + template void error(const char* fmt, const Arg1&, const Args&... args); + template void critical(const char* fmt, const Arg1&, const Args&... args); + + template void log_if(const bool flag, level::level_enum lvl, const char* fmt, const Args&... args); + template void log_if(const bool flag, level::level_enum lvl, const char* msg); + template void trace_if(const bool flag, const char* fmt, const Arg1&, const Args&... args); + template void debug_if(const bool flag, const char* fmt, const Arg1&, const Args&... args); + template void info_if(const bool flag, const char* fmt, const Arg1&, const Args&... args); + template void warn_if(const bool flag, const char* fmt, const Arg1&, const Args&... args); + template void error_if(const bool flag, const char* fmt, const Arg1&, const Args&... args); + template void critical_if(const bool flag, const char* fmt, const Arg1&, const Args&... args); + +#ifdef SPDLOG_WCHAR_TO_UTF8_SUPPORT + template void log(level::level_enum lvl, const wchar_t* msg); + template void log(level::level_enum lvl, const wchar_t* fmt, const Args&... args); + template void trace(const wchar_t* fmt, const Args&... args); + template void debug(const wchar_t* fmt, const Args&... args); + template void info(const wchar_t* fmt, const Args&... args); + template void warn(const wchar_t* fmt, const Args&... args); + template void error(const wchar_t* fmt, const Args&... args); + template void critical(const wchar_t* fmt, const Args&... args); + + template void log_if(const bool flag, level::level_enum lvl, const wchar_t* msg); + template void log_if(const bool flag, level::level_enum lvl, const wchar_t* fmt, const Args&... args); + template void trace_if(const bool flag, const wchar_t* fmt, const Args&... args); + template void debug_if(const bool flag, const wchar_t* fmt, const Args&... args); + template void info_if(const bool flag, const wchar_t* fmt, const Args&... args); + template void warn_if(const bool flag, const wchar_t* fmt, const Args&... args); + template void error_if(const bool flag, const wchar_t* fmt, const Args&... args); + template void critical_if(const bool flag, const wchar_t* fmt, const Args&... args); +#endif // SPDLOG_WCHAR_TO_UTF8_SUPPORT + + template void log(level::level_enum lvl, const T&); + template void trace(const T&); + template void debug(const T&); + template void info(const T&); + template void warn(const T&); + template void error(const T&); + template void critical(const T&); + + template void log_if(const bool flag, level::level_enum lvl, const T&); + template void trace_if(const bool flag, const T&); + template void debug_if(const bool flag, const T&); + template void info_if(const bool flag, const T&); + template void warn_if(const bool flag, const T&); + template void error_if(const bool flag, const T&); + template void critical_if(const bool flag, const T&); + + bool should_log(level::level_enum) const; + void set_level(level::level_enum); + level::level_enum level() const; + const std::string& name() const; + void set_pattern(const std::string&, pattern_time_type = pattern_time_type::local); + void set_formatter(formatter_ptr); + + // automatically call flush() if message level >= log_level + void flush_on(level::level_enum log_level); + + virtual void flush(); + + const std::vector& sinks() const; + + // error handler + virtual void set_error_handler(log_err_handler); + virtual log_err_handler error_handler(); + +protected: + virtual void _sink_it(details::log_msg&); + virtual void _set_pattern(const std::string&, pattern_time_type); + virtual void _set_formatter(formatter_ptr); + + // default error handler: print the error to stderr with the max rate of 1 message/minute + virtual void _default_err_handler(const std::string &msg); + + // return true if the given message level should trigger a flush + bool _should_flush_on(const details::log_msg&); + + const std::string _name; + std::vector _sinks; + formatter_ptr _formatter; + spdlog::level_t _level; + spdlog::level_t _flush_level; + log_err_handler _err_handler; + std::atomic _last_err_time; + std::atomic _msg_counter; +}; +} + +#include "details/logger_impl.h" diff --git a/src/dionysus/wasserstein/spdlog/sinks/android_sink.h b/src/dionysus/wasserstein/spdlog/sinks/android_sink.h new file mode 100755 index 0000000..0ab475e --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/sinks/android_sink.h @@ -0,0 +1,90 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#if defined(__ANDROID__) + +#include "../sinks/sink.h" + +#include +#include +#include +#include +#include + +#if !defined(SPDLOG_ANDROID_RETRIES) +#define SPDLOG_ANDROID_RETRIES 2 +#endif + +namespace spdlog +{ +namespace sinks +{ + +/* +* Android sink (logging using __android_log_write) +* __android_log_write is thread-safe. No lock is needed. +*/ +class android_sink : public sink +{ +public: + explicit android_sink(const std::string& tag = "spdlog", bool use_raw_msg = false): _tag(tag), _use_raw_msg(use_raw_msg) {} + + void log(const details::log_msg& msg) override + { + const android_LogPriority priority = convert_to_android(msg.level); + const char *msg_output = (_use_raw_msg ? msg.raw.c_str() : msg.formatted.c_str()); + + // See system/core/liblog/logger_write.c for explanation of return value + int ret = __android_log_write(priority, _tag.c_str(), msg_output); + int retry_count = 0; + while ((ret == -11/*EAGAIN*/) && (retry_count < SPDLOG_ANDROID_RETRIES)) + { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + ret = __android_log_write(priority, _tag.c_str(), msg_output); + retry_count++; + } + + if (ret < 0) + { + throw spdlog_ex("__android_log_write() failed", ret); + } + } + + void flush() override + { + } + +private: + static android_LogPriority convert_to_android(spdlog::level::level_enum level) + { + switch(level) + { + case spdlog::level::trace: + return ANDROID_LOG_VERBOSE; + case spdlog::level::debug: + return ANDROID_LOG_DEBUG; + case spdlog::level::info: + return ANDROID_LOG_INFO; + case spdlog::level::warn: + return ANDROID_LOG_WARN; + case spdlog::level::err: + return ANDROID_LOG_ERROR; + case spdlog::level::critical: + return ANDROID_LOG_FATAL; + default: + return ANDROID_LOG_DEFAULT; + } + } + + std::string _tag; + bool _use_raw_msg; +}; + +} +} + +#endif diff --git a/src/dionysus/wasserstein/spdlog/sinks/ansicolor_sink.h b/src/dionysus/wasserstein/spdlog/sinks/ansicolor_sink.h new file mode 100755 index 0000000..4a393c6 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/sinks/ansicolor_sink.h @@ -0,0 +1,133 @@ +// +// Copyright(c) 2017 spdlog authors. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include "../sinks/base_sink.h" +#include "../common.h" +#include "../details/os.h" + +#include +#include + +namespace spdlog +{ +namespace sinks +{ + +/** + * This sink prefixes the output with an ANSI escape sequence color code depending on the severity + * of the message. + * If no color terminal detected, omit the escape codes. + */ +template +class ansicolor_sink: public base_sink +{ +public: + ansicolor_sink(FILE* file): target_file_(file) + { + should_do_colors_ = details::os::in_terminal(file) && details::os::is_color_terminal(); + colors_[level::trace] = cyan; + colors_[level::debug] = cyan; + colors_[level::info] = reset; + colors_[level::warn] = yellow + bold; + colors_[level::err] = red + bold; + colors_[level::critical] = bold + on_red; + colors_[level::off] = reset; + } + virtual ~ansicolor_sink() + { + _flush(); + } + + void set_color(level::level_enum color_level, const std::string& color) + { + std::lock_guard lock(base_sink::_mutex); + colors_[color_level] = color; + } + + /// Formatting codes + const std::string reset = "\033[00m"; + const std::string bold = "\033[1m"; + const std::string dark = "\033[2m"; + const std::string underline = "\033[4m"; + const std::string blink = "\033[5m"; + const std::string reverse = "\033[7m"; + const std::string concealed = "\033[8m"; + + // Foreground colors + const std::string grey = "\033[30m"; + const std::string red = "\033[31m"; + const std::string green = "\033[32m"; + const std::string yellow = "\033[33m"; + const std::string blue = "\033[34m"; + const std::string magenta = "\033[35m"; + const std::string cyan = "\033[36m"; + const std::string white = "\033[37m"; + + /// Background colors + const std::string on_grey = "\033[40m"; + const std::string on_red = "\033[41m"; + const std::string on_green = "\033[42m"; + const std::string on_yellow = "\033[43m"; + const std::string on_blue = "\033[44m"; + const std::string on_magenta = "\033[45m"; + const std::string on_cyan = "\033[46m"; + const std::string on_white = "\033[47m"; + +protected: + virtual void _sink_it(const details::log_msg& msg) override + { + // Wrap the originally formatted message in color codes. + // If color is not supported in the terminal, log as is instead. + if (should_do_colors_) + { + const std::string& prefix = colors_[msg.level]; + fwrite(prefix.data(), sizeof(char), prefix.size(), target_file_); + fwrite(msg.formatted.data(), sizeof(char), msg.formatted.size(), target_file_); + fwrite(reset.data(), sizeof(char), reset.size(), target_file_); + } + else + { + fwrite(msg.formatted.data(), sizeof(char), msg.formatted.size(), target_file_); + } + _flush(); + } + + void _flush() override + { + fflush(target_file_); + } + FILE* target_file_; + bool should_do_colors_; + std::map colors_; +}; + + +template +class ansicolor_stdout_sink: public ansicolor_sink +{ +public: + ansicolor_stdout_sink(): ansicolor_sink(stdout) + {} +}; + +template +class ansicolor_stderr_sink: public ansicolor_sink +{ +public: + ansicolor_stderr_sink(): ansicolor_sink(stderr) + {} +}; + +typedef ansicolor_stdout_sink ansicolor_stdout_sink_mt; +typedef ansicolor_stdout_sink ansicolor_stdout_sink_st; + +typedef ansicolor_stderr_sink ansicolor_stderr_sink_mt; +typedef ansicolor_stderr_sink ansicolor_stderr_sink_st; + +} // namespace sinks +} // namespace spdlog + diff --git a/src/dionysus/wasserstein/spdlog/sinks/base_sink.h b/src/dionysus/wasserstein/spdlog/sinks/base_sink.h new file mode 100755 index 0000000..d3e4265 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/sinks/base_sink.h @@ -0,0 +1,50 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once +// +// base sink templated over a mutex (either dummy or real) +// concrete implementation should only override the _sink_it method. +// all locking is taken care of here so no locking needed by the implementers.. +// + +#include "../sinks/sink.h" +#include "../formatter.h" +#include "../common.h" +#include "../details/log_msg.h" + +#include + +namespace spdlog +{ +namespace sinks +{ +template +class base_sink:public sink +{ +public: + base_sink():_mutex() {} + virtual ~base_sink() = default; + + base_sink(const base_sink&) = delete; + base_sink& operator=(const base_sink&) = delete; + + void log(const details::log_msg& msg) SPDLOG_FINAL override + { + std::lock_guard lock(_mutex); + _sink_it(msg); + } + void flush() SPDLOG_FINAL override + { + _flush(); + } + +protected: + virtual void _sink_it(const details::log_msg& msg) = 0; + virtual void _flush() = 0; + Mutex _mutex; +}; +} +} diff --git a/src/dionysus/wasserstein/spdlog/sinks/dist_sink.h b/src/dionysus/wasserstein/spdlog/sinks/dist_sink.h new file mode 100755 index 0000000..f7b799d --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/sinks/dist_sink.h @@ -0,0 +1,73 @@ +// +// Copyright (c) 2015 David Schury, Gabi Melman +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include "../details/log_msg.h" +#include "../details/null_mutex.h" +#include "../sinks/base_sink.h" +#include "../sinks/sink.h" + +#include +#include +#include +#include + +// Distribution sink (mux). Stores a vector of sinks which get called when log is called + +namespace spdlog +{ +namespace sinks +{ +template +class dist_sink: public base_sink +{ +public: + explicit dist_sink() :_sinks() {} + dist_sink(const dist_sink&) = delete; + dist_sink& operator=(const dist_sink&) = delete; + virtual ~dist_sink() = default; + +protected: + std::vector> _sinks; + + void _sink_it(const details::log_msg& msg) override + { + for (auto &sink : _sinks) + { + if( sink->should_log( msg.level)) + { + sink->log(msg); + } + } + } + + void _flush() override + { + std::lock_guard lock(base_sink::_mutex); + for (auto &sink : _sinks) + sink->flush(); + } + +public: + + + void add_sink(std::shared_ptr sink) + { + std::lock_guard lock(base_sink::_mutex); + _sinks.push_back(sink); + } + + void remove_sink(std::shared_ptr sink) + { + std::lock_guard lock(base_sink::_mutex); + _sinks.erase(std::remove(_sinks.begin(), _sinks.end(), sink), _sinks.end()); + } +}; + +typedef dist_sink dist_sink_mt; +typedef dist_sink dist_sink_st; +} +} diff --git a/src/dionysus/wasserstein/spdlog/sinks/file_sinks.h b/src/dionysus/wasserstein/spdlog/sinks/file_sinks.h new file mode 100755 index 0000000..1b7b3bf --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/sinks/file_sinks.h @@ -0,0 +1,242 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include "../sinks/base_sink.h" +#include "../details/null_mutex.h" +#include "../details/file_helper.h" +#include "../fmt/fmt.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace spdlog +{ +namespace sinks +{ +/* + * Trivial file sink with single file as target + */ +template +class simple_file_sink SPDLOG_FINAL : public base_sink < Mutex > +{ +public: + explicit simple_file_sink(const filename_t &filename, bool truncate = false):_force_flush(false) + { + _file_helper.open(filename, truncate); + } + + void set_force_flush(bool force_flush) + { + _force_flush = force_flush; + } + +protected: + void _sink_it(const details::log_msg& msg) override + { + _file_helper.write(msg); + if(_force_flush) + _file_helper.flush(); + } + void _flush() override + { + _file_helper.flush(); + } +private: + details::file_helper _file_helper; + bool _force_flush; +}; + +typedef simple_file_sink simple_file_sink_mt; +typedef simple_file_sink simple_file_sink_st; + +/* + * Rotating file sink based on size + */ +template +class rotating_file_sink SPDLOG_FINAL : public base_sink < Mutex > +{ +public: + rotating_file_sink(const filename_t &base_filename, + std::size_t max_size, std::size_t max_files) : + _base_filename(base_filename), + _max_size(max_size), + _max_files(max_files), + _current_size(0), + _file_helper() + { + _file_helper.open(calc_filename(_base_filename, 0)); + _current_size = _file_helper.size(); //expensive. called only once + } + + +protected: + void _sink_it(const details::log_msg& msg) override + { + _current_size += msg.formatted.size(); + if (_current_size > _max_size) + { + _rotate(); + _current_size = msg.formatted.size(); + } + _file_helper.write(msg); + } + + void _flush() override + { + _file_helper.flush(); + } + +private: + static filename_t calc_filename(const filename_t& filename, std::size_t index) + { + std::conditional::value, fmt::MemoryWriter, fmt::WMemoryWriter>::type w; + if (index) + w.write(SPDLOG_FILENAME_T("{}.{}"), filename, index); + else + w.write(SPDLOG_FILENAME_T("{}"), filename); + return w.str(); + } + + // Rotate files: + // log.txt -> log.txt.1 + // log.txt.1 -> log.txt.2 + // log.txt.2 -> log.txt.3 + // lo3.txt.3 -> delete + + void _rotate() + { + using details::os::filename_to_str; + _file_helper.close(); + for (auto i = _max_files; i > 0; --i) + { + filename_t src = calc_filename(_base_filename, i - 1); + filename_t target = calc_filename(_base_filename, i); + + if (details::file_helper::file_exists(target)) + { + if (details::os::remove(target) != 0) + { + throw spdlog_ex("rotating_file_sink: failed removing " + filename_to_str(target), errno); + } + } + if (details::file_helper::file_exists(src) && details::os::rename(src, target)) + { + throw spdlog_ex("rotating_file_sink: failed renaming " + filename_to_str(src) + " to " + filename_to_str(target), errno); + } + } + _file_helper.reopen(true); + } + filename_t _base_filename; + std::size_t _max_size; + std::size_t _max_files; + std::size_t _current_size; + details::file_helper _file_helper; +}; + +typedef rotating_file_sink rotating_file_sink_mt; +typedef rotating_file_sinkrotating_file_sink_st; + +/* + * Default generator of daily log file names. + */ +struct default_daily_file_name_calculator +{ + // Create filename for the form basename.YYYY-MM-DD_hh-mm + static filename_t calc_filename(const filename_t& basename) + { + std::tm tm = spdlog::details::os::localtime(); + std::conditional::value, fmt::MemoryWriter, fmt::WMemoryWriter>::type w; + w.write(SPDLOG_FILENAME_T("{}_{:04d}-{:02d}-{:02d}_{:02d}-{:02d}"), basename, tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday, tm.tm_hour, tm.tm_min); + return w.str(); + } +}; + +/* + * Generator of daily log file names in format basename.YYYY-MM-DD + */ +struct dateonly_daily_file_name_calculator +{ + // Create filename for the form basename.YYYY-MM-DD + static filename_t calc_filename(const filename_t& basename) + { + std::tm tm = spdlog::details::os::localtime(); + std::conditional::value, fmt::MemoryWriter, fmt::WMemoryWriter>::type w; + w.write(SPDLOG_FILENAME_T("{}_{:04d}-{:02d}-{:02d}"), basename, tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday); + return w.str(); + } +}; + +/* + * Rotating file sink based on date. rotates at midnight + */ +template +class daily_file_sink SPDLOG_FINAL :public base_sink < Mutex > +{ +public: + //create daily file sink which rotates on given time + daily_file_sink( + const filename_t& base_filename, + int rotation_hour, + int rotation_minute) : _base_filename(base_filename), + _rotation_h(rotation_hour), + _rotation_m(rotation_minute) + { + if (rotation_hour < 0 || rotation_hour > 23 || rotation_minute < 0 || rotation_minute > 59) + throw spdlog_ex("daily_file_sink: Invalid rotation time in ctor"); + _rotation_tp = _next_rotation_tp(); + _file_helper.open(FileNameCalc::calc_filename(_base_filename)); + } + + +protected: + void _sink_it(const details::log_msg& msg) override + { + if (std::chrono::system_clock::now() >= _rotation_tp) + { + _file_helper.open(FileNameCalc::calc_filename(_base_filename)); + _rotation_tp = _next_rotation_tp(); + } + _file_helper.write(msg); + } + + void _flush() override + { + _file_helper.flush(); + } + +private: + std::chrono::system_clock::time_point _next_rotation_tp() + { + auto now = std::chrono::system_clock::now(); + time_t tnow = std::chrono::system_clock::to_time_t(now); + tm date = spdlog::details::os::localtime(tnow); + date.tm_hour = _rotation_h; + date.tm_min = _rotation_m; + date.tm_sec = 0; + auto rotation_time = std::chrono::system_clock::from_time_t(std::mktime(&date)); + if (rotation_time > now) + return rotation_time; + else + return std::chrono::system_clock::time_point(rotation_time + std::chrono::hours(24)); + } + + filename_t _base_filename; + int _rotation_h; + int _rotation_m; + std::chrono::system_clock::time_point _rotation_tp; + details::file_helper _file_helper; +}; + +typedef daily_file_sink daily_file_sink_mt; +typedef daily_file_sink daily_file_sink_st; +} +} diff --git a/src/dionysus/wasserstein/spdlog/sinks/msvc_sink.h b/src/dionysus/wasserstein/spdlog/sinks/msvc_sink.h new file mode 100755 index 0000000..249ef71 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/sinks/msvc_sink.h @@ -0,0 +1,51 @@ +// +// Copyright(c) 2016 Alexander Dalshov. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#if defined(_MSC_VER) + +#include "../sinks/base_sink.h" +#include "../details/null_mutex.h" + +#include + +#include +#include + +namespace spdlog +{ +namespace sinks +{ +/* +* MSVC sink (logging using OutputDebugStringA) +*/ +template +class msvc_sink : public base_sink < Mutex > +{ +public: + explicit msvc_sink() + { + } + + + +protected: + void _sink_it(const details::log_msg& msg) override + { + OutputDebugStringA(msg.formatted.c_str()); + } + + void _flush() override + {} +}; + +typedef msvc_sink msvc_sink_mt; +typedef msvc_sink msvc_sink_st; + +} +} + +#endif diff --git a/src/dionysus/wasserstein/spdlog/sinks/null_sink.h b/src/dionysus/wasserstein/spdlog/sinks/null_sink.h new file mode 100755 index 0000000..582f8ac --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/sinks/null_sink.h @@ -0,0 +1,34 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include "../sinks/base_sink.h" +#include "../details/null_mutex.h" + +#include + +namespace spdlog +{ +namespace sinks +{ + +template +class null_sink : public base_sink < Mutex > +{ +protected: + void _sink_it(const details::log_msg&) override + {} + + void _flush() override + {} + +}; +typedef null_sink null_sink_st; +typedef null_sink null_sink_mt; + +} +} + diff --git a/src/dionysus/wasserstein/spdlog/sinks/ostream_sink.h b/src/dionysus/wasserstein/spdlog/sinks/ostream_sink.h new file mode 100755 index 0000000..aa9cd54 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/sinks/ostream_sink.h @@ -0,0 +1,47 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include "../details/null_mutex.h" +#include "../sinks/base_sink.h" + +#include +#include + +namespace spdlog +{ +namespace sinks +{ +template +class ostream_sink: public base_sink +{ +public: + explicit ostream_sink(std::ostream& os, bool force_flush=false) :_ostream(os), _force_flush(force_flush) {} + ostream_sink(const ostream_sink&) = delete; + ostream_sink& operator=(const ostream_sink&) = delete; + virtual ~ostream_sink() = default; + +protected: + void _sink_it(const details::log_msg& msg) override + { + _ostream.write(msg.formatted.data(), msg.formatted.size()); + if (_force_flush) + _ostream.flush(); + } + + void _flush() override + { + _ostream.flush(); + } + + std::ostream& _ostream; + bool _force_flush; +}; + +typedef ostream_sink ostream_sink_mt; +typedef ostream_sink ostream_sink_st; +} +} diff --git a/src/dionysus/wasserstein/spdlog/sinks/sink.h b/src/dionysus/wasserstein/spdlog/sinks/sink.h new file mode 100755 index 0000000..af61b54 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/sinks/sink.h @@ -0,0 +1,53 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + + +#pragma once + +#include "../details/log_msg.h" + +namespace spdlog +{ +namespace sinks +{ +class sink +{ +public: + sink() + { + _level = level::trace; + } + + virtual ~sink() {} + virtual void log(const details::log_msg& msg) = 0; + virtual void flush() = 0; + + bool should_log(level::level_enum msg_level) const; + void set_level(level::level_enum log_level); + level::level_enum level() const; + +private: + level_t _level; + +}; + +inline bool sink::should_log(level::level_enum msg_level) const +{ + return msg_level >= _level.load(std::memory_order_relaxed); +} + +inline void sink::set_level(level::level_enum log_level) +{ + _level.store(log_level); +} + +inline level::level_enum sink::level() const +{ + return static_cast(_level.load(std::memory_order_relaxed)); +} + +} +} + diff --git a/src/dionysus/wasserstein/spdlog/sinks/stdout_sinks.h b/src/dionysus/wasserstein/spdlog/sinks/stdout_sinks.h new file mode 100755 index 0000000..ab59412 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/sinks/stdout_sinks.h @@ -0,0 +1,77 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include "../details/null_mutex.h" +#include "../sinks/base_sink.h" + +#include +#include +#include + +namespace spdlog +{ +namespace sinks +{ + +template +class stdout_sink SPDLOG_FINAL : public base_sink +{ + using MyType = stdout_sink; +public: + stdout_sink() + {} + static std::shared_ptr instance() + { + static std::shared_ptr instance = std::make_shared(); + return instance; + } +protected: + void _sink_it(const details::log_msg& msg) override + { + fwrite(msg.formatted.data(), sizeof(char), msg.formatted.size(), stdout); + _flush(); + } + + void _flush() override + { + fflush(stdout); + } +}; + +typedef stdout_sink stdout_sink_st; +typedef stdout_sink stdout_sink_mt; + + +template +class stderr_sink SPDLOG_FINAL : public base_sink +{ + using MyType = stderr_sink; +public: + stderr_sink() + {} + static std::shared_ptr instance() + { + static std::shared_ptr instance = std::make_shared(); + return instance; + } +protected: + void _sink_it(const details::log_msg& msg) override + { + fwrite(msg.formatted.data(), sizeof(char), msg.formatted.size(), stderr); + _flush(); + } + + void _flush() override + { + fflush(stderr); + } +}; + +typedef stderr_sink stderr_sink_mt; +typedef stderr_sink stderr_sink_st; +} +} diff --git a/src/dionysus/wasserstein/spdlog/sinks/syslog_sink.h b/src/dionysus/wasserstein/spdlog/sinks/syslog_sink.h new file mode 100755 index 0000000..2cbc7af --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/sinks/syslog_sink.h @@ -0,0 +1,81 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include "../common.h" + +#ifdef SPDLOG_ENABLE_SYSLOG + +#include "../sinks/sink.h" +#include "../details/log_msg.h" + +#include +#include +#include + + +namespace spdlog +{ +namespace sinks +{ +/** + * Sink that write to syslog using the `syscall()` library call. + * + * Locking is not needed, as `syslog()` itself is thread-safe. + */ +class syslog_sink : public sink +{ +public: + // + syslog_sink(const std::string& ident = "", int syslog_option=0, int syslog_facility=LOG_USER): + _ident(ident) + { + _priorities[static_cast(level::trace)] = LOG_DEBUG; + _priorities[static_cast(level::debug)] = LOG_DEBUG; + _priorities[static_cast(level::info)] = LOG_INFO; + _priorities[static_cast(level::warn)] = LOG_WARNING; + _priorities[static_cast(level::err)] = LOG_ERR; + _priorities[static_cast(level::critical)] = LOG_CRIT; + _priorities[static_cast(level::off)] = LOG_INFO; + + //set ident to be program name if empty + ::openlog(_ident.empty()? nullptr:_ident.c_str(), syslog_option, syslog_facility); + } + ~syslog_sink() + { + ::closelog(); + } + + syslog_sink(const syslog_sink&) = delete; + syslog_sink& operator=(const syslog_sink&) = delete; + + void log(const details::log_msg &msg) override + { + ::syslog(syslog_prio_from_level(msg), "%s", msg.raw.str().c_str()); + } + + void flush() override + { + } + + +private: + std::array _priorities; + //must store the ident because the man says openlog might use the pointer as is and not a string copy + const std::string _ident; + + // + // Simply maps spdlog's log level to syslog priority level. + // + int syslog_prio_from_level(const details::log_msg &msg) const + { + return _priorities[static_cast(msg.level)]; + } +}; +} +} + +#endif diff --git a/src/dionysus/wasserstein/spdlog/sinks/wincolor_sink.h b/src/dionysus/wasserstein/spdlog/sinks/wincolor_sink.h new file mode 100755 index 0000000..544aeb1 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/sinks/wincolor_sink.h @@ -0,0 +1,117 @@ +// +// Copyright(c) 2016 spdlog +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +#include "../sinks/base_sink.h" +#include "../details/null_mutex.h" +#include "../common.h" + +#include +#include +#include +#include + +namespace spdlog +{ +namespace sinks +{ +/* + * Windows color console sink. Uses WriteConsoleA to write to the console with colors + */ +template +class wincolor_sink: public base_sink +{ +public: + const WORD BOLD = FOREGROUND_INTENSITY; + const WORD RED = FOREGROUND_RED; + const WORD CYAN = FOREGROUND_GREEN | FOREGROUND_BLUE; + const WORD WHITE = FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE; + const WORD YELLOW = FOREGROUND_RED | FOREGROUND_GREEN; + + wincolor_sink(HANDLE std_handle): out_handle_(std_handle) + { + colors_[level::trace] = CYAN; + colors_[level::debug] = CYAN; + colors_[level::info] = WHITE | BOLD; + colors_[level::warn] = YELLOW | BOLD; + colors_[level::err] = RED | BOLD; // red bold + colors_[level::critical] = BACKGROUND_RED | WHITE | BOLD; // white bold on red background + colors_[level::off] = 0; + } + + virtual ~wincolor_sink() + { + this->flush(); + } + + wincolor_sink(const wincolor_sink& other) = delete; + wincolor_sink& operator=(const wincolor_sink& other) = delete; + +protected: + virtual void _sink_it(const details::log_msg& msg) override + { + auto color = colors_[msg.level]; + auto orig_attribs = set_console_attribs(color); + WriteConsoleA(out_handle_, msg.formatted.data(), static_cast(msg.formatted.size()), nullptr, nullptr); + SetConsoleTextAttribute(out_handle_, orig_attribs); //reset to orig colors + } + + virtual void _flush() override + { + // windows console always flushed? + } + + // change the color for the given level + void set_color(level::level_enum level, WORD color) + { + std::lock_guard lock(base_sink::_mutex); + colors_[level] = color; + } + +private: + HANDLE out_handle_; + std::map colors_; + + // set color and return the orig console attributes (for resetting later) + WORD set_console_attribs(WORD attribs) + { + CONSOLE_SCREEN_BUFFER_INFO orig_buffer_info; + GetConsoleScreenBufferInfo(out_handle_, &orig_buffer_info); + SetConsoleTextAttribute(out_handle_, attribs); + return orig_buffer_info.wAttributes; //return orig attribs + } +}; + +// +// windows color console to stdout +// +template +class wincolor_stdout_sink: public wincolor_sink +{ +public: + wincolor_stdout_sink() : wincolor_sink(GetStdHandle(STD_OUTPUT_HANDLE)) + {} +}; + +typedef wincolor_stdout_sink wincolor_stdout_sink_mt; +typedef wincolor_stdout_sink wincolor_stdout_sink_st; + +// +// windows color console to stderr +// +template +class wincolor_stderr_sink: public wincolor_sink +{ +public: + wincolor_stderr_sink() : wincolor_sink(GetStdHandle(STD_ERROR_HANDLE)) + {} +}; + +typedef wincolor_stderr_sink wincolor_stderr_sink_mt; +typedef wincolor_stderr_sink wincolor_stderr_sink_st; + +} +} diff --git a/src/dionysus/wasserstein/spdlog/spdlog.h b/src/dionysus/wasserstein/spdlog/spdlog.h new file mode 100755 index 0000000..f5b9ca6 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/spdlog.h @@ -0,0 +1,187 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// +// spdlog main header file. +// see example.cpp for usage example + +#pragma once + +#define SPDLOG_VERSION "0.13.0" + +#include "tweakme.h" +#include "common.h" +#include "logger.h" + +#include +#include +#include +#include + +namespace spdlog +{ + +// +// Return an existing logger or nullptr if a logger with such name doesn't exist. +// example: spdlog::get("my_logger")->info("hello {}", "world"); +// +std::shared_ptr get(const std::string& name); + + +// +// Set global formatting +// example: spdlog::set_pattern("%Y-%m-%d %H:%M:%S.%e %l : %v"); +// +void set_pattern(const std::string& format_string); +void set_formatter(formatter_ptr f); + +// +// Set global logging level for +// +void set_level(level::level_enum log_level); + +// +// Set global error handler +// +void set_error_handler(log_err_handler); + +// +// Turn on async mode (off by default) and set the queue size for each async_logger. +// effective only for loggers created after this call. +// queue_size: size of queue (must be power of 2): +// Each logger will pre-allocate a dedicated queue with queue_size entries upon construction. +// +// async_overflow_policy (optional, block_retry by default): +// async_overflow_policy::block_retry - if queue is full, block until queue has room for the new log entry. +// async_overflow_policy::discard_log_msg - never block and discard any new messages when queue overflows. +// +// worker_warmup_cb (optional): +// callback function that will be called in worker thread upon start (can be used to init stuff like thread affinity) +// +// worker_teardown_cb (optional): +// callback function that will be called in worker thread upon exit +// +void set_async_mode(size_t queue_size, const async_overflow_policy overflow_policy = async_overflow_policy::block_retry, const std::function& worker_warmup_cb = nullptr, const std::chrono::milliseconds& flush_interval_ms = std::chrono::milliseconds::zero(), const std::function& worker_teardown_cb = nullptr); + +// Turn off async mode +void set_sync_mode(); + + +// +// Create and register multi/single threaded basic file logger. +// Basic logger simply writes to given file without any limitatons or rotations. +// +std::shared_ptr basic_logger_mt(const std::string& logger_name, const filename_t& filename, bool truncate = false); +std::shared_ptr basic_logger_st(const std::string& logger_name, const filename_t& filename, bool truncate = false); + +// +// Create and register multi/single threaded rotating file logger +// +std::shared_ptr rotating_logger_mt(const std::string& logger_name, const filename_t& filename, size_t max_file_size, size_t max_files); +std::shared_ptr rotating_logger_st(const std::string& logger_name, const filename_t& filename, size_t max_file_size, size_t max_files); + +// +// Create file logger which creates new file on the given time (default in midnight): +// +std::shared_ptr daily_logger_mt(const std::string& logger_name, const filename_t& filename, int hour=0, int minute=0); +std::shared_ptr daily_logger_st(const std::string& logger_name, const filename_t& filename, int hour=0, int minute=0); + +// +// Create and register stdout/stderr loggers +// +std::shared_ptr stdout_logger_mt(const std::string& logger_name); +std::shared_ptr stdout_logger_st(const std::string& logger_name); +std::shared_ptr stderr_logger_mt(const std::string& logger_name); +std::shared_ptr stderr_logger_st(const std::string& logger_name); +// +// Create and register colored stdout/stderr loggers +// +std::shared_ptr stdout_color_mt(const std::string& logger_name); +std::shared_ptr stdout_color_st(const std::string& logger_name); +std::shared_ptr stderr_color_mt(const std::string& logger_name); +std::shared_ptr stderr_color_st(const std::string& logger_name); + + +// +// Create and register a syslog logger +// +#ifdef SPDLOG_ENABLE_SYSLOG +std::shared_ptr syslog_logger(const std::string& logger_name, const std::string& ident = "", int syslog_option = 0); +#endif + +#if defined(__ANDROID__) +std::shared_ptr android_logger(const std::string& logger_name, const std::string& tag = "spdlog"); +#endif + +// Create and register a logger with a single sink +std::shared_ptr create(const std::string& logger_name, const sink_ptr& sink); + +// Create and register a logger with multiple sinks +std::shared_ptr create(const std::string& logger_name, sinks_init_list sinks); +template +std::shared_ptr create(const std::string& logger_name, const It& sinks_begin, const It& sinks_end); + + +// Create and register a logger with templated sink type +// Example: +// spdlog::create("mylog", "dailylog_filename"); +template +std::shared_ptr create(const std::string& logger_name, Args...); + +// Create and register an async logger with a single sink +std::shared_ptr create_async(const std::string& logger_name, const sink_ptr& sink, size_t queue_size, const async_overflow_policy overflow_policy = async_overflow_policy::block_retry, const std::function& worker_warmup_cb = nullptr, const std::chrono::milliseconds& flush_interval_ms = std::chrono::milliseconds::zero(), const std::function& worker_teardown_cb = nullptr); + +// Create and register an async logger with multiple sinks +std::shared_ptr create_async(const std::string& logger_name, sinks_init_list sinks, size_t queue_size, const async_overflow_policy overflow_policy = async_overflow_policy::block_retry, const std::function& worker_warmup_cb = nullptr, const std::chrono::milliseconds& flush_interval_ms = std::chrono::milliseconds::zero(), const std::function& worker_teardown_cb = nullptr); +template +std::shared_ptr create_async(const std::string& logger_name, const It& sinks_begin, const It& sinks_end, size_t queue_size, const async_overflow_policy overflow_policy = async_overflow_policy::block_retry, const std::function& worker_warmup_cb = nullptr, const std::chrono::milliseconds& flush_interval_ms = std::chrono::milliseconds::zero(), const std::function& worker_teardown_cb = nullptr); + +// Register the given logger with the given name +void register_logger(std::shared_ptr logger); + +// Apply a user defined function on all registered loggers +// Example: +// spdlog::apply_all([&](std::shared_ptr l) {l->flush();}); +void apply_all(std::function)> fun); + +// Drop the reference to the given logger +void drop(const std::string &name); + +// Drop all references from the registry +void drop_all(); + + +/////////////////////////////////////////////////////////////////////////////// +// +// Trace & Debug can be switched on/off at compile time for zero cost debug statements. +// Uncomment SPDLOG_DEBUG_ON/SPDLOG_TRACE_ON in teakme.h to enable. +// SPDLOG_TRACE(..) will also print current file and line. +// +// Example: +// spdlog::set_level(spdlog::level::trace); +// SPDLOG_TRACE(my_logger, "some trace message"); +// SPDLOG_TRACE(my_logger, "another trace message {} {}", 1, 2); +// SPDLOG_DEBUG(my_logger, "some debug message {} {}", 3, 4); +// SPDLOG_DEBUG_IF(my_logger, true, "some debug message {} {}", 3, 4); +/////////////////////////////////////////////////////////////////////////////// + +#ifdef SPDLOG_TRACE_ON +#define SPDLOG_STR_H(x) #x +#define SPDLOG_STR_HELPER(x) SPDLOG_STR_H(x) +#define SPDLOG_TRACE(logger, ...) logger->trace("[" __FILE__ " line #" SPDLOG_STR_HELPER(__LINE__) "] " __VA_ARGS__) +#define SPDLOG_TRACE_IF(logger, flag, ...) logger->trace_if(flag, "[" __FILE__ " line #" SPDLOG_STR_HELPER(__LINE__) "] " __VA_ARGS__) +#else +#define SPDLOG_TRACE(logger, ...) +#endif + +#ifdef SPDLOG_DEBUG_ON +#define SPDLOG_DEBUG(logger, ...) logger->debug(__VA_ARGS__) +#define SPDLOG_DEBUG_IF(logger, flag, ...) logger->debug_if(flag, __VA_ARGS__) +#else +#define SPDLOG_DEBUG(logger, ...) +#endif + +} + + +#include "details/spdlog_impl.h" diff --git a/src/dionysus/wasserstein/spdlog/tweakme.h b/src/dionysus/wasserstein/spdlog/tweakme.h new file mode 100755 index 0000000..53f5cf7 --- /dev/null +++ b/src/dionysus/wasserstein/spdlog/tweakme.h @@ -0,0 +1,141 @@ +// +// Copyright(c) 2015 Gabi Melman. +// Distributed under the MIT License (http://opensource.org/licenses/MIT) +// + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// +// Edit this file to squeeze more performance, and to customize supported features +// +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Under Linux, the much faster CLOCK_REALTIME_COARSE clock can be used. +// This clock is less accurate - can be off by dozens of millis - depending on the kernel HZ. +// Uncomment to use it instead of the regular clock. +// +// #define SPDLOG_CLOCK_COARSE +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment if date/time logging is not needed and never appear in the log pattern. +// This will prevent spdlog from quering the clock on each log call. +// +// WARNING: If the log pattern contains any date/time while this flag is on, the result is undefined. +// You must set new pattern(spdlog::set_pattern(..") without any date/time in it +// +// #define SPDLOG_NO_DATETIME +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment if thread id logging is not needed (i.e. no %t in the log pattern). +// This will prevent spdlog from quering the thread id on each log call. +// +// WARNING: If the log pattern contains thread id (i.e, %t) while this flag is on, the result is undefined. +// +// #define SPDLOG_NO_THREAD_ID +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment if logger name logging is not needed. +// This will prevent spdlog from copying the logger name on each log call. +// +// #define SPDLOG_NO_NAME +/////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment to enable the SPDLOG_DEBUG/SPDLOG_TRACE macros. +// +// #define SPDLOG_DEBUG_ON +// #define SPDLOG_TRACE_ON +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment to avoid locking in the registry operations (spdlog::get(), spdlog::drop() spdlog::register()). +// Use only if your code never modifes concurrently the registry. +// Note that upon creating a logger the registry is modified by spdlog.. +// +// #define SPDLOG_NO_REGISTRY_MUTEX +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment to avoid spdlog's usage of atomic log levels +// Use only if your code never modifies a logger's log levels concurrently by different threads. +// +// #define SPDLOG_NO_ATOMIC_LEVELS +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment to enable usage of wchar_t for file names on Windows. +// +// #define SPDLOG_WCHAR_FILENAMES +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment to override default eol ("\n" or "\r\n" under Linux/Windows) +// +// #define SPDLOG_EOL ";-)\n" +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment to use your own copy of the fmt library instead of spdlog's copy. +// In this case spdlog will try to include so set your -I flag accordingly. +// +// #define SPDLOG_FMT_EXTERNAL +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment to enable syslog (disabled by default) +// +// #define SPDLOG_ENABLE_SYSLOG +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment to enable wchar_t support (convert to utf8) +// +// #define SPDLOG_WCHAR_TO_UTF8_SUPPORT +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment to prevent child processes from inheriting log file descriptors +// +// #define SPDLOG_PREVENT_CHILD_FD +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment to mark some types as final, allowing more optimizations in release +// mode with recent compilers. See GCC's documentation for -Wsuggest-final-types +// for instance. +// +// #define SPDLOG_FINAL final +/////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment to enable message counting feature. Adds %i logger pattern that +// prints log message sequence id. +// +// #define SPDLOG_ENABLE_MESSAGE_COUNTER +/////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////// +// Uncomment to enable user defined tag names +// +// #define SPDLOG_LEVEL_NAMES { " TRACE", " DEBUG", " INFO", +// " WARNING", " ERROR", "CRITICAL", "OFF" }; +/////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/src/dionysus/wasserstein/wasserstein.h b/src/dionysus/wasserstein/wasserstein.h new file mode 100755 index 0000000..b90a545 --- /dev/null +++ b/src/dionysus/wasserstein/wasserstein.h @@ -0,0 +1,347 @@ +/* + +Copyright (c) 2015, M. Kerber, D. Morozov, A. Nigmetov +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +(Enhancements) to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to copyright holder, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. + + */ + +#ifndef HERA_WASSERSTEIN_H +#define HERA_WASSERSTEIN_H + +#include +#include +#include + +#include "def_debug_ws.h" +#include "basic_defs_ws.h" +#include "diagram_reader.h" +#include "auction_runner_gs.h" +#include "auction_runner_gs_single_diag.h" +#include "auction_runner_jac.h" +#include "auction_runner_fr.h" + + +namespace hera +{ + +template().begin())>::type > +struct DiagramTraits +{ + using Container = PairContainer_; + using PointType = PointType_; + using RealType = typename std::remove_reference< decltype(std::declval()[0]) >::type; + + static RealType get_x(const PointType& p) { return p[0]; } + static RealType get_y(const PointType& p) { return p[1]; } +}; + +template +struct DiagramTraits> +{ + using PointType = std::pair; + using RealType = double; + using Container = std::vector; + + static RealType get_x(const PointType& p) { return p.first; } + static RealType get_y(const PointType& p) { return p.second; } +}; + + +namespace ws +{ + + // compare as multisets + template + bool are_equal(const PairContainer& dgm1, const PairContainer& dgm2) + { + if (dgm1.size() != dgm2.size()) { + return false; + } + + using Traits = typename hera::DiagramTraits; + using PointType = typename Traits::PointType; + + std::map m1, m2; + + for(const auto& pair1 : dgm1) { + m1[pair1]++; + } + + for(const auto& pair2 : dgm2) { + m2[pair2]++; + } + + return m1 == m2; + } + + // to handle points with one coordinate = infinity + template + RealType get_one_dimensional_cost(std::vector& set_A, + std::vector& set_B, + const RealType wasserstein_power) + { + if (set_A.size() != set_B.size()) { + return std::numeric_limits::infinity(); + } + std::sort(set_A.begin(), set_A.end()); + std::sort(set_B.begin(), set_B.end()); + RealType result = 0.0; + for(size_t i = 0; i < set_A.size(); ++i) { + result += std::pow(std::fabs(set_A[i] - set_B[i]), wasserstein_power); + } + return result; + } + + + template + struct SplitProblemInput + { + std::vector> A_1; + std::vector> B_1; + std::vector> A_2; + std::vector> B_2; + + std::unordered_map A_1_indices; + std::unordered_map A_2_indices; + std::unordered_map B_1_indices; + std::unordered_map B_2_indices; + + RealType mid_coord { 0.0 }; + RealType strip_width { 0.0 }; + + void init_vectors(size_t n) + { + + A_1_indices.clear(); + A_2_indices.clear(); + B_1_indices.clear(); + B_2_indices.clear(); + + A_1.clear(); + A_2.clear(); + B_1.clear(); + B_2.clear(); + + A_1.reserve(n / 2); + B_1.reserve(n / 2); + A_2.reserve(n / 2); + B_2.reserve(n / 2); + } + + void init(const std::vector>& A, + const std::vector>& B) + { + using DiagramPointR = DiagramPoint; + + init_vectors(A.size()); + + RealType min_sum = std::numeric_limits::max(); + RealType max_sum = -std::numeric_limits::max(); + for(const auto& p_A : A) { + RealType s = p_A[0] + p_A[1]; + if (s > max_sum) + max_sum = s; + if (s < min_sum) + min_sum = s; + mid_coord += s; + } + + mid_coord /= A.size(); + + strip_width = 0.25 * (max_sum - min_sum); + + auto first_diag_iter = std::upper_bound(A.begin(), A.end(), 0, [](const int& a, const DiagramPointR& p) { return a < (int)(p.is_diagonal()); }); + size_t num_normal_A_points = std::distance(A.begin(), first_diag_iter); + + // process all normal points in A, + // projections follow normal points + for(size_t i = 0; i < A.size(); ++i) { + + assert(i < num_normal_A_points and A.is_normal() or i >= num_normal_A_points and A.is_diagonal()); + assert(i < num_normal_A_points and B.is_diagonal() or i >= num_normal_A_points and B.is_normal()); + + RealType s = i < num_normal_A_points ? A[i][0] + A[i][1] : B[i][0] + B[i][1]; + + if (s < mid_coord + strip_width) { + // add normal point and its projection to the + // left half + A_1.push_back(A[i]); + B_1.push_back(B[i]); + A_1_indices[i] = A_1.size() - 1; + B_1_indices[i] = B_1.size() - 1; + } + + if (s > mid_coord - strip_width) { + // to the right half + A_2.push_back(A[i]); + B_2.push_back(B[i]); + A_2_indices[i] = A_2.size() - 1; + B_2_indices[i] = B_2.size() - 1; + } + + } + } // end init + + }; + + + // CAUTION: + // this function assumes that all coordinates are finite + // points at infinity are processed in wasserstein_cost + template + RealType wasserstein_cost_vec(const std::vector>& A, + const std::vector>& B, + const AuctionParams& params, + const std::string& _log_filename_prefix) + { + if (params.wasserstein_power < 1.0) { + throw std::runtime_error("Bad q in Wasserstein " + std::to_string(params.wasserstein_power)); + } + if (params.delta < 0.0) { + throw std::runtime_error("Bad delta in Wasserstein " + std::to_string(params.delta)); + } + if (params.initial_epsilon < 0.0) { + throw std::runtime_error("Bad initial epsilon in Wasserstein" + std::to_string(params.initial_epsilon)); + } + if (params.epsilon_common_ratio < 0.0) { + throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(params.epsilon_common_ratio)); + } + + if (A.empty() and B.empty()) + return 0.0; + + RealType result; + + // just use Gauss-Seidel + AuctionRunnerGS auction(A, B, params, _log_filename_prefix); + auction.run_auction(); + result = auction.get_wasserstein_cost(); + return result; + } + +} // ws + + + +template +typename DiagramTraits::RealType +wasserstein_cost(const PairContainer& A, + const PairContainer& B, + const AuctionParams< typename DiagramTraits::RealType >& params, + const std::string& _log_filename_prefix = "") +{ + using Traits = DiagramTraits; + + //using PointType = typename Traits::PointType; + using RealType = typename Traits::RealType; + + if (hera::ws::are_equal(A, B)) { + return 0.0; + } + + bool a_empty = true; + bool b_empty = true; + RealType total_cost_A = 0.0; + RealType total_cost_B = 0.0; + + using DgmPoint = hera::ws::DiagramPoint; + + std::vector dgm_A, dgm_B; + // coordinates of points at infinity + std::vector x_plus_A, x_minus_A, y_plus_A, y_minus_A; + std::vector x_plus_B, x_minus_B, y_plus_B, y_minus_B; + // loop over A, add projections of A-points to corresponding positions + // in B-vector + for(auto& pair_A : A) { + a_empty = false; + RealType x = Traits::get_x(pair_A); + RealType y = Traits::get_y(pair_A); + if ( x == std::numeric_limits::infinity()) { + y_plus_A.push_back(y); + } else if (x == -std::numeric_limits::infinity()) { + y_minus_A.push_back(y); + } else if (y == std::numeric_limits::infinity()) { + x_plus_A.push_back(x); + } else if (y == -std::numeric_limits::infinity()) { + x_minus_A.push_back(x); + } else { + dgm_A.emplace_back(x, y, DgmPoint::NORMAL); + dgm_B.emplace_back(x, y, DgmPoint::DIAG); + total_cost_A += std::pow(dgm_A.back().persistence_lp(params.internal_p), params.wasserstein_power); + } + } + // the same for B + for(auto& pair_B : B) { + b_empty = false; + RealType x = Traits::get_x(pair_B); + RealType y = Traits::get_y(pair_B); + if (x == std::numeric_limits::infinity()) { + y_plus_B.push_back(y); + } else if (x == -std::numeric_limits::infinity()) { + y_minus_B.push_back(y); + } else if (y == std::numeric_limits::infinity()) { + x_plus_B.push_back(x); + } else if (y == -std::numeric_limits::infinity()) { + x_minus_B.push_back(x); + } else { + dgm_A.emplace_back(x, y, DgmPoint::DIAG); + dgm_B.emplace_back(x, y, DgmPoint::NORMAL); + total_cost_B += std::pow(dgm_B.back().persistence_lp(params.internal_p), params.wasserstein_power); + } + } + + RealType infinity_cost = ws::get_one_dimensional_cost(x_plus_A, x_plus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(x_minus_A, x_minus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(y_plus_A, y_plus_B, params.wasserstein_power); + infinity_cost += ws::get_one_dimensional_cost(y_minus_A, y_minus_B, params.wasserstein_power); + + if (a_empty) + return total_cost_B + infinity_cost; + + if (b_empty) + return total_cost_A + infinity_cost; + + + if (infinity_cost == std::numeric_limits::infinity()) { + return infinity_cost; + } else { + return infinity_cost + wasserstein_cost_vec(dgm_A, dgm_B, params, _log_filename_prefix); + } + +} + +template +typename DiagramTraits::RealType +wasserstein_dist(PairContainer& A, + PairContainer& B, + const AuctionParams::RealType> params, + const std::string& _log_filename_prefix = "") +{ + using Real = typename DiagramTraits::RealType; + return std::pow(hera::wasserstein_cost(A, B, params, _log_filename_prefix), Real(1.)/params.wasserstein_power); +} + +} // end of namespace hera + +#endif diff --git a/src/dionysus/wasserstein/wasserstein_pure_geom.hpp b/src/dionysus/wasserstein/wasserstein_pure_geom.hpp new file mode 100755 index 0000000..2a57599 --- /dev/null +++ b/src/dionysus/wasserstein/wasserstein_pure_geom.hpp @@ -0,0 +1,87 @@ +#ifndef WASSERSTEIN_PURE_GEOM_HPP +#define WASSERSTEIN_PURE_GEOM_HPP + +#define WASSERSTEIN_PURE_GEOM + + +#include "diagram_reader.h" +#include "auction_oracle_kdtree_pure_geom.h" +#include "auction_runner_gs.h" +#include "auction_runner_jac.h" + +namespace hera +{ +namespace ws +{ + + template + using DynamicTraits = typename hera::ws::dnn::DynamicPointTraits; + + template + using DynamicPoint = typename hera::ws::dnn::DynamicPointTraits::PointType; + + template + using DynamicPointVector = typename hera::ws::dnn::DynamicPointVector; + + template + using AuctionRunnerGSR = typename hera::ws::AuctionRunnerGS, hera::ws::dnn::DynamicPointVector>; + + template + using AuctionRunnerJacR = typename hera::ws::AuctionRunnerJac, hera::ws::dnn::DynamicPointVector>; + + +double wasserstein_cost(const DynamicPointVector& set_A, const DynamicPointVector& set_B, const AuctionParams& params) +{ + if (params.wasserstein_power < 1.0) { + throw std::runtime_error("Bad q in Wasserstein " + std::to_string(params.wasserstein_power)); + } + + if (params.delta < 0.0) { + throw std::runtime_error("Bad delta in Wasserstein " + std::to_string(params.delta)); + } + + if (params.initial_epsilon < 0.0) { + throw std::runtime_error("Bad initial epsilon in Wasserstein" + std::to_string(params.initial_epsilon)); + } + + if (params.epsilon_common_ratio < 0.0) { + throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(params.epsilon_common_ratio)); + } + + if (set_A.size() != set_B.size()) { + throw std::runtime_error("Different cardinalities of point clouds: " + std::to_string(set_A.size()) + " != " + std::to_string(set_B.size())); + } + + DynamicTraits traits(params.dim); + + DynamicPointVector set_A_copy(set_A); + DynamicPointVector set_B_copy(set_B); + + // set point id to the index in vector + for(size_t i = 0; i < set_A.size(); ++i) { + traits.id(set_A_copy[i]) = i; + traits.id(set_B_copy[i]) = i; + } + + if (params.max_bids_per_round == 1) { + hera::ws::AuctionRunnerGSR auction(set_A_copy, set_B_copy, params); + auction.run_auction(); + return auction.get_wasserstein_cost(); + } else { + hera::ws::AuctionRunnerJacR auction(set_A_copy, set_B_copy, params); + auction.run_auction(); + return auction.get_wasserstein_cost(); + } + +} + +double wasserstein_dist(const DynamicPointVector& set_A, const DynamicPointVector& set_B, const AuctionParams& params) +{ + return std::pow(wasserstein_cost(set_A, set_B, params), 1.0 / params.wasserstein_power); +} + +} // ws +} // hera + + +#endif diff --git a/src/rips.h b/src/rips.h index ab11cd9..531bf93 100644 --- a/src/rips.h +++ b/src/rips.h @@ -1,261 +1,217 @@ -#include -#include - -// for changing formats and typecasting -#include - -//for GUDHI -#include - -// for Dionysus -#include - -// for phat -#include - -// for Rips -#include -#include -#include -#include - - -// ripsFiltration -/** \brief Interface for R code, construct the rips filtration on the input - * set of points. - * - * @param[out] Rcpp::List A list - * @param[in] X Either an nxd matrix of coordinates, - * or an nxn matrix of distances of points - * @param[in] maxdimension Max dimension of the homological features to be computed. - * @param[in] maxscale Threshold for the Rips complex - * @param[in] dist "euclidean" for Euclidean distance, - * "arbitrary" for an arbitrary distance - * @param[in] library Either "GUDHI" or "Dionysus" - * @param[in] printProgress Is progress printed? - * @param[in] max_num_bars Write the max_num_pairs most persistent pairs of the - * diagram. Diagram must point to enough memory space for - * 3*max_num_pairs double. If there is not enough pairs in the diagram, - * write nothing after. - */ -template< typename IntVector, typename RealMatrix, typename VectorList, - typename RealVector, typename Print > -inline void ripsFiltration( - const RealMatrix & X, - const unsigned nSample, - const unsigned nDim, - const int maxdimension, - const double maxscale, - const std::string & dist, - const std::string & library, - const bool printProgress, - const Print & print, - VectorList & cmplx, - RealVector & values, - VectorList & boundary -) { - if (library[0] == 'G') { - Gudhi::Simplex_tree<> smplxTree = - RipsFiltrationGudhi< Gudhi::Simplex_tree<> >(X, nSample, nDim, - maxdimension, maxscale, printProgress, print); - filtrationGudhiToTda< IntVector >(smplxTree, cmplx, values, boundary); - } - - else { - - if (dist[0] == 'e') { - // RipsDiag for L2 distance - if (library[0] == 'D' && library[1] == '2') { - filtrationDionysus2Tda< IntVector >( - RipsFiltrationDionysus2< PairDistances2, Generator2, FltrR2 >(X, nSample, - nDim, false, maxdimension, maxscale, printProgress, print), - cmplx, values, boundary); - } - else{ - filtrationDionysusToTda< IntVector >( - RipsFiltrationDionysus< PairDistances, Generator, FltrR >(X, nSample, - nDim, false, maxdimension, maxscale, printProgress, print), - cmplx, values, boundary); - } - } - else { - - if (library[0] == 'D' && library[1] == '2') { - filtrationDionysus2Tda< IntVector >( - RipsFiltrationDionysus2< PairDistances2A, Generator2A, FltrR2A >(X, nSample, - nDim, true, maxdimension, maxscale, printProgress, print), - cmplx, values, boundary); - } else { - // RipsDiag for arbitrary distance - filtrationDionysusToTda< IntVector >( - RipsFiltrationDionysus< PairDistancesA, GeneratorA, FltrRA >(X, - nSample, nDim, true, maxdimension, maxscale, printProgress, - print), - cmplx, values, boundary); - } - } - } -} - - - -// ripsDiag -/** \brief Interface for R code, construct the persistence diagram - * of the Rips complex constructed on the input set of points. - * - * @param[out] Rcpp::List A list - * @param[in] X Either an nxd matrix of coordinates, - * or an nxn matrix of distances of points - * @param[in] maxdimension Max dimension of the homological features to be computed. - * @param[in] maxscale Threshold for the Rips complex - * @param[in] dist "euclidean" for Euclidean distance, - * "arbitrary" for an arbitrary distance - * @param[in] libraryFiltration Either "GUDHI" or "Dionysus" - * @param[in] libraryDiag Either "GUDHI", "Dionysus", or "PHAT" - * @param[in] location Are location of birth point, death point, - * and representative cycles returned? - * @param[in] printProgress Is progress printed? - * @param[in] max_num_bars Write the max_num_pairs most persistent pairs of the - * diagram. Diagram must point to enough memory space for - * 3*max_num_pairs double. If there is not enough pairs in the diagram, - * write nothing after. - */ -template< typename RealMatrix, typename Print > -inline void ripsDiag( - const RealMatrix & X, - const unsigned nSample, - const unsigned nDim, - const int maxdimension, - const double maxscale, - const std::string & dist, - const std::string & libraryFiltration, - const std::string & libraryDiag, - const bool location, - const bool printProgress, - const Print & print, - std::vector< std::vector< std::vector< double > > > & persDgm, - std::vector< std::vector< std::vector< unsigned > > > & persLoc, - std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle -) { - - if (libraryFiltration[0] == 'G') { - Gudhi::Simplex_tree<> smplxTree = - RipsFiltrationGudhi< Gudhi::Simplex_tree<> >(X, nSample, nDim, - maxdimension, maxscale, printProgress, print); - - // Compute the persistence diagram of the complex - if (libraryDiag[0] == 'G') { - int p = 2; //characteristic of the coefficient field for homology - double min_persistence = 0; //minimal length for persistent intervals - FiltrationDiagGudhi( - smplxTree, p, min_persistence, maxdimension, printProgress, persDgm); - } - else if (libraryDiag[0] == 'D' && libraryDiag[1] == '2') { - FltrR2 filtration = filtrationGudhiToDionysus2< FltrR2 >(smplxTree); - FiltrationDiagDionysus2< Persistence2 >( - filtration, maxdimension, location, printProgress, persDgm, persLoc, - persCycle); - } - else if (libraryDiag[0] == 'D') { - FltrR filtration = filtrationGudhiToDionysus< FltrR >(smplxTree); - FiltrationDiagDionysus< Persistence >( - filtration, maxdimension, location, printProgress, persDgm, persLoc, - persCycle); - } - else { - std::vector< phat::column > cmplx; - std::vector< double > values; - phat::boundary_matrix< phat::vector_vector > boundary_matrix; - filtrationGudhiToPhat< phat::column, phat::dimension >( - smplxTree, cmplx, values, boundary_matrix); - FiltrationDiagPhat( - cmplx, values, boundary_matrix, maxdimension, location, - printProgress, persDgm, persLoc, persCycle); - } - } - else { - if (dist[0] == 'e') { - // RipsDiag for L2 distance - if (libraryDiag[0] == 'D' && libraryDiag[0] == '2') { - FiltrationDiagDionysus2( - RipsFiltrationDionysus2< PairDistances2, Generator2, FltrR2 >(X, nSample, - nDim, false, maxdimension, maxscale, printProgress, print), - maxdimension, location, printProgress, persDgm, persLoc, persCycle - ); - } - else { - FltrR filtration = - RipsFiltrationDionysus< PairDistances, Generator, FltrR >( - X, nSample, nDim, false, maxdimension, maxscale, - printProgress, print); - - if (libraryDiag[0] == 'D') { - FiltrationDiagDionysus< Persistence >( - filtration, maxdimension, location, printProgress, persDgm, - persLoc, persCycle); - } - else if (libraryDiag[0] == 'G') { - Gudhi::Simplex_tree<> smplxTree = - filtrationDionysusToGudhi< Gudhi::Simplex_tree<> >(filtration); - int p = 2; //characteristic of the coefficient field for homology - double min_persistence = 0; //minimal length for persistent intervals - FiltrationDiagGudhi( - smplxTree, p, min_persistence, maxdimension, printProgress, - persDgm); - } - else { - std::vector< phat::column > cmplx; - std::vector< double > values; - phat::boundary_matrix< phat::vector_vector > boundary_matrix; - filtrationDionysusToPhat< phat::column, phat::dimension >( - filtration, cmplx, values, boundary_matrix); - FiltrationDiagPhat( - cmplx, values, boundary_matrix, maxdimension, location, - printProgress, persDgm, persLoc, persCycle); - } - } - } - else { - // RipsDiag for arbitrary distance - - if (libraryDiag[0] == 'D' && libraryDiag[1] == '2') { - FiltrationDiagDionysus2( - RipsFiltrationDionysus2< PairDistances2A, Generator2A, FltrR2A >(X, nSample, - nDim, true, maxdimension, maxscale, printProgress, print), - maxdimension, location, printProgress, persDgm, persLoc, persCycle); - } else { - - FltrRA filtration = - RipsFiltrationDionysus< PairDistancesA, GeneratorA, FltrRA >( - X, nSample, nDim, true, maxdimension, maxscale, - printProgress, print); - - if (libraryDiag[0] == 'D') { - FiltrationDiagDionysus< Persistence >( - filtration, maxdimension, location, printProgress, persDgm, - persLoc, persCycle); - } - else if (libraryDiag[0] == 'G') { - Gudhi::Simplex_tree<> smplxTree = - filtrationDionysusToGudhi< Gudhi::Simplex_tree<> >(filtration); - int p = 2; //characteristic of the coefficient field for homology - double min_persistence = 0; //minimal length for persistent intervals - FiltrationDiagGudhi( - smplxTree, p, min_persistence, maxdimension, printProgress, - persDgm); - } - else { - std::vector< phat::column > cmplx; - std::vector< double > values; - phat::boundary_matrix< phat::vector_vector > boundary_matrix; - filtrationDionysusToPhat< phat::column, phat::dimension >( - filtration, cmplx, values, boundary_matrix); - FiltrationDiagPhat( - cmplx, values, boundary_matrix, maxdimension, location, - printProgress, persDgm, persLoc, persCycle); - } - } - } - } -} - +#include +#include + +// for changing formats and typecasting +#include + +//for GUDHI +#include + +// for Dionysus +#include + +// for phat +#include + +// for Rips +#include +#include + + +// ripsFiltration +/** \brief Interface for R code, construct the rips filtration on the input + * set of points. + * + * @param[out] Rcpp::List A list + * @param[in] X Either an nxd matrix of coordinates, + * or an nxn matrix of distances of points + * @param[in] maxdimension Max dimension of the homological features to be computed. + * @param[in] maxscale Threshold for the Rips complex + * @param[in] dist "euclidean" for Euclidean distance, + * "arbitrary" for an arbitrary distance + * @param[in] library Either "GUDHI" or "Dionysus" + * @param[in] printProgress Is progress printed? + * @param[in] max_num_bars Write the max_num_pairs most persistent pairs of the + * diagram. Diagram must point to enough memory space for + * 3*max_num_pairs double. If there is not enough pairs in the diagram, + * write nothing after. + */ +template< typename IntVector, typename RealMatrix, typename VectorList, + typename RealVector, typename Print > +inline void ripsFiltration( + const RealMatrix & X, + const unsigned nSample, + const unsigned nDim, + const int maxdimension, + const double maxscale, + const std::string & dist, + const std::string & library, + const bool printProgress, + const Print & print, + VectorList & cmplx, + RealVector & values, + VectorList & boundary +) { + if (library[0] == 'G') { + Gudhi::Simplex_tree<> smplxTree = + RipsFiltrationGudhi< Gudhi::Simplex_tree<> >(X, nSample, nDim, + maxdimension, maxscale, printProgress, print); + filtrationGudhiToTda< IntVector >(smplxTree, cmplx, values, boundary); + } + else { + + if (dist[0] == 'e') { + // RipsDiag for L2 distance + filtrationDionysus2Tda< IntVector >( + RipsFiltrationDionysus2< PairDistances2, Generator2, FltrR2 >(X, nSample, + nDim, false, maxdimension, maxscale, printProgress, print), + cmplx, values, boundary); + } + else { + // RipsDiag for arbitrary distance + filtrationDionysus2Tda< IntVector >( + RipsFiltrationDionysus2< PairDistances2A, Generator2A, FltrR2A >(X, + nSample, nDim, true, maxdimension, maxscale, printProgress, + print), + cmplx, values, boundary); + } + } +} + + + +// ripsDiag +/** \brief Interface for R code, construct the persistence diagram + * of the Rips complex constructed on the input set of points. + * + * @param[out] Rcpp::List A list + * @param[in] X Either an nxd matrix of coordinates, + * or an nxn matrix of distances of points + * @param[in] maxdimension Max dimension of the homological features to be computed. + * @param[in] maxscale Threshold for the Rips complex + * @param[in] dist "euclidean" for Euclidean distance, + * "arbitrary" for an arbitrary distance + * @param[in] libraryFiltration Either "GUDHI" or "Dionysus" + * @param[in] libraryDiag Either "GUDHI", "Dionysus", or "PHAT" + * @param[in] location Are location of birth point, death point, + * and representative cycles returned? + * @param[in] printProgress Is progress printed? + * @param[in] max_num_bars Write the max_num_pairs most persistent pairs of the + * diagram. Diagram must point to enough memory space for + * 3*max_num_pairs double. If there is not enough pairs in the diagram, + * write nothing after. + */ +template< typename RealMatrix, typename Print > +inline void ripsDiag( + const RealMatrix & X, + const unsigned nSample, + const unsigned nDim, + const int maxdimension, + const double maxscale, + const std::string & dist, + const std::string & libraryFiltration, + const std::string & libraryDiag, + const bool location, + const bool printProgress, + const Print & print, + std::vector< std::vector< std::vector< double > > > & persDgm, + std::vector< std::vector< std::vector< unsigned > > > & persLoc, + std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle +) { + + if (libraryFiltration[0] == 'G') { + Gudhi::Simplex_tree<> smplxTree = + RipsFiltrationGudhi< Gudhi::Simplex_tree<> >(X, nSample, nDim, + maxdimension, maxscale, printProgress, print); + + // Compute the persistence diagram of the complex + if (libraryDiag[0] == 'G') { + int p = 2; //characteristic of the coefficient field for homology + double min_persistence = 0; //minimal length for persistent intervals + FiltrationDiagGudhi( + smplxTree, p, min_persistence, maxdimension, printProgress, persDgm); + } + else if (libraryDiag[0] == 'D') { + FltrR2 filtration = filtrationGudhiToDionysus2< FltrR2 >(smplxTree); + FiltrationDiagDionysus2< Persistence2 >( + filtration, maxdimension, location, printProgress, persDgm, persLoc, + persCycle); + } + else { + std::vector< phat::column > cmplx; + std::vector< double > values; + phat::boundary_matrix< phat::vector_vector > boundary_matrix; + filtrationGudhiToPhat< phat::column, phat::dimension >( + smplxTree, cmplx, values, boundary_matrix); + FiltrationDiagPhat( + cmplx, values, boundary_matrix, maxdimension, location, + printProgress, persDgm, persLoc, persCycle); + } + } + else { + if (dist[0] == 'e') { + // RipsDiag for L2 distance + FltrR2 filtration = + RipsFiltrationDionysus2< PairDistances2, Generator2, FltrR2 >( + X, nSample, nDim, false, maxdimension, maxscale, + printProgress, print); + + if (libraryDiag[0] == 'D') { + FiltrationDiagDionysus2< Persistence2 >( + filtration, maxdimension, location, printProgress, persDgm, + persLoc, persCycle); + } + else if (libraryDiag[0] == 'G') { + Gudhi::Simplex_tree<> smplxTree = + filtrationDionysus2Gudhi< Gudhi::Simplex_tree<> >(filtration); + int p = 2; //characteristic of the coefficient field for homology + double min_persistence = 0; //minimal length for persistent intervals + FiltrationDiagGudhi( + smplxTree, p, min_persistence, maxdimension, printProgress, + persDgm); + } + else { + std::vector< phat::column > cmplx; + std::vector< double > values; + phat::boundary_matrix< phat::vector_vector > boundary_matrix; + filtrationDionysus2ToPhat< phat::column, phat::dimension >( + filtration, cmplx, values, boundary_matrix); + FiltrationDiagPhat( + cmplx, values, boundary_matrix, maxdimension, location, + printProgress, persDgm, persLoc, persCycle); + } + } + else { + // RipsDiag for arbitrary distance + FltrR2A filtration = + RipsFiltrationDionysus2< PairDistances2A, Generator2A, FltrR2A >( + X, nSample, nDim, true, maxdimension, maxscale, + printProgress, print); + + if (libraryDiag[0] == 'D') { + FiltrationDiagDionysus2< Persistence2 >( + filtration, maxdimension, location, printProgress, persDgm, + persLoc, persCycle); + } + else if (libraryDiag[0] == 'G') { + Gudhi::Simplex_tree<> smplxTree = + filtrationDionysus2Gudhi< Gudhi::Simplex_tree<> >(filtration); + int p = 2; //characteristic of the coefficient field for homology + double min_persistence = 0; //minimal length for persistent intervals + FiltrationDiagGudhi( + smplxTree, p, min_persistence, maxdimension, printProgress, + persDgm); + } + else { + std::vector< phat::column > cmplx; + std::vector< double > values; + phat::boundary_matrix< phat::vector_vector > boundary_matrix; + filtrationDionysus2ToPhat< phat::column, phat::dimension >( + filtration, cmplx, values, boundary_matrix); + FiltrationDiagPhat( + cmplx, values, boundary_matrix, maxdimension, location, + printProgress, persDgm, persLoc, persCycle); + } + } + } +} diff --git a/src/ripsb.h b/src/ripsb.h new file mode 100644 index 0000000..bfe338f --- /dev/null +++ b/src/ripsb.h @@ -0,0 +1,262 @@ +#include +#include + +// for changing formats and typecasting +#include + +//for GUDHI +#include + +// for Dionysus +#include +#include + +// for phat +#include + +// for Rips +#include +#include +#include +#include + + +// ripsFiltration +/** \brief Interface for R code, construct the rips filtration on the input + * set of points. + * + * @param[out] Rcpp::List A list + * @param[in] X Either an nxd matrix of coordinates, + * or an nxn matrix of distances of points + * @param[in] maxdimension Max dimension of the homological features to be computed. + * @param[in] maxscale Threshold for the Rips complex + * @param[in] dist "euclidean" for Euclidean distance, + * "arbitrary" for an arbitrary distance + * @param[in] library Either "GUDHI" or "Dionysus" + * @param[in] printProgress Is progress printed? + * @param[in] max_num_bars Write the max_num_pairs most persistent pairs of the + * diagram. Diagram must point to enough memory space for + * 3*max_num_pairs double. If there is not enough pairs in the diagram, + * write nothing after. + */ +template< typename IntVector, typename RealMatrix, typename VectorList, + typename RealVector, typename Print > +inline void ripsFiltration( + const RealMatrix & X, + const unsigned nSample, + const unsigned nDim, + const int maxdimension, + const double maxscale, + const std::string & dist, + const std::string & library, + const bool printProgress, + const Print & print, + VectorList & cmplx, + RealVector & values, + VectorList & boundary +) { + if (library[0] == 'G') { + Gudhi::Simplex_tree<> smplxTree = + RipsFiltrationGudhi< Gudhi::Simplex_tree<> >(X, nSample, nDim, + maxdimension, maxscale, printProgress, print); + filtrationGudhiToTda< IntVector >(smplxTree, cmplx, values, boundary); + } + + else { + + if (dist[0] == 'e') { + // RipsDiag for L2 distance + if (library[0] == 'D' && library[1] == '2') { + filtrationDionysus2Tda< IntVector >( + RipsFiltrationDionysus2< PairDistances2, Generator2, FltrR2 >(X, nSample, + nDim, false, maxdimension, maxscale, printProgress, print), + cmplx, values, boundary); + } + else{ + filtrationDionysusToTda< IntVector >( + RipsFiltrationDionysus< PairDistances, Generator, FltrR >(X, nSample, + nDim, false, maxdimension, maxscale, printProgress, print), + cmplx, values, boundary); + } + } + else { + + if (library[0] == 'D' && library[1] == '2') { + filtrationDionysus2Tda< IntVector >( + RipsFiltrationDionysus2< PairDistances2A, Generator2A, FltrR2A >(X, nSample, + nDim, true, maxdimension, maxscale, printProgress, print), + cmplx, values, boundary); + } else { + // RipsDiag for arbitrary distance + filtrationDionysusToTda< IntVector >( + RipsFiltrationDionysus< PairDistancesA, GeneratorA, FltrRA >(X, + nSample, nDim, true, maxdimension, maxscale, printProgress, + print), + cmplx, values, boundary); + } + } + } +} + + + +// ripsDiag +/** \brief Interface for R code, construct the persistence diagram + * of the Rips complex constructed on the input set of points. + * + * @param[out] Rcpp::List A list + * @param[in] X Either an nxd matrix of coordinates, + * or an nxn matrix of distances of points + * @param[in] maxdimension Max dimension of the homological features to be computed. + * @param[in] maxscale Threshold for the Rips complex + * @param[in] dist "euclidean" for Euclidean distance, + * "arbitrary" for an arbitrary distance + * @param[in] libraryFiltration Either "GUDHI" or "Dionysus" + * @param[in] libraryDiag Either "GUDHI", "Dionysus", or "PHAT" + * @param[in] location Are location of birth point, death point, + * and representative cycles returned? + * @param[in] printProgress Is progress printed? + * @param[in] max_num_bars Write the max_num_pairs most persistent pairs of the + * diagram. Diagram must point to enough memory space for + * 3*max_num_pairs double. If there is not enough pairs in the diagram, + * write nothing after. + */ +template< typename RealMatrix, typename Print > +inline void ripsDiag( + const RealMatrix & X, + const unsigned nSample, + const unsigned nDim, + const int maxdimension, + const double maxscale, + const std::string & dist, + const std::string & libraryFiltration, + const std::string & libraryDiag, + const bool location, + const bool printProgress, + const Print & print, + std::vector< std::vector< std::vector< double > > > & persDgm, + std::vector< std::vector< std::vector< unsigned > > > & persLoc, + std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle +) { + + if (libraryFiltration[0] == 'G') { + Gudhi::Simplex_tree<> smplxTree = + RipsFiltrationGudhi< Gudhi::Simplex_tree<> >(X, nSample, nDim, + maxdimension, maxscale, printProgress, print); + + // Compute the persistence diagram of the complex + if (libraryDiag[0] == 'G') { + int p = 2; //characteristic of the coefficient field for homology + double min_persistence = 0; //minimal length for persistent intervals + FiltrationDiagGudhi( + smplxTree, p, min_persistence, maxdimension, printProgress, persDgm); + } + else if (libraryDiag[0] == 'D' && libraryDiag[1] == '2') { + FltrR2 filtration = filtrationGudhiToDionysus2< FltrR2 >(smplxTree); + FiltrationDiagDionysus2< Persistence2 >( + filtration, maxdimension, location, printProgress, persDgm, persLoc, + persCycle); + } + else if (libraryDiag[0] == 'D') { + FltrR filtration = filtrationGudhiToDionysus< FltrR >(smplxTree); + FiltrationDiagDionysus< Persistence >( + filtration, maxdimension, location, printProgress, persDgm, persLoc, + persCycle); + } + else { + std::vector< phat::column > cmplx; + std::vector< double > values; + phat::boundary_matrix< phat::vector_vector > boundary_matrix; + filtrationGudhiToPhat< phat::column, phat::dimension >( + smplxTree, cmplx, values, boundary_matrix); + FiltrationDiagPhat( + cmplx, values, boundary_matrix, maxdimension, location, + printProgress, persDgm, persLoc, persCycle); + } + } + else { + if (dist[0] == 'e') { + // RipsDiag for L2 distance + if (libraryDiag[0] == 'D' && libraryDiag[0] == '2') { + FiltrationDiagDionysus2( + RipsFiltrationDionysus2< PairDistances2, Generator2, FltrR2 >(X, nSample, + nDim, false, maxdimension, maxscale, printProgress, print), + maxdimension, location, printProgress, persDgm, persLoc, persCycle + ); + } + else { + FltrR filtration = + RipsFiltrationDionysus< PairDistances, Generator, FltrR >( + X, nSample, nDim, false, maxdimension, maxscale, + printProgress, print); + + if (libraryDiag[0] == 'D') { + FiltrationDiagDionysus< Persistence >( + filtration, maxdimension, location, printProgress, persDgm, + persLoc, persCycle); + } + else if (libraryDiag[0] == 'G') { + Gudhi::Simplex_tree<> smplxTree = + filtrationDionysusToGudhi< Gudhi::Simplex_tree<> >(filtration); + int p = 2; //characteristic of the coefficient field for homology + double min_persistence = 0; //minimal length for persistent intervals + FiltrationDiagGudhi( + smplxTree, p, min_persistence, maxdimension, printProgress, + persDgm); + } + else { + std::vector< phat::column > cmplx; + std::vector< double > values; + phat::boundary_matrix< phat::vector_vector > boundary_matrix; + filtrationDionysusToPhat< phat::column, phat::dimension >( + filtration, cmplx, values, boundary_matrix); + FiltrationDiagPhat( + cmplx, values, boundary_matrix, maxdimension, location, + printProgress, persDgm, persLoc, persCycle); + } + } + } + else { + // RipsDiag for arbitrary distance + + if (libraryDiag[0] == 'D' && libraryDiag[1] == '2') { + FiltrationDiagDionysus2( + RipsFiltrationDionysus2< PairDistances2A, Generator2A, FltrR2A >(X, nSample, + nDim, true, maxdimension, maxscale, printProgress, print), + maxdimension, location, printProgress, persDgm, persLoc, persCycle); + } else { + + FltrRA filtration = + RipsFiltrationDionysus< PairDistancesA, GeneratorA, FltrRA >( + X, nSample, nDim, true, maxdimension, maxscale, + printProgress, print); + + if (libraryDiag[0] == 'D') { + FiltrationDiagDionysus< Persistence >( + filtration, maxdimension, location, printProgress, persDgm, + persLoc, persCycle); + } + else if (libraryDiag[0] == 'G') { + Gudhi::Simplex_tree<> smplxTree = + filtrationDionysusToGudhi< Gudhi::Simplex_tree<> >(filtration); + int p = 2; //characteristic of the coefficient field for homology + double min_persistence = 0; //minimal length for persistent intervals + FiltrationDiagGudhi( + smplxTree, p, min_persistence, maxdimension, printProgress, + persDgm); + } + else { + std::vector< phat::column > cmplx; + std::vector< double > values; + phat::boundary_matrix< phat::vector_vector > boundary_matrix; + filtrationDionysusToPhat< phat::column, phat::dimension >( + filtration, cmplx, values, boundary_matrix); + FiltrationDiagPhat( + cmplx, values, boundary_matrix, maxdimension, location, + printProgress, persDgm, persLoc, persCycle); + } + } + } + } +} + diff --git a/src/tdautils/.swp b/src/tdautils/.swp new file mode 100644 index 0000000000000000000000000000000000000000..3572d2ce8e6b3acf51ae542f2e229a7f97fe9ab2 GIT binary patch literal 12288 zcmeI%u?~VT5P;#s%}Jw^8|qvFniyv{2NDNyuQ*`D)Cjw;=o|PbN>N585^(sNw3jw* z@A7R;kN#mBx@#|xq7CivZO@5Fo|vo7lJs0H*2I0eFH@n8H^yAl?rzk|OHHiPniDJ| zfI#~K{b;!ee4BSZo4Vss)P4zML;wK<5I_I{1Q0*~fzAZ-n -#include -#include - -#include - -// for changing formats and typecasting -#include - -//for GUDHI -#include - -// for Dionysus -#include -#include - -// for phat -#include - -// for grid -#include - -#include - - - -// FiltrationDiag -/** \brief Interface for R code, construct the persistence diagram from the -* filtration. -* -* @param[out] Rcpp::List A list -* @param[in] filtration The input filtration -* @param[in] maxdimension Max dimension of the homological features to be -* computed. -* @param[in] library Either "GUDHI", "Dionysus", or "PHAT" -* @param[in] location Are location of birth point, death point, and -* representative cycles returned? -* @param[in] printProgress Is progress printed? -*/ -// TODO: see whether IntegerVector in template is deducible -template< typename VertexVector, typename VectorList, typename RealVector > -inline void filtrationDiagSorted( - VectorList & cmplx, - RealVector & values, - const int maxdimension, - const std::string & library, - const bool location, - const bool printProgress, - const unsigned idxShift, - std::vector< std::vector< std::vector< double > > > & persDgm, - std::vector< std::vector< std::vector< unsigned > > > & persLoc, - std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle -) { - - if (library[0] == 'G') { - int coeff_field_characteristic = 2; - double min_persistence = 0.0; - Gudhi::Simplex_tree<> smplxTree = filtrationTdaToGudhi< - VertexVector, Gudhi::Simplex_tree<> >( - cmplx, values, idxShift); - FiltrationDiagGudhi( - smplxTree, coeff_field_characteristic, min_persistence, maxdimension, - printProgress, persDgm); - } - else if (library[0] == 'D' && library[1] == '2') { - FiltrationDiagDionysus2( - filtrationTdaToDionysus2< VertexVector, Fltr2>( - cmplx, values, idxShift), - maxdimension, location, printProgress, persDgm, persLoc, persCycle); - } - else if (library[0] == 'D') { - FiltrationDiagDionysus< Persistence >( - filtrationTdaToDionysus< VertexVector, Fltr >( - cmplx, values, idxShift), - maxdimension, location, printProgress, persDgm, persLoc, persCycle); - } - else { - - std::vector< phat::column > cmplxPhat(cmplx.size()); - typename VectorList::iterator iCmplx = cmplx.begin(); - std::vector< phat::column >::iterator iPhat = cmplxPhat.begin(); - for (; iCmplx != cmplx.end(); ++iCmplx, ++iPhat) { - VertexVector cmplxVec(*iCmplx); - *iPhat = phat::column(cmplxVec.begin(), cmplxVec.end()); - } - - phat::boundary_matrix< phat::vector_vector > boundary_matrix; - filtrationDionysusToPhat< phat::column, phat::dimension >( - filtrationTdaToDionysus< phat::column, Fltr >( - cmplxPhat, values, idxShift), - cmplxPhat, values, boundary_matrix); - FiltrationDiagPhat(cmplxPhat, values, boundary_matrix, - maxdimension, location, printProgress, persDgm, persLoc, persCycle); - } -} - - - -// FiltrationDiag -/** \brief Interface for R code, construct the persistence diagram from the -* filtration. -* -* @param[out] Rcpp::List A list -* @param[in] filtration The input filtration -* @param[in] maxdimension Max dimension of the homological features to be -* computed. -* @param[in] library Either "GUDHI", "Dionysus", or "PHAT" -* @param[in] location Are location of birth point, death point, and -* representative cycles returned? -* @param[in] printProgress Is progress printed? -*/ -// TODO: see whether IntegerVector in template is deducible -template< typename VertexVector, typename VectorList, typename RealVector > -inline void filtrationDiag( - VectorList & cmplx, - RealVector & values, - const int maxdimension, - const std::string & library, - const bool location, - const bool printProgress, - const unsigned idxShift, - std::vector< std::vector< std::vector< double > > > & persDgm, - std::vector< std::vector< std::vector< unsigned > > > & persLoc, - std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle -) { - - if (std::is_sorted(values.begin(), values.end())) { - filtrationDiagSorted< VertexVector >( - cmplx, values, maxdimension, library, location, printProgress, - idxShift, persDgm, persLoc, persCycle); - } - else { - std::vector< std::vector< unsigned > > cmplxTemp = - RcppCmplxToStl< std::vector< unsigned >, VertexVector >(cmplx, 0); - std::vector< double > valuesTemp(values.begin(), values.end()); - filtrationSort(cmplxTemp, valuesTemp); - filtrationDiagSorted< std::vector< unsigned > >( - cmplxTemp, valuesTemp, maxdimension, library, location, printProgress, - idxShift, persDgm, persLoc, persCycle); - } -} - - - -# endif // __FILTRATIONDIAG_H__ + +#ifndef __FILTRATIONDIAG_H__ +#define __FILTRATIONDIAG_H__ + +#include +#include +#include + +#include + +// for changing formats and typecasting +#include + +//for GUDHI +#include + +// for Dionysus +#include + +// for phat +#include + +// for grid +#include + +#include + + + +// FiltrationDiag +/** \brief Interface for R code, construct the persistence diagram from the +* filtration. +* +* @param[out] Rcpp::List A list +* @param[in] filtration The input filtration +* @param[in] maxdimension Max dimension of the homological features to be +* computed. +* @param[in] library Either "GUDHI", "Dionysus", or "PHAT" +* @param[in] location Are location of birth point, death point, and +* representative cycles returned? +* @param[in] printProgress Is progress printed? +*/ +// TODO: see whether IntegerVector in template is deducible +template< typename VertexVector, typename VectorList, typename RealVector > +inline void filtrationDiagSorted( + VectorList & cmplx, + RealVector & values, + const int maxdimension, + const std::string & library, + const bool location, + const bool printProgress, + const unsigned idxShift, + std::vector< std::vector< std::vector< double > > > & persDgm, + std::vector< std::vector< std::vector< unsigned > > > & persLoc, + std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle +) { + + if (library[0] == 'G') { + int coeff_field_characteristic = 2; + double min_persistence = 0.0; + Gudhi::Simplex_tree<> smplxTree = filtrationTdaToGudhi< + VertexVector, Gudhi::Simplex_tree<> >( + cmplx, values, idxShift); + FiltrationDiagGudhi( + smplxTree, coeff_field_characteristic, min_persistence, maxdimension, + printProgress, persDgm); + } + else if (library[0] == 'D') { + FiltrationDiagDionysus2< Persistence2 >( + filtrationTdaToDionysus2< VertexVector, Fltr2 >( + cmplx, values, idxShift), + maxdimension, location, printProgress, persDgm, persLoc, persCycle); + } + else { + + std::vector< phat::column > cmplxPhat(cmplx.size()); + typename VectorList::iterator iCmplx = cmplx.begin(); + std::vector< phat::column >::iterator iPhat = cmplxPhat.begin(); + for (; iCmplx != cmplx.end(); ++iCmplx, ++iPhat) { + VertexVector cmplxVec(*iCmplx); + *iPhat = phat::column(cmplxVec.begin(), cmplxVec.end()); + } + + phat::boundary_matrix< phat::vector_vector > boundary_matrix; + filtrationDionysus2ToPhat< phat::column, phat::dimension >( + filtrationTdaToDionysus2< phat::column, Fltr2 >( + cmplxPhat, values, idxShift), + cmplxPhat, values, boundary_matrix); + FiltrationDiagPhat(cmplxPhat, values, boundary_matrix, + maxdimension, location, printProgress, persDgm, persLoc, persCycle); + } +} + + + +// FiltrationDiag +/** \brief Interface for R code, construct the persistence diagram from the +* filtration. +* +* @param[out] Rcpp::List A list +* @param[in] filtration The input filtration +* @param[in] maxdimension Max dimension of the homological features to be +* computed. +* @param[in] library Either "GUDHI", "Dionysus", or "PHAT" +* @param[in] location Are location of birth point, death point, and +* representative cycles returned? +* @param[in] printProgress Is progress printed? +*/ +// TODO: see whether IntegerVector in template is deducible +template< typename VertexVector, typename VectorList, typename RealVector > +inline void filtrationDiag( + VectorList & cmplx, + RealVector & values, + const int maxdimension, + const std::string & library, + const bool location, + const bool printProgress, + const unsigned idxShift, + std::vector< std::vector< std::vector< double > > > & persDgm, + std::vector< std::vector< std::vector< unsigned > > > & persLoc, + std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle +) { + + if (std::is_sorted(values.begin(), values.end())) { + filtrationDiagSorted< VertexVector >( + cmplx, values, maxdimension, library, location, printProgress, + idxShift, persDgm, persLoc, persCycle); + } + else { + std::vector< std::vector< unsigned > > cmplxTemp = + RcppCmplxToStl< std::vector< unsigned >, VertexVector >(cmplx, 0); + std::vector< double > valuesTemp(values.begin(), values.end()); + filtrationSort(cmplxTemp, valuesTemp); + filtrationDiagSorted< std::vector< unsigned > >( + cmplxTemp, valuesTemp, maxdimension, library, location, printProgress, + idxShift, persDgm, persLoc, persCycle); + } +} + + + +# endif // __FILTRATIONDIAG_H__ diff --git a/src/tdautils/filtrationDiagb.h b/src/tdautils/filtrationDiagb.h new file mode 100644 index 0000000..eac5529 --- /dev/null +++ b/src/tdautils/filtrationDiagb.h @@ -0,0 +1,148 @@ +#ifndef __FILTRATIONDIAG_H__ +#define __FILTRATIONDIAG_H__ + +#include +#include +#include + +#include + +// for changing formats and typecasting +#include + +//for GUDHI +#include + +// for Dionysus +#include +#include + +// for phat +#include + +// for grid +#include + +#include + + + +// FiltrationDiag +/** \brief Interface for R code, construct the persistence diagram from the +* filtration. +* +* @param[out] Rcpp::List A list +* @param[in] filtration The input filtration +* @param[in] maxdimension Max dimension of the homological features to be +* computed. +* @param[in] library Either "GUDHI", "Dionysus", or "PHAT" +* @param[in] location Are location of birth point, death point, and +* representative cycles returned? +* @param[in] printProgress Is progress printed? +*/ +// TODO: see whether IntegerVector in template is deducible +template< typename VertexVector, typename VectorList, typename RealVector > +inline void filtrationDiagSorted( + VectorList & cmplx, + RealVector & values, + const int maxdimension, + const std::string & library, + const bool location, + const bool printProgress, + const unsigned idxShift, + std::vector< std::vector< std::vector< double > > > & persDgm, + std::vector< std::vector< std::vector< unsigned > > > & persLoc, + std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle +) { + + if (library[0] == 'G') { + int coeff_field_characteristic = 2; + double min_persistence = 0.0; + Gudhi::Simplex_tree<> smplxTree = filtrationTdaToGudhi< + VertexVector, Gudhi::Simplex_tree<> >( + cmplx, values, idxShift); + FiltrationDiagGudhi( + smplxTree, coeff_field_characteristic, min_persistence, maxdimension, + printProgress, persDgm); + } + else if (library[0] == 'D' && library[1] == '2') { + FiltrationDiagDionysus2( + filtrationTdaToDionysus2< VertexVector, Fltr2>( + cmplx, values, idxShift), + maxdimension, location, printProgress, persDgm, persLoc, persCycle); + } + else if (library[0] == 'D') { + FiltrationDiagDionysus< Persistence >( + filtrationTdaToDionysus< VertexVector, Fltr >( + cmplx, values, idxShift), + maxdimension, location, printProgress, persDgm, persLoc, persCycle); + } + else { + + std::vector< phat::column > cmplxPhat(cmplx.size()); + typename VectorList::iterator iCmplx = cmplx.begin(); + std::vector< phat::column >::iterator iPhat = cmplxPhat.begin(); + for (; iCmplx != cmplx.end(); ++iCmplx, ++iPhat) { + VertexVector cmplxVec(*iCmplx); + *iPhat = phat::column(cmplxVec.begin(), cmplxVec.end()); + } + + phat::boundary_matrix< phat::vector_vector > boundary_matrix; + filtrationDionysusToPhat< phat::column, phat::dimension >( + filtrationTdaToDionysus< phat::column, Fltr >( + cmplxPhat, values, idxShift), + cmplxPhat, values, boundary_matrix); + FiltrationDiagPhat(cmplxPhat, values, boundary_matrix, + maxdimension, location, printProgress, persDgm, persLoc, persCycle); + } +} + + + +// FiltrationDiag +/** \brief Interface for R code, construct the persistence diagram from the +* filtration. +* +* @param[out] Rcpp::List A list +* @param[in] filtration The input filtration +* @param[in] maxdimension Max dimension of the homological features to be +* computed. +* @param[in] library Either "GUDHI", "Dionysus", or "PHAT" +* @param[in] location Are location of birth point, death point, and +* representative cycles returned? +* @param[in] printProgress Is progress printed? +*/ +// TODO: see whether IntegerVector in template is deducible +template< typename VertexVector, typename VectorList, typename RealVector > +inline void filtrationDiag( + VectorList & cmplx, + RealVector & values, + const int maxdimension, + const std::string & library, + const bool location, + const bool printProgress, + const unsigned idxShift, + std::vector< std::vector< std::vector< double > > > & persDgm, + std::vector< std::vector< std::vector< unsigned > > > & persLoc, + std::vector< std::vector< std::vector< std::vector< unsigned > > > > & persCycle +) { + + if (std::is_sorted(values.begin(), values.end())) { + filtrationDiagSorted< VertexVector >( + cmplx, values, maxdimension, library, location, printProgress, + idxShift, persDgm, persLoc, persCycle); + } + else { + std::vector< std::vector< unsigned > > cmplxTemp = + RcppCmplxToStl< std::vector< unsigned >, VertexVector >(cmplx, 0); + std::vector< double > valuesTemp(values.begin(), values.end()); + filtrationSort(cmplxTemp, valuesTemp); + filtrationDiagSorted< std::vector< unsigned > >( + cmplxTemp, valuesTemp, maxdimension, library, location, printProgress, + idxShift, persDgm, persLoc, persCycle); + } +} + + + +# endif // __FILTRATIONDIAG_H__ diff --git a/src/tdautils/gridUtils.h b/src/tdautils/gridUtils.h index 3a89f02..fab285b 100644 --- a/src/tdautils/gridUtils.h +++ b/src/tdautils/gridUtils.h @@ -7,7 +7,7 @@ //#include //#include //#include -//#include +#include #include #include @@ -29,7 +29,7 @@ #endif - +/* typedef unsigned Vertex; typedef Simplex Smplx; typedef Smplx::VertexContainer VertexCont; @@ -43,13 +43,30 @@ typedef OffsetBeginMap FiltrationPersistenceMap; - +*/ //dionysus2 //needs changing -typedef d::Simplex Smplx2; +//typedef d::Simplex Smplx2; +//typedef d::Filtration Fltr2; +//typedef d::ReducedMatrix Persistence2; +//typedef d::StandardReduction StandardReduction2; + +typedef unsigned Vertex; +typedef d::Simplex Smplx2; +//typedef Smplx::VertexContainer VertexCont; +typedef std::vector VertexVector; typedef d::Filtration Fltr2; typedef d::ReducedMatrix Persistence2; typedef d::StandardReduction StandardReduction2; +typedef d::Diagram PDgm; +/* +typedef OffsetBeginMap PersistenceFiltrationMap; +typedef OffsetBeginMap FiltrationPersistenceMap; +*/ // add a single edge to the filtration template< typename VectorList > diff --git a/src/tdautils/ripsD2L2.h b/src/tdautils/ripsD2L2.h index be30224..3cb5db7 100644 --- a/src/tdautils/ripsD2L2.h +++ b/src/tdautils/ripsD2L2.h @@ -26,7 +26,7 @@ typedef d::PairwiseDistances> Pai typedef PairDistances2::DistanceType DistanceType2; typedef PairDistances2::IndexType VertexR2; -typedef d::Rips< PairDistances2, d::Simplex< VertexR, double > > Generator2; +typedef d::Rips< PairDistances2, d::Simplex< VertexR2, double > > Generator2; typedef Generator2::Simplex SmplxR2; typedef d::Filtration FltrR2; diff --git a/src/tdautils/typecastUtils.h b/src/tdautils/typecastUtils.h index 1980f57..69121a6 100644 --- a/src/tdautils/typecastUtils.h +++ b/src/tdautils/typecastUtils.h @@ -19,7 +19,29 @@ inline PersistenceDiagram RcppToDionysus(const RcppMatrix& rcppMatrix) { return dionysusDiagram; } +template +inline PersistenceDiagram RcppToDionysus2(const RcppMatrix& rcppMatrix) { + PersistenceDiagram dionysusDiagram; + const unsigned rowNum = rcppMatrix.nrow(); + for (unsigned rowIdx = 0; rowIdx < rowNum; ++rowIdx) + { + dionysusDiagram.push_back(typename PersistenceDiagram::Point( + rcppMatrix[rowIdx + 0 * rowNum], rcppMatrix[rowIdx + 1 * rowNum], d::Empty())); + } + return dionysusDiagram; +} +template +inline PairVector RcppToPairVector(const RcppMatrix& rcppMatrix) { + PairVector dionysusDiagram; + const unsigned rowNum = rcppMatrix.nrow(); + for (unsigned rowIdx = 0; rowIdx < rowNum; ++rowIdx) + { + dionysusDiagram.push_back(std::pair( + rcppMatrix[rowIdx + 0 * rowNum], rcppMatrix[rowIdx + 1 * rowNum])); + } + return dionysusDiagram; +} template< typename StlMatrix, typename RealMatrix > inline StlMatrix TdaToStl(const RealMatrix & rcppMatrix, @@ -802,6 +824,31 @@ inline Filtration filtrationRcppToDionysus(const RcppList & rcppList) { return filtration; } +template< typename Filtration, typename RcppVector, typename RcppList > +inline Filtration filtrationRcppToDionysus2(const RcppList & rcppList) { + + const RcppList rcppComplex(rcppList[0]); + const RcppVector rcppValue(rcppList[1]); + Filtration filtration; + + typename RcppList::const_iterator iCmplx = rcppComplex.begin(); + typename RcppVector::const_iterator iValue = rcppValue.begin(); + for (; iCmplx != rcppComplex.end(); ++iCmplx, ++iValue) { + const RcppVector rcppVec(*iCmplx); + RcppVector dionysusVec(rcppVec.size()); + typename RcppVector::const_iterator iRcpp = rcppVec.begin(); + typename RcppVector::iterator iDionysus = dionysusVec.begin(); + for (; iRcpp != rcppVec.end(); ++iRcpp, ++iDionysus) { + // R is 1-base, while C++ is 0-base + *iDionysus = *iRcpp - 1; + } + filtration.push_back(typename Filtration::Cell(dionysusVec.size(), + dionysusVec.begin(), dionysusVec.end(), *iValue)); + } + + return filtration; +} + template< typename SimplexTree, typename Filtration > inline SimplexTree filtrationDionysusToGudhi(const Filtration & filtration) { @@ -920,7 +967,7 @@ inline void filtrationDionysusToPhat( template< typename Column, typename Dimension, typename Filtration, typename VectorList, typename RealVector, typename Boundary > -inline void filtrationDionysus2Phat( +inline void filtrationDionysus2ToPhat( const Filtration & filtration, VectorList & cmplx, RealVector & values, Boundary & boundary_matrix) { // use custom VertexComparison with Dionysus2 From a6cd186c1a269f18f86db60c5579fb44cac87b03 Mon Sep 17 00:00:00 2001 From: thomashli Date: Sat, 12 Jan 2019 17:38:57 -0500 Subject: [PATCH 29/29] need to integrate wasserstein, need to fill in persLoc and persCycle --- src/.diag.cpp.swp | Bin 36864 -> 0 bytes src/.grid.h.swp | Bin 16384 -> 0 bytes src/tdautils/dionysus2Utils.h | 159 +++++++++++++++++++++++++++++----- src/tdautils/dionysusUtils.h | 36 +------- 4 files changed, 136 insertions(+), 59 deletions(-) delete mode 100644 src/.diag.cpp.swp delete mode 100644 src/.grid.h.swp diff --git a/src/.diag.cpp.swp b/src/.diag.cpp.swp deleted file mode 100644 index 7bbc7610aae8544f00a44e25b335cae4dba5cb78..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 36864 zcmeI536LDsd4Su15gWkRSPrp)@L1XMN?Gk_R~X5wgGgFQEIK2tkP47R&rI)5XXjYv zkanRGVJ;E0bh222&iKp9tIyJEn09L()d!6t;$I8ZnS1EB;c-~amVnH}w{ zmY{_0;$wEE-+TS~f9Les<~3ZtmN{x z+-q_T=x~)eYt78EQfYY3hH`1n$zAzU#>_`z{kflDYWb$rMyq0(foFOM;h~by5~w7Guls7A3b{Wq?1b` zCK5;_kVqhrKq7%e0*M3?2_zCoB=CPv0@aae#xAmbkel^m-OmriJ`ZufZ*V`4i{1aA z`+b@F>DdgtbPjcY?{`11cR#)R)7HlvWWf(8O z18_h55N?HA;OlS^Sg;&U1O?OKcONp0SK(jan{Yk+J&eI9!-OJ~b@b7RZ zTnSgeD4Y%JVKH>VC*b4o2XGAh_6WoHIa~tE;f=!$<285)&V?27{r3#xI_QUKaDSU& z?1q&v9X-gz@59Zo4OYQQ zm<7*cAh{WCf*aum_$KUzVK@%{4I|5Cup3M`3?9O;u?uFx0~l#8gAFhZ?s07#O3hYH zF0+_bWox92zeTZXR_z^@%8X_5R)t@ywQ|wQo?=K6e>`HA&BB>>ahp=&yiv7FMK{@{ zWh?lpG^o;cxtdcWCA(PdR9VaPztksWr^{w}%*);iyPC7gs=bf2$11f-yDIT*!>LOL z+UufbtxTyfQmWWaM!cw>san;tnX|HG(}|rX{S2>Wm7O%aXEpd7$k`QTol`UO71dEN zN9}@{S6RDY6)WaNQgQG-}Ytu;sMQ^6^wlrIh0c)_Sa zL{y^yRJEkih@Z4;RV&-26K;@C@28yz%Tfhfk?mrwkVbUmaj0x(9pP7$T~vrPczRt{ zC2*fz@KRi=c37DzLhmSacdD+gu1;0R&g!g-wsP~j9cI2pmi6UF0e@ZAFlB6oNbl%!wc*g%vD6?hz7g?f ztxo=5tW?XjOjVaTkE)Ont4JtuKcXcVsgyFyX`9+c1rEx4rm9u!;*cs>g%UT+iZUsL zA|TPMn0C>VHd42vz%>r#NA@-OwP26Zq$|@8>+TCy3Tl+ z9QF2YBW7q zo^CphWNAWD-MwpOtSk+=;RL5ly-=agNJTK@h7;PU<&nBc>vB6YsS446k>)_apw1bg z4o8|;xcboQ&OFn(Z;g(W%GGV#jM*t8Iuu4;U#b2I_0y@EPDPdBbHB&m=p!dR`o@x( z_}Ckb7-^46O=#q8DuT(5yU}cg>k8#R`BArNUMDf;MziTwQmLYH`9o)H#EnMsD{L$h z16ph2t)Fa@(RMZ!mW1EWBW=(Pj5prSu#1^|Eo-SoW#)rjxh3?S-Y>&hEBb?!dGwcT zwGjQm%#Y;EO*wNU7FGYg+?j1kjL>87hptuad?hvLk9mFK8tvv~#tg}k{-2Z4tUIet zrAyT+)yB#UhyIjksZ{V1e(0TMrNZ>5YT3n9=x^^aPTGE6yl6d}C2RW#!sc0E(3`a6V|L;MUyb1kM^nY*t{3Uexhd}iF(_t3uMQ48q?gP>9FNf7| z2)v3e{vbR6--Gj@2fE>Ccn#hB4yeK(!7=bLm<}I>qv5~N)o*|T{3(d;-UdHKXa5_x z5_Z8^@DX?pz5OZp54a1?hcnkw8ll z@bo{lI7fk7Ycls96;0Yx<>YTqgL7xZP9U|ULo|iuV)hC2L1x-mdHnReR2s!xQ$mFp z#B74%+@rNYnWwanK;IBrAL=&o=j}_I6y-~;uT@2(_7$9}ua+C~Xiq8B0d(RLQfD)G zQ%CXL+Nrj770acaq8am7rj(bz!n4Uvm3QXKGQ9G|wSL7_QJx$gttD`Wr~CR(ZLV}i zG|cYSeNuA%iDPxoPM!7Ju+M% zp}>>TBBK?0rlLZGuT;u08HsVOQYv6?a|}C`5llu$hV{p#GKF)eYEI!d(jc*ArO2^A z5!VLQiD;4WIMh&G-6|L?F#`_du!rPI`K%uA#Zam9Lo3vl%6c2#*6>@Rh!jre5lPH+V(9&O6jph5{Zn-C{HAFEhmkQu{x zYfzreFEXb~d_E%cP~UjvxSceg&PFD!iO&a4SQE1ec+=U&wz`PFK-f8BGPB?Qus>4b z^Wpj>j_+~l+a4Up&BW)Uv8QFz>Cz;|+^EMQr5!~Yn*4c9^$`r%l3 z34Q)<_&XSeLD&KdKy>{#(d~Z#7l8NyoC~tYU=Mo!O>iSz01MzPbbQhMpMZaYi{Vr_ z2n^VZuKy0a4Znt`;STr$TnHz_2jEt8{q1lP^guVvg%lhM;^+T7`u?3D{{4Rei(wi( zfsX%OsKF9=5n22#h)fRQ{xA6LHn9bSWH;3~KTE`kf;0{A>^0r4wH!NKqtwuL+3R=5Q&g^jQQ*26kj3%?Hs!OPeo z?t)7p4fEhoFkmnCo)_V(@DFeuTnpF0rLY7(4W9(LWS$;$pFx$IlJ%Uya;YFII2?%ACx^u=G--D%lFx$3e~t(fgpWgVu!n6*Ybm2*eiTirDw zT(h?5;2pAf>n(DM!svR1s|NFJW+F~XL~*W$f!+1?nUYDJfx6)CeAP?E{VY!~T4_WGXaJxx$DyqUK$VuG3>q|-+d1T=Zovukn?JKe6f*= z<<@(=7NrqU>;0Aepvr4KAifkRmD3aC1loibw60WADK`u3XEA$wLr#XS4@&gF%a=k@l+7p!%5jV*zPYtyDKp}P&I z@-*xi{T?n-p96PWXY9KrBW@Hy%fA(dxKZkEm;QW{VVy2 z6z4MEy?F{sv#o8vm$XG9Z44ok^`>hna?F#?YlBYp7mt=;TNcH%FF(3Ttz^3uTZOjS z_*-*gesxjPkmqI>aRx=8T&}i3=a;qqdqC_0cR&%=!&+Da zvQOYN=z!^9z~k5i{t+$%+2?->^n&;Xd>S6dKJZic3H%s-1oy(#up26{3XX+0u@(Fd zUWNPND)=&730J_E;0#y^Js`G)`>`8b4?E#B=!4~u0@+*e3_K0q9s@Zy-~t#1EiVyN zawQT-B#=lTkw7AWL;{Hf#*si`LzhEm5*vDAL!bIZ2x)TOLec+U8@i~YDyQ|#z&zU6 z=qU>W#%kBv-)w1PYb%8@Y;jgto-O=%4f*Fi`71%;LWnFQ)bDoM+Uc$x1=}*4u50jM z*}v~Ea3F(VMI*eLF&XPZM1t-@lh;k{+hpKnO&5#vc-PuPVTG?;IIYMt{S`xz!M;On z?QB$LiqdYpn+j{e>b@1@7MSS&Da!5w!#9K53=w7FnE#m`ftE6*pP!E z7=!_kb^VuFw-=v)AHjvN7@lKIUe53vgm(B9Ywoh|U(W420%VW>4k*LtU^yswg*Ek` zz}H|5R>LY-333j@baL1(&b$X7hsE zi?zabSq81R>z`e!U+

>0S;HELD|NEY*fMxlsq=>_FK{n{`(AkWAA@xWT#@0Q730 zD%7YodS|cr2)HSS_fE?`Wl1!Au%wsGv`H?N+M-`oa^c-@vQ{Wb+MdywKXrzky}$7! zBcshY2Xw`9;Z`nBklq1T`mh=2P)TnUmm?*cN~PUDA+TL<%jPIR{}>gw#KQol7^Y9jj$^_JoI;}P`~?PO>w8s_^*qc~a& zZ8?`Wmo-9tnxlJCW2|gqJ8#{x(J5s^DTTH)hNaWQ12U?D?cdIH>0wpi2^HD&E(^!e zjzIcve}Wobrgpm4ckTE_W@*c18~emga3pl7@A92fIrqa)#y{E7h#gMhcDmt%7o}>- zNTG8uP-2mF-1R;0%#!#BrAHUZeDlS&$O4_0*H}V z0%^FjFWud*j7Y4Z!$5=0?D~mr4hHD?+>;Vyd%vBr^>G3Xo2{P*;yUs29FW5MhEHA% zwe`BDu$1LjIcs$scRufEg;r7xRo)J%{>Mka5`~=~9Jxz37 zdtd9>GkxEQP!iD@YGWlhBL)W3$aN&NQ$ybV_+=_?YshlOCCJxkr0P^Bpc6^1imUBg5l4 z4+>GsAs9x28=)w5BLR`z!9etX)M?{7(KFHiy?)?3!4f^_BAU^zZPX8Eu7G6L{{}qUz{$IjMI0@##$KXx$_3y#?&^N_JD_|`R2tK_!J!A^J-)*sRm1l+Gvq(U8WkRlURnxlJR8jbX* zzVTGtBJ1;|&yCXWy3+5yr{vVA*Z}F;C_`EB$Glc7K2BIi3iXB*PX^krS^jAW{m&)T zoyZoRfH!*osUEt*qg`vw$Zvv8`rqhe`~Yl3LB*1PYr;K_SK8G3-fy;UW*$DWaNIK_ zo19!}I4yuaAXpd~@4QN;6OKJo+lc)0D}&OK37}XhWP;yu0D8~1jjt|5?)a7kai0}U z1pUbGe6+wkkMZby+L2|f^XUfDjTOjl>243p80EA=Z3m!D2qru$W=Qg6*JT!y18RSTIr|aqg>6S9LsAtOzMc zRnc+R=f3;yJMY$g_uW)y=KAV7eWf`oaC}OLU+-Pref6EsixcOCi2A;3hJ8ozvm=fj zEf-F&o7M&2e{Om+@Tb4rbbQNnO0B+=D{NidZf(g>nt|1Ct#4nuxwEv`3Z^65Y=wc< z>IKSf_DgN6i;97Yfrn+_r1;X=*%{Vydg2s)_A}QWR$KT@R;~6?0qb?Ioq+HNUyWs52V*q>90HXJiV-bDh4VBDh4VBDh4VBDh4VB zDh4VBDh4VBDh4VB9zh07TZkuM*I#*~kN^MA_W%FH&9{L!fjB;8~yzoCN-GQivCU-vcJl1{Q!*z`s5r#GipTfY*Va0N((b zfB~EVJ`Mcs<3jubh=6V23UCQn1kM9(U;&s1o&*H&n~w?M0TTcLcRwn`uYdtSz#AWd zKHx0y-iL+wEpQ7s56lDaJ}$&xfER%u0eiqSaQ8#t2lyp01U7*+U;%gvcoKN~gLn?u zz#eb|cnY|Sg34>aE5IFK7kCzU1~>uSLqX;)@JHZhz;l2C`alop0?R-h;JJG%E7Xj9 ztkohj=%^?#gCW_1MgLD zAJcPMb#2)t3}R(jo= zXS$Ng%!)|j)y7LcKav>r*%;j%v$^m@Peeo1TcS?SAQGAGuJZQMIFf|{qSX*nlf;B% zj--7aEsi+&)d}PPQ3N3ZQF33BHMAVWvOq?0;3W!*UGi6UI;JIB7XwN8EyS*l3J?Pa zKIW@|*);QuG0Pp*(aS_?+a*L!8PLQM0?sgu!wK?nwRw4QXJV{j$wP>;hEGQlVZJZz zz)ae9^0>zujyQJuDkOO;Hl2_RxRY`bZ${w}W8xscr3PUZg(hj5X2;Ab`XM@7HxKUD zfv*-h+$emeNL^49gLYmfIv^eKHpVj0JlfXVAW4#|2qt^6+W}>;k7Hj#d5CqDcoE9& zA2Tyu0x~)lDTx{wI2M6svx&9Ao`j>7yexkal4f}$bnQs5hwj|jIvUqqo6Q+p^~m3c znL0k_Y7v?DXv-Q5=H}KA0JYvla&Ri@=p0HAv>mxGomLnlO6pV26+(Vyj=GS%g6s8D z_hKeoKt_9DC|dP8-RJ}gL35SYz^-XYP9SItl-m*u9|yW1=Jgu0TnNaZ(Cpzw<5c4j z%W9GwBKezT6igv=j6k-l5J(8i`W{U!AUophCOL`%SkytdyL*=E#0~Z`F-`!=g6)!E zqPbSnt@wgYEszHd*=oe}8ZmOe9>PZ=wLmlHlCya&dXA>1xYMjj)*n)h)MRaQHbj2B zt5(>}rxG=F^7xVmk+m%h;@B39HPAgJHtSOu&A-}eK4G(AU}9bmH!{`_C3Q?WPtMVN zj24G%aE9jh9Cymyp*_p2n^-^(3~d``F_odjSjgx_*0ikHjd}SoY~?mPvZs*5A~xd! z*%F`P)z_jBixg5>8IH0Dopa`6MI-H9JkETICw1el`hV&J~_$H z*yD~(4opDjc$t$WyTDBDa!f?YK(;E)pwatTBvR-Z$B0B^KTK?cawWBd#>_*LK%rYC z9is)brU4{Nx;n$g$TUd6*v5Uy!-b*oz$rmK&O$O!u)56~ozu#*oYT(OsZomLj_ST2%*w&4B$PqD|{!`_$o z|5H2!94}!%e+%$|E5N_8e}4t=fiD2BVbA_O;4*Lmcolo|9|O+=b>MTrKd|?H9k>I8 zzy@Fd_pk?l3-~#(1>D0P{5{}5z~6zFfjhtp!1F*GFn}|_yRh?3;QPQ$fbG5kR7b@? z#X!YC#X!YC#lR!NK(=AvecVzH*%_a#`6!_A$wjVg=prX4f29&|xhC2`YNyK+I7!PL z(7?s8tk%+?87?7>O2(XLaYb0#^JC>719C{4xP*)y{`+k+d0)#r(6kJkCa7g2bJOz; zc|&JnNO#BWFI7qIvt18`?dLqMSj)mKoy;fM+h`Y%)(|;iFFh|8x)N(ysMT6V)xAIu zGVUysE-84xi##iDqfp8JP=m}S6y1$k_5Mo7^dwr?KaQg$FR+dbVPDfK4zHU&P{p+D zUj=RV7guwqEJ*1=7b>c}JxP4{NTIT$6hz_dR6dF&uHtCnex*|D0~J}1EU)s*bm$(b t=x<{?qGVTAxH)E4;fb!MQoAarqJmtS|EiqIVVHwrRZiui?f;^j`d^8zT@L^N diff --git a/src/tdautils/dionysus2Utils.h b/src/tdautils/dionysus2Utils.h index 2af0585..641d50a 100644 --- a/src/tdautils/dionysus2Utils.h +++ b/src/tdautils/dionysus2Utils.h @@ -36,6 +36,134 @@ namespace d = dionysus; * write nothing after. */ + +/** + * Class: EvaluatePushBack + * + * Push back the simplex and the evaluated value + */ +template< typename Container, typename Evaluator > +class EvaluatePushBack2 { + +public: + EvaluatePushBack2(Container & argContainer, const Evaluator & argEvaluator) : + container(argContainer), evaluator(argEvaluator) {} + + void operator()(const typename Container::value_type & argSmp) const { + typename Container::value_type smp(argSmp.dimension(),argSmp.begin(),argSmp.end(), evaluator(argSmp)); + container.push_back(smp); + } + +private: + Container & container; + const Evaluator & evaluator; +}; + +template< typename VertexList, typename Evaluator > +unsigned getLocation(const VertexList & vertices, const Evaluator & evaluator) { + typename VertexList::const_iterator vertexItr; + unsigned vertex = *(vertices.begin()); + for (vertexItr = vertices.begin(); vertexItr != vertices.end(); ++vertexItr) { + if (evaluator[*vertexItr] > evaluator[vertex]) { + vertex = *vertexItr; + } + } + return vertex + 1; +} + +template< typename Simplex, typename Locations, typename Cycles, + typename Persistence, typename Evaluator, typename SimplexMap, + typename Filtration > +inline void initLocations( + Locations & locations, Cycles & cycles, const Persistence & p, + const Evaluator & evaluator, const SimplexMap & m, + const unsigned maxdimension, const Filtration & filtration) { + + unsigned verticesMax = 0; + for (typename Filtration::OrderConstIterator iFltr = filtration.begin(); + iFltr != filtration.end(); ++iFltr) { + const typename Filtration::Simplex & c = *(iFltr); + if (c.dimension() == 0) { + verticesMax = std::max(verticesMax, *(c.begin())); + } + } + + // vertices range from 0 to verticesMax + std::vector< double > verticesValues( + verticesMax + 1, -std::numeric_limits< double >::infinity()); + + for (typename Filtration::OrderConstIterator iFltr = filtration.begin(); + iFltr != filtration.end(); ++iFltr) { + const typename Filtration::Cell & c = *(iFltr); + if(c.dimension() == 0) { + verticesValues[*(c.begin())] = c.data(); + } + } + + locations.resize(maxdimension + 1); + cycles.resize(maxdimension + 1); + typename Locations::value_type::value_type persLocPoint(2); + typename Cycles::value_type::value_type persBdy; + typename Cycles::value_type::value_type::value_type persSimplex; + for (typename Persistence::iterator cur = p.begin(); cur != p.end(); ++cur) { + // positive simplices corresponds to + // negative simplices having non-empty cycles + if (cur->sign()) { + // first consider that cycle is paired + if (!cur->unpaired()) { + // the cycle that was born at cur is killed + // when we added death (another simplex) + const typename SimplexMap::key_type& death = cur->pair; + + //const typename SimplexMap::value_type& b = m[cur]; + //const typename SimplexMap::value_type& d = m[death]; + const typename Filtration::Cell & b = m[cur]; + const typename Filtration::Cell & d = m[death]; + if ((unsigned)b.dimension() > maxdimension) { + continue; + } + if (evaluator(b) < evaluator(d)) { + persLocPoint[0] = getLocation(b.vertices(), verticesValues); + persLocPoint[1] = getLocation(d.vertices(), verticesValues); + locations[b.dimension()].push_back(persLocPoint); + + // Iterate over the cycle + persBdy.clear(); + const typename Persistence::Cycle& cycle = death->cycle; + for (typename Persistence::Cycle::const_iterator + si = cycle.begin(); si != cycle.end(); ++si) { + persSimplex.clear(); + const typename Simplex::VertexContainer& + vertices = m[*si].vertices(); // std::vector where Vertex = Distances::IndexType + typename Simplex::VertexContainer::const_iterator vtxItr; + for (vtxItr = vertices.begin(); vtxItr != vertices.end(); + ++vtxItr) { + persSimplex.push_back(*vtxItr + 1); + } + persBdy.push_back(persSimplex); + } + cycles[b.dimension()].push_back(persBdy); + } + } + else { // cycles can be unpaired + const typename SimplexMap::value_type& b = m[cur]; + if ((unsigned)b.dimension() > maxdimension) { + continue; + } + persLocPoint[0] = getLocation(b.vertices(), verticesValues); + persLocPoint[1] = (unsigned)( + std::max_element(verticesValues.begin(), verticesValues.end()) + - verticesValues.begin() + 1); + locations[b.dimension()].push_back(persLocPoint); + + // Iterate over the cycle + persBdy.clear(); + cycles[b.dimension()].push_back(persBdy); + } + } + } +} + template void FiltrationDiagDionysus2( const Filtration &filtration, @@ -75,7 +203,8 @@ void FiltrationDiagDionysus2( } else { pt_ = {filtration[pt.birth()].data(),filtration[pt.death()].data()}; } - persDgm[_].push_back(pt_); + if (pt_[0] != pt_[1]) + persDgm[_].push_back(pt_); } _++; } @@ -83,30 +212,12 @@ void FiltrationDiagDionysus2( if (persDgm.size() > maxdimension) { persDgm.resize(maxdimension + 1); } + if (location) { + initLocations< typename Filtration::Cell >( + persLoc, persCycle, p, typename Filtration::Cell::DataEvaluator(), + m, maxdimension, filtration); + } } - -/** - * Class: EvaluatePushBack - * - * Push back the simplex and the evaluated value - */ -template< typename Container, typename Evaluator > -class EvaluatePushBack2 { - -public: - EvaluatePushBack2(Container & argContainer, const Evaluator & argEvaluator) : - container(argContainer), evaluator(argEvaluator) {} - - void operator()(const typename Container::value_type & argSmp) const { - typename Container::value_type smp(argSmp.dimension(),argSmp.begin(),argSmp.end(), evaluator(argSmp)); - container.push_back(smp); - } - -private: - Container & container; - const Evaluator & evaluator; -}; - template< typename Distances, typename Generator, typename Filtration, typename RealMatrix, typename Print > inline Filtration RipsFiltrationDionysus2( diff --git a/src/tdautils/dionysusUtils.h b/src/tdautils/dionysusUtils.h index 2ae8d58..58e7444 100644 --- a/src/tdautils/dionysusUtils.h +++ b/src/tdautils/dionysusUtils.h @@ -391,39 +391,5 @@ inline Filtration RipsFiltrationDionysus( return filtration; } -/* -template< typename Distances, typename Generator, typename Filtration, - typename RealMatrix, typename Print > -inline Filtration RipsFiltrationDionysus2( - const RealMatrix & X, - const unsigned nSample, - const unsigned nDim, - const bool is_row_names, - const int maxdimension, - const double maxscale, - const bool printProgress, - const Print & print -) { - // This is a Matrix of Points - PointContainer points = TdaToStl< PointContainer >(X, nSample, nDim, is_row_names); - // - Distances distances(points); - Generator rips(distances); - typename Generator::Evaluator size(distances); - Filtration filtration; - EvaluatePushBack< Filtration, typename Generator::Evaluator > functor( - filtration, size); - // Generate maxdimension skeleton of the Rips complex - rips.generate(maxdimension + 1, maxscale, functor); - if (printProgress) { - print("# Generated complex of size: %d \n", filtration.size()); - } - // Sort the simplices with respect to comparison criteria - // e.g. distance or function values - filtration.sort(ComparisonDataDimension< typename Filtration::Simplex >()); - return filtration; -} -*/ - -# endif // __DIONYSUSUTILS_H__ \ No newline at end of file +# endif // __DIONYSUSUTILS_H__