diff --git a/include/boost/thread/thread.hpp b/include/boost/thread/thread.hpp index dfcf2e23..c02c6db5 100644 --- a/include/boost/thread/thread.hpp +++ b/include/boost/thread/thread.hpp @@ -38,6 +38,16 @@ public: ~thread_cancel(); }; +class BOOST_THREAD_DECL cancellation_guard +{ +public: + cancellation_guard(); + ~cancellation_guard(); + +private: + void* m_handle; +}; + #if defined(BOOST_HAS_WINTHREADS) struct sched_param @@ -127,11 +137,6 @@ public: static const int stack_min; - // This is an implementation detail and should be private, - // but we need it to be public to access the type in some - // unnamed namespace free functions in the implementation. - class data; - private: template friend std::basic_ostream& operator<<(std::basic_ostream&, const thread&); @@ -142,7 +147,7 @@ private: const void* id() const; #endif - data* m_handle; + void* m_handle; }; template diff --git a/src/thread.cpp b/src/thread.cpp index d5ef0b4b..f8f4bff8 100644 --- a/src/thread.cpp +++ b/src/thread.cpp @@ -34,7 +34,7 @@ namespace boost { -class thread::data +class thread_data { public: enum @@ -45,14 +45,16 @@ public: joined }; - data(const boost::function0& threadfunc); - data(); - ~data(); + thread_data(const boost::function0& threadfunc); + thread_data(); + ~thread_data(); void addref(); bool release(); void join(); void cancel(); + void enable_cancellation(); + void disable_cancellation(); void test_cancel(); void run(); #if defined(BOOST_HAS_WINTHREADS) @@ -64,6 +66,8 @@ public: void set_scheduling_parameter(int policy, const sched_param& param); void get_scheduling_parameter(int& policy, sched_param& param) const; + static thread_data* get_current(); + private: mutable boost::mutex m_mutex; mutable boost::condition m_cond; @@ -80,6 +84,7 @@ private: MPTaskID m_pTaskID; #endif bool m_canceled; + int m_cancellation_disabled_level; bool m_native; }; @@ -105,14 +110,14 @@ struct as_pointer : private boost::mpl::if_, pointer_based, static const void* from(const T& obj) { return do_from(obj); } }; -void release_tss_data(boost::thread::data* data) +void release_tss_data(boost::thread_data* data) { assert(data); if (data->release()) delete data; } -boost::thread_specific_ptr tss_thread_data(&release_tss_data); +boost::thread_specific_ptr tss_thread_data(&release_tss_data); struct thread_equals { @@ -135,7 +140,7 @@ static OSStatus thread_proxy(void* param) { try { - boost::thread::data* tdata = static_cast(param); + boost::thread_data* tdata = static_cast(param); tss_thread_data.reset(tdata); tdata->run(); } @@ -156,13 +161,15 @@ static OSStatus thread_proxy(void* param) namespace boost { -thread::data::data(const boost::function0& threadfunc) - : m_threadfunc(threadfunc), m_refcount(2), m_state(creating), m_canceled(false), m_native(false) +thread_data::thread_data(const boost::function0& threadfunc) + : m_threadfunc(threadfunc), m_refcount(2), m_state(creating), m_canceled(false), m_native(false), + m_cancellation_disabled_level(0) { } -thread::data::data() - : m_refcount(2), m_state(running), m_canceled(false), m_native(true) +thread_data::thread_data() + : m_refcount(1), m_state(running), m_canceled(false), m_native(true), + m_cancellation_disabled_level(0) { #if defined(BOOST_HAS_WINTHREADS) DuplicateHandle(GetCurrentProcess(), GetCurrentThread(), GetCurrentProcess(), @@ -173,7 +180,7 @@ thread::data::data() #endif } -thread::data::~data() +thread_data::~thread_data() { if (m_state != joined) { @@ -194,7 +201,7 @@ thread::data::~data() } } -void thread::data::addref() +void thread_data::addref() { boost::mutex::scoped_lock lock(m_mutex); while (m_state == creating) @@ -202,7 +209,7 @@ void thread::data::addref() ++m_refcount; } -bool thread::data::release() +bool thread_data::release() { boost::mutex::scoped_lock lock(m_mutex); while (m_state == creating) @@ -210,7 +217,7 @@ bool thread::data::release() return (--m_refcount == 0); } -void thread::data::join() +void thread_data::join() { { boost::mutex::scoped_lock lock(m_mutex); @@ -251,7 +258,7 @@ void thread::data::join() m_cond.notify_all(); } -void thread::data::cancel() +void thread_data::cancel() { boost::mutex::scoped_lock lock(m_mutex); while (m_state == creating) @@ -259,16 +266,28 @@ void thread::data::cancel() m_canceled = true; } -void thread::data::test_cancel() +void thread_data::test_cancel() { boost::mutex::scoped_lock lock(m_mutex); while (m_state == creating) m_cond.wait(lock); - if (m_canceled) + if (m_cancellation_disabled_level == 0 && m_canceled) throw boost::thread_cancel(); } -void thread::data::run() +void thread_data::disable_cancellation() +{ + boost::mutex::scoped_lock lock(m_mutex); + m_cancellation_disabled_level++; +} + +void thread_data::enable_cancellation() +{ + boost::mutex::scoped_lock lock(m_mutex); + m_cancellation_disabled_level--; +} + +void thread_data::run() { { boost::mutex::scoped_lock lock(m_mutex); @@ -279,14 +298,14 @@ void thread::data::run() #elif defined(BOOST_HAS_PTHREADS) m_thread = pthread_self(); #endif - m_state = thread::data::running; + m_state = thread_data::running; m_cond.notify_all(); } m_threadfunc(); } #if defined(BOOST_HAS_WINTHREADS) -long thread::data::id() const +long thread_data::id() const { boost::mutex::scoped_lock lock(m_mutex); while (m_state == creating) @@ -298,7 +317,7 @@ long thread::data::id() const return 0; // throw instead? } #else -const void* thread::data::id() const +const void* thread_data::id() const { boost::mutex::scoped_lock lock(m_mutex); while (m_state == creating) @@ -316,7 +335,7 @@ const void* thread::data::id() const } #endif -void thread::data::set_scheduling_parameter(int policy, const sched_param& param) +void thread_data::set_scheduling_parameter(int policy, const sched_param& param) { boost::mutex::scoped_lock lock(m_mutex); while (m_state == creating) @@ -349,7 +368,7 @@ void thread::data::set_scheduling_parameter(int policy, const sched_param& param #endif } -void thread::data::get_scheduling_parameter(int& policy, sched_param& param) const +void thread_data::get_scheduling_parameter(int& policy, sched_param& param) const { boost::mutex::scoped_lock lock(m_mutex); while (m_state == creating) @@ -369,6 +388,19 @@ void thread::data::get_scheduling_parameter(int& policy, sched_param& param) con #endif } +thread_data* thread_data::get_current() +{ + thread_data* data = tss_thread_data.get(); + if (data == 0) + { + data = new(std::nothrow) thread_data; + if (!data) + throw thread_resource_error(); + tss_thread_data.reset(data); + } + return data; +} + thread_cancel::thread_cancel() { } @@ -377,6 +409,18 @@ thread_cancel::~thread_cancel() { } +cancellation_guard::cancellation_guard() +{ + thread_data* data = thread_data::get_current(); + m_handle = data; + data->disable_cancellation(); +} + +cancellation_guard::~cancellation_guard() +{ + static_cast(m_handle)->enable_cancellation(); +} + thread::attributes::attributes() { #if defined(BOOST_HAS_WINTHREADS) @@ -602,23 +646,15 @@ thread::thread() m_pTaskID = MPCurrentTaskID(); m_pJoinQueueID = kInvalidID; #endif - thread::data* tdata = tss_thread_data.get(); - if (tdata == 0) - { - tdata = new(std::nothrow) thread::data; - if (!tdata) - throw thread_resource_error(); - tss_thread_data.reset(tdata); - } - else - tdata->addref(); + thread_data* tdata = thread_data::get_current(); + tdata->addref(); m_handle = tdata; } thread::thread(const function0& threadfunc, attributes attr) : m_handle(0) { - std::auto_ptr param(new(std::nothrow) thread::data(threadfunc)); + std::auto_ptr param(new(std::nothrow) thread_data(threadfunc)); if (param.get() == 0) throw thread_resource_error(); #if defined(BOOST_HAS_WINTHREADS) @@ -662,21 +698,22 @@ thread::thread(const function0& threadfunc, attributes attr) thread::thread(const thread& other) : m_handle(other.m_handle) { - m_handle->addref(); + static_cast(m_handle)->addref(); } thread::~thread() { - if (m_handle && m_handle->release()) + if (m_handle && static_cast(m_handle)->release()) delete m_handle; } thread& thread::operator=(const thread& other) { - if (m_handle->release()) - delete m_handle; + thread_data* data = static_cast(m_handle); + if (data->release()) + delete data; m_handle = other.m_handle; - m_handle->addref(); + static_cast(m_handle)->addref(); return *this; } @@ -687,38 +724,38 @@ bool thread::operator==(const thread& other) const bool thread::operator!=(const thread& other) const { - return !operator==(other); + return m_handle != other.m_handle; } bool thread::operator<(const thread& other) const { - return std::less()(m_handle, m_handle); + return std::less()(m_handle, other.m_handle); } void thread::join() { - m_handle->join(); + static_cast(m_handle)->join(); } void thread::cancel() { - m_handle->cancel(); + static_cast(m_handle)->cancel(); } void thread::test_cancel() { thread self; - self.m_handle->test_cancel(); + static_cast(self.m_handle)->test_cancel(); } void thread::set_scheduling_parameter(int policy, const sched_param& param) { - m_handle->set_scheduling_parameter(policy, param); + static_cast(m_handle)->set_scheduling_parameter(policy, param); } void thread::get_scheduling_parameter(int& policy, sched_param& param) const { - m_handle->get_scheduling_parameter(policy, param); + static_cast(m_handle)->get_scheduling_parameter(policy, param); } int thread::max_priority(int policy) @@ -822,7 +859,7 @@ const void* thread::id() const #endif { std::cout << *this; - return m_handle->id(); + return static_cast(m_handle)->id(); } #if defined(BOOST_HAS_WINTHREADS) diff --git a/test/test_thread.cpp b/test/test_thread.cpp index 387245b2..b5da0c7a 100644 --- a/test/test_thread.cpp +++ b/test/test_thread.cpp @@ -29,6 +29,38 @@ namespace test_value = 999; } + void cancel_thread() + { + // Sleep long enough to let the main thread cancel us + boost::xtime xt; + BOOST_CHECK_EQUAL(boost::xtime_get(&xt, boost::TIME_UTC), + static_cast(boost::TIME_UTC)); + xt.sec += 3; + boost::thread::sleep(xt); + + // This block will test the cancellation guard. If it + // doesn't work, we'll be cancelled with out setting + // the test_value to 999. + { + boost::cancellation_guard guard; + boost::thread::test_cancel(); + } + + // This block tests the cancellation itself. If it + // works a thread_cancel exception will be thrown, + // and in the catch handler for it we'll set our + // exptected test_value of 999. + try + { + boost::thread::test_cancel(); + } + catch (boost::thread_cancel& cancel) + { + test_value = 999; + throw; // Make sure to re-throw! + } + } + struct thread_adapter { thread_adapter(void (*func)(boost::thread& parent), @@ -78,10 +110,25 @@ void test_creation() void test_comparison() { boost::thread self; + BOOST_CHECK_EQUAL(self, boost::thread()); + boost::thread thrd(thread_adapter(comparison_thread, self)); + boost::thread thrd2 = thrd; + + BOOST_CHECK(thrd != self); + BOOST_CHECK(thrd == thrd2); + thrd.join(); } +void test_cancel() +{ + test_value = 0; + boost::thread thrd(&cancel_thread); + thrd.cancel(); + BOOST_CHECK_EQUAL(test_value, 999); // only true if thread was cancelled +} + boost::unit_test_framework::test_suite* init_unit_test_suite(int, char*[]) { boost::unit_test_framework::test_suite* test =