diff --git a/src/core/hle/kernel/kernel.h b/src/core/hle/kernel/kernel.h index 5bf9a2bfc8..a9a893f41a 100644 --- a/src/core/hle/kernel/kernel.h +++ b/src/core/hle/kernel/kernel.h @@ -60,26 +60,34 @@ class Object : NonCopyable { public: virtual ~Object() {} Handle GetHandle() const { return handle; } + virtual std::string GetTypeName() const { return "[BAD KERNEL OBJECT TYPE]"; } virtual std::string GetName() const { return "[UNKNOWN KERNEL OBJECT]"; } virtual Kernel::HandleType GetHandleType() const = 0; /** - * Check if this object is available - * @return True if the current thread should wait due to this object being unavailable + * Check if a thread can wait on the object + * @return True if a thread can wait on the object, otherwise false */ - virtual ResultVal Wait() { - LOG_ERROR(Kernel, "(UNIMPLEMENTED)"); - return UnimplementedFunction(ErrorModule::Kernel); - } + bool IsWaitable() const { + switch (GetHandleType()) { + case HandleType::Event: + case HandleType::Mutex: + case HandleType::Thread: + case HandleType::Semaphore: + case HandleType::Timer: + return true; - /** - * Acquire/lock the this object if it is available - * @return True if we were able to acquire this object, otherwise false - */ - virtual ResultVal Acquire() { - LOG_ERROR(Kernel, "(UNIMPLEMENTED)"); - return UnimplementedFunction(ErrorModule::Kernel); + case HandleType::Unknown: + case HandleType::Port: + case HandleType::SharedMemory: + case HandleType::Redirection: + case HandleType::Process: + case HandleType::AddressArbiter: + return false; + } + + return false; } private: @@ -107,6 +115,24 @@ using SharedPtr = boost::intrusive_ptr; class WaitObject : public Object { public: + /** + * Check if this object is available + * @return True if the current thread should wait due to this object being unavailable + */ + virtual ResultVal Wait() { + LOG_ERROR(Kernel, "(UNIMPLEMENTED)"); + return UnimplementedFunction(ErrorModule::Kernel); + } + + /** + * Acquire/lock the this object if it is available + * @return True if we were able to acquire this object, otherwise false + */ + virtual ResultVal Acquire() { + LOG_ERROR(Kernel, "(UNIMPLEMENTED)"); + return UnimplementedFunction(ErrorModule::Kernel); + } + /** * Add a thread to wait on this object * @param thread Pointer to thread to add @@ -186,14 +212,14 @@ public: /** * Looks up a handle. - * @returns Pointer to the looked-up object, or `nullptr` if the handle is not valid. + * @return Pointer to the looked-up object, or `nullptr` if the handle is not valid. */ SharedPtr GetGeneric(Handle handle) const; /** * Looks up a handle while verifying its type. - * @returns Pointer to the looked-up object, or `nullptr` if the handle is not valid or its - * type differs from the handle type `T::HANDLE_TYPE`. + * @return Pointer to the looked-up object, or `nullptr` if the handle is not valid or its + * type differs from the handle type `T::HANDLE_TYPE`. */ template SharedPtr Get(Handle handle) const { @@ -204,6 +230,19 @@ public: return nullptr; } + /** + * Looks up a handle while verifying that it is an object that a thread can wait on + * @return Pointer to the looked-up object, or `nullptr` if the handle is not valid or it is + * not a waitable object. + */ + SharedPtr GetWaitObject(Handle handle) const { + SharedPtr object = GetGeneric(handle); + if (object != nullptr && object->IsWaitable()) { + return boost::static_pointer_cast(std::move(object)); + } + return nullptr; + } + /// Closes all handles held in this table. void Clear(); diff --git a/src/core/hle/kernel/thread.cpp b/src/core/hle/kernel/thread.cpp index 16865ccc40..271828ea7a 100644 --- a/src/core/hle/kernel/thread.cpp +++ b/src/core/hle/kernel/thread.cpp @@ -210,7 +210,7 @@ void WaitCurrentThread_Sleep() { ChangeThreadState(thread, ThreadStatus(THREADSTATUS_WAIT | (thread->status & THREADSTATUS_SUSPEND))); } -void WaitCurrentThread_WaitSynchronization(WaitObject* wait_object, bool wait_all) { +void WaitCurrentThread_WaitSynchronization(SharedPtr wait_object, bool wait_all) { Thread* thread = GetCurrentThread(); thread->wait_all = wait_all; thread->wait_address = 0; diff --git a/src/core/hle/kernel/thread.h b/src/core/hle/kernel/thread.h index 9907aa6e15..a3a17e6c06 100644 --- a/src/core/hle/kernel/thread.h +++ b/src/core/hle/kernel/thread.h @@ -136,7 +136,7 @@ void WaitCurrentThread_Sleep(); * @param wait_object Kernel object that we are waiting on * @param wait_all If true, wait on all objects before resuming (for WaitSynchronizationN only) */ -void WaitCurrentThread_WaitSynchronization(WaitObject* wait_object, bool wait_all=false); +void WaitCurrentThread_WaitSynchronization(SharedPtr wait_object, bool wait_all = false); /** * Waits the current thread from an ArbitrateAddress call diff --git a/src/core/hle/svc.cpp b/src/core/hle/svc.cpp index 5e9c38973c..8df8616694 100644 --- a/src/core/hle/svc.cpp +++ b/src/core/hle/svc.cpp @@ -115,7 +115,7 @@ static Result CloseHandle(Handle handle) { /// Wait for a handle to synchronize, timeout after the specified nanoseconds static Result WaitSynchronization1(Handle handle, s64 nano_seconds) { - Kernel::WaitObject* object = static_cast(Kernel::g_handle_table.GetGeneric(handle).get()); + auto object = Kernel::g_handle_table.GetWaitObject(handle); if (object == nullptr) return InvalidHandle(ErrorModule::Kernel).raw; @@ -163,7 +163,7 @@ static Result WaitSynchronizationN(s32* out, Handle* handles, s32 handle_count, if (handle_count != 0) { bool selected = false; // True once an object has been selected for (int i = 0; i < handle_count; ++i) { - Kernel::WaitObject* object = static_cast(Kernel::g_handle_table.GetGeneric(handles[i]).get()); + auto object = Kernel::g_handle_table.GetWaitObject(handles[i]); if (object == nullptr) return InvalidHandle(ErrorModule::Kernel).raw;