diff --git a/src/core/core.cpp b/src/core/core.cpp index ab7181a05e..6dda20faa7 100644 --- a/src/core/core.cpp +++ b/src/core/core.cpp @@ -78,6 +78,7 @@ FileSys::VirtualFile GetGameFileFromPath(const FileSys::VirtualFilesystem& vfs, return vfs->OpenFile(path, FileSys::Mode::Read); } struct System::Impl { + explicit Impl(System& system) : kernel{system} {} Cpu& CurrentCpuCore() { return cpu_core_manager.GetCurrentCore(); @@ -95,7 +96,7 @@ struct System::Impl { LOG_DEBUG(HW_Memory, "initialized OK"); core_timing.Initialize(); - kernel.Initialize(core_timing); + kernel.Initialize(); const auto current_time = std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()); @@ -265,7 +266,7 @@ struct System::Impl { Core::FrameLimiter frame_limiter; }; -System::System() : impl{std::make_unique()} {} +System::System() : impl{std::make_unique(*this)} {} System::~System() = default; Cpu& System::CurrentCpuCore() { diff --git a/src/core/hle/kernel/address_arbiter.cpp b/src/core/hle/kernel/address_arbiter.cpp index a250d088dc..9780a78491 100644 --- a/src/core/hle/kernel/address_arbiter.cpp +++ b/src/core/hle/kernel/address_arbiter.cpp @@ -9,6 +9,7 @@ #include "common/common_types.h" #include "core/core.h" #include "core/core_cpu.h" +#include "core/hle/kernel/address_arbiter.h" #include "core/hle/kernel/errors.h" #include "core/hle/kernel/object.h" #include "core/hle/kernel/process.h" @@ -17,32 +18,144 @@ #include "core/hle/result.h" #include "core/memory.h" -namespace Kernel::AddressArbiter { +namespace Kernel { +namespace { +// Wake up num_to_wake (or all) threads in a vector. +void WakeThreads(const std::vector>& waiting_threads, s32 num_to_wake) { + // Only process up to 'target' threads, unless 'target' is <= 0, in which case process + // them all. + std::size_t last = waiting_threads.size(); + if (num_to_wake > 0) { + last = num_to_wake; + } -// Performs actual address waiting logic. -static ResultCode WaitForAddress(VAddr address, s64 timeout) { - SharedPtr current_thread = GetCurrentThread(); + // Signal the waiting threads. + for (std::size_t i = 0; i < last; i++) { + ASSERT(waiting_threads[i]->GetStatus() == ThreadStatus::WaitArb); + waiting_threads[i]->SetWaitSynchronizationResult(RESULT_SUCCESS); + waiting_threads[i]->SetArbiterWaitAddress(0); + waiting_threads[i]->ResumeFromWait(); + } +} +} // Anonymous namespace + +AddressArbiter::AddressArbiter(Core::System& system) : system{system} {} +AddressArbiter::~AddressArbiter() = default; + +ResultCode AddressArbiter::SignalToAddress(VAddr address, s32 num_to_wake) { + const std::vector> waiting_threads = GetThreadsWaitingOnAddress(address); + WakeThreads(waiting_threads, num_to_wake); + return RESULT_SUCCESS; +} + +ResultCode AddressArbiter::IncrementAndSignalToAddressIfEqual(VAddr address, s32 value, + s32 num_to_wake) { + // Ensure that we can write to the address. + if (!Memory::IsValidVirtualAddress(address)) { + return ERR_INVALID_ADDRESS_STATE; + } + + if (static_cast(Memory::Read32(address)) != value) { + return ERR_INVALID_STATE; + } + + Memory::Write32(address, static_cast(value + 1)); + return SignalToAddress(address, num_to_wake); +} + +ResultCode AddressArbiter::ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value, + s32 num_to_wake) { + // Ensure that we can write to the address. + if (!Memory::IsValidVirtualAddress(address)) { + return ERR_INVALID_ADDRESS_STATE; + } + + // Get threads waiting on the address. + const std::vector> waiting_threads = GetThreadsWaitingOnAddress(address); + + // Determine the modified value depending on the waiting count. + s32 updated_value; + if (waiting_threads.empty()) { + updated_value = value - 1; + } else if (num_to_wake <= 0 || waiting_threads.size() <= static_cast(num_to_wake)) { + updated_value = value + 1; + } else { + updated_value = value; + } + + if (static_cast(Memory::Read32(address)) != value) { + return ERR_INVALID_STATE; + } + + Memory::Write32(address, static_cast(updated_value)); + WakeThreads(waiting_threads, num_to_wake); + return RESULT_SUCCESS; +} + +ResultCode AddressArbiter::WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout, + bool should_decrement) { + // Ensure that we can read the address. + if (!Memory::IsValidVirtualAddress(address)) { + return ERR_INVALID_ADDRESS_STATE; + } + + const s32 cur_value = static_cast(Memory::Read32(address)); + if (cur_value >= value) { + return ERR_INVALID_STATE; + } + + if (should_decrement) { + Memory::Write32(address, static_cast(cur_value - 1)); + } + + // Short-circuit without rescheduling, if timeout is zero. + if (timeout == 0) { + return RESULT_TIMEOUT; + } + + return WaitForAddress(address, timeout); +} + +ResultCode AddressArbiter::WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout) { + // Ensure that we can read the address. + if (!Memory::IsValidVirtualAddress(address)) { + return ERR_INVALID_ADDRESS_STATE; + } + // Only wait for the address if equal. + if (static_cast(Memory::Read32(address)) != value) { + return ERR_INVALID_STATE; + } + // Short-circuit without rescheduling, if timeout is zero. + if (timeout == 0) { + return RESULT_TIMEOUT; + } + + return WaitForAddress(address, timeout); +} + +ResultCode AddressArbiter::WaitForAddress(VAddr address, s64 timeout) { + SharedPtr current_thread = system.CurrentScheduler().GetCurrentThread(); current_thread->SetArbiterWaitAddress(address); current_thread->SetStatus(ThreadStatus::WaitArb); current_thread->InvalidateWakeupCallback(); current_thread->WakeAfterDelay(timeout); - Core::System::GetInstance().CpuCore(current_thread->GetProcessorID()).PrepareReschedule(); + system.CpuCore(current_thread->GetProcessorID()).PrepareReschedule(); return RESULT_TIMEOUT; } -// Gets the threads waiting on an address. -static std::vector> GetThreadsWaitingOnAddress(VAddr address) { - const auto RetrieveWaitingThreads = [](std::size_t core_index, - std::vector>& waiting_threads, - VAddr arb_addr) { - const auto& scheduler = Core::System::GetInstance().Scheduler(core_index); +std::vector> AddressArbiter::GetThreadsWaitingOnAddress(VAddr address) const { + const auto RetrieveWaitingThreads = [this](std::size_t core_index, + std::vector>& waiting_threads, + VAddr arb_addr) { + const auto& scheduler = system.Scheduler(core_index); const auto& thread_list = scheduler.GetThreadList(); for (const auto& thread : thread_list) { - if (thread->GetArbiterWaitAddress() == arb_addr) + if (thread->GetArbiterWaitAddress() == arb_addr) { waiting_threads.push_back(thread); + } } }; @@ -61,118 +174,4 @@ static std::vector> GetThreadsWaitingOnAddress(VAddr address) return threads; } - -// Wake up num_to_wake (or all) threads in a vector. -static void WakeThreads(std::vector>& waiting_threads, s32 num_to_wake) { - // Only process up to 'target' threads, unless 'target' is <= 0, in which case process - // them all. - std::size_t last = waiting_threads.size(); - if (num_to_wake > 0) - last = num_to_wake; - - // Signal the waiting threads. - for (std::size_t i = 0; i < last; i++) { - ASSERT(waiting_threads[i]->GetStatus() == ThreadStatus::WaitArb); - waiting_threads[i]->SetWaitSynchronizationResult(RESULT_SUCCESS); - waiting_threads[i]->SetArbiterWaitAddress(0); - waiting_threads[i]->ResumeFromWait(); - } -} - -// Signals an address being waited on. -ResultCode SignalToAddress(VAddr address, s32 num_to_wake) { - std::vector> waiting_threads = GetThreadsWaitingOnAddress(address); - - WakeThreads(waiting_threads, num_to_wake); - return RESULT_SUCCESS; -} - -// Signals an address being waited on and increments its value if equal to the value argument. -ResultCode IncrementAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_to_wake) { - // Ensure that we can write to the address. - if (!Memory::IsValidVirtualAddress(address)) { - return ERR_INVALID_ADDRESS_STATE; - } - - if (static_cast(Memory::Read32(address)) == value) { - Memory::Write32(address, static_cast(value + 1)); - } else { - return ERR_INVALID_STATE; - } - - return SignalToAddress(address, num_to_wake); -} - -// Signals an address being waited on and modifies its value based on waiting thread count if equal -// to the value argument. -ResultCode ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value, - s32 num_to_wake) { - // Ensure that we can write to the address. - if (!Memory::IsValidVirtualAddress(address)) { - return ERR_INVALID_ADDRESS_STATE; - } - - // Get threads waiting on the address. - std::vector> waiting_threads = GetThreadsWaitingOnAddress(address); - - // Determine the modified value depending on the waiting count. - s32 updated_value; - if (waiting_threads.empty()) { - updated_value = value - 1; - } else if (num_to_wake <= 0 || waiting_threads.size() <= static_cast(num_to_wake)) { - updated_value = value + 1; - } else { - updated_value = value; - } - - if (static_cast(Memory::Read32(address)) == value) { - Memory::Write32(address, static_cast(updated_value)); - } else { - return ERR_INVALID_STATE; - } - - WakeThreads(waiting_threads, num_to_wake); - return RESULT_SUCCESS; -} - -// Waits on an address if the value passed is less than the argument value, optionally decrementing. -ResultCode WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout, bool should_decrement) { - // Ensure that we can read the address. - if (!Memory::IsValidVirtualAddress(address)) { - return ERR_INVALID_ADDRESS_STATE; - } - - s32 cur_value = static_cast(Memory::Read32(address)); - if (cur_value < value) { - if (should_decrement) { - Memory::Write32(address, static_cast(cur_value - 1)); - } - } else { - return ERR_INVALID_STATE; - } - // Short-circuit without rescheduling, if timeout is zero. - if (timeout == 0) { - return RESULT_TIMEOUT; - } - - return WaitForAddress(address, timeout); -} - -// Waits on an address if the value passed is equal to the argument value. -ResultCode WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout) { - // Ensure that we can read the address. - if (!Memory::IsValidVirtualAddress(address)) { - return ERR_INVALID_ADDRESS_STATE; - } - // Only wait for the address if equal. - if (static_cast(Memory::Read32(address)) != value) { - return ERR_INVALID_STATE; - } - // Short-circuit without rescheduling, if timeout is zero. - if (timeout == 0) { - return RESULT_TIMEOUT; - } - - return WaitForAddress(address, timeout); -} -} // namespace Kernel::AddressArbiter +} // namespace Kernel diff --git a/src/core/hle/kernel/address_arbiter.h b/src/core/hle/kernel/address_arbiter.h index b58f21becb..e0c36f2e3e 100644 --- a/src/core/hle/kernel/address_arbiter.h +++ b/src/core/hle/kernel/address_arbiter.h @@ -5,28 +5,68 @@ #pragma once #include "common/common_types.h" +#include "core/hle/kernel/address_arbiter.h" union ResultCode; -namespace Kernel::AddressArbiter { +namespace Core { +class System; +} -enum class ArbitrationType { - WaitIfLessThan = 0, - DecrementAndWaitIfLessThan = 1, - WaitIfEqual = 2, +namespace Kernel { + +class Thread; + +class AddressArbiter { +public: + enum class ArbitrationType { + WaitIfLessThan = 0, + DecrementAndWaitIfLessThan = 1, + WaitIfEqual = 2, + }; + + enum class SignalType { + Signal = 0, + IncrementAndSignalIfEqual = 1, + ModifyByWaitingCountAndSignalIfEqual = 2, + }; + + explicit AddressArbiter(Core::System& system); + ~AddressArbiter(); + + AddressArbiter(const AddressArbiter&) = delete; + AddressArbiter& operator=(const AddressArbiter&) = delete; + + AddressArbiter(AddressArbiter&&) = default; + AddressArbiter& operator=(AddressArbiter&&) = delete; + + /// Signals an address being waited on. + ResultCode SignalToAddress(VAddr address, s32 num_to_wake); + + /// Signals an address being waited on and increments its value if equal to the value argument. + ResultCode IncrementAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_to_wake); + + /// Signals an address being waited on and modifies its value based on waiting thread count if + /// equal to the value argument. + ResultCode ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value, + s32 num_to_wake); + + /// Waits on an address if the value passed is less than the argument value, + /// optionally decrementing. + ResultCode WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout, + bool should_decrement); + + /// Waits on an address if the value passed is equal to the argument value. + ResultCode WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout); + +private: + // Waits on the given address with a timeout in nanoseconds + ResultCode WaitForAddress(VAddr address, s64 timeout); + + // Gets the threads waiting on an address. + std::vector> GetThreadsWaitingOnAddress(VAddr address) const; + + Core::System& system; }; -enum class SignalType { - Signal = 0, - IncrementAndSignalIfEqual = 1, - ModifyByWaitingCountAndSignalIfEqual = 2, -}; - -ResultCode SignalToAddress(VAddr address, s32 num_to_wake); -ResultCode IncrementAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_to_wake); -ResultCode ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_to_wake); - -ResultCode WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout, bool should_decrement); -ResultCode WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout); - -} // namespace Kernel::AddressArbiter +} // namespace Kernel diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp index dd749eed48..04ea9349ee 100644 --- a/src/core/hle/kernel/kernel.cpp +++ b/src/core/hle/kernel/kernel.cpp @@ -12,6 +12,7 @@ #include "core/core.h" #include "core/core_timing.h" +#include "core/hle/kernel/address_arbiter.h" #include "core/hle/kernel/client_port.h" #include "core/hle/kernel/handle_table.h" #include "core/hle/kernel/kernel.h" @@ -86,11 +87,13 @@ static void ThreadWakeupCallback(u64 thread_handle, [[maybe_unused]] int cycles_ } struct KernelCore::Impl { - void Initialize(KernelCore& kernel, Core::Timing::CoreTiming& core_timing) { + explicit Impl(Core::System& system) : address_arbiter{system}, system{system} {} + + void Initialize(KernelCore& kernel) { Shutdown(); InitializeSystemResourceLimit(kernel); - InitializeThreads(core_timing); + InitializeThreads(); } void Shutdown() { @@ -122,9 +125,9 @@ struct KernelCore::Impl { ASSERT(system_resource_limit->SetLimitValue(ResourceType::Sessions, 900).IsSuccess()); } - void InitializeThreads(Core::Timing::CoreTiming& core_timing) { + void InitializeThreads() { thread_wakeup_event_type = - core_timing.RegisterEvent("ThreadWakeupCallback", ThreadWakeupCallback); + system.CoreTiming().RegisterEvent("ThreadWakeupCallback", ThreadWakeupCallback); } std::atomic next_object_id{0}; @@ -135,6 +138,8 @@ struct KernelCore::Impl { std::vector> process_list; Process* current_process = nullptr; + Kernel::AddressArbiter address_arbiter; + SharedPtr system_resource_limit; Core::Timing::EventType* thread_wakeup_event_type = nullptr; @@ -145,15 +150,18 @@ struct KernelCore::Impl { /// Map of named ports managed by the kernel, which can be retrieved using /// the ConnectToPort SVC. NamedPortTable named_ports; + + // System context + Core::System& system; }; -KernelCore::KernelCore() : impl{std::make_unique()} {} +KernelCore::KernelCore(Core::System& system) : impl{std::make_unique(system)} {} KernelCore::~KernelCore() { Shutdown(); } -void KernelCore::Initialize(Core::Timing::CoreTiming& core_timing) { - impl->Initialize(*this, core_timing); +void KernelCore::Initialize() { + impl->Initialize(*this); } void KernelCore::Shutdown() { @@ -184,6 +192,14 @@ const Process* KernelCore::CurrentProcess() const { return impl->current_process; } +AddressArbiter& KernelCore::AddressArbiter() { + return impl->address_arbiter; +} + +const AddressArbiter& KernelCore::AddressArbiter() const { + return impl->address_arbiter; +} + void KernelCore::AddNamedPort(std::string name, SharedPtr port) { impl->named_ports.emplace(std::move(name), std::move(port)); } diff --git a/src/core/hle/kernel/kernel.h b/src/core/hle/kernel/kernel.h index 154bced42b..4d292aca9c 100644 --- a/src/core/hle/kernel/kernel.h +++ b/src/core/hle/kernel/kernel.h @@ -11,6 +11,10 @@ template class ResultVal; +namespace Core { +class System; +} + namespace Core::Timing { class CoreTiming; struct EventType; @@ -18,6 +22,7 @@ struct EventType; namespace Kernel { +class AddressArbiter; class ClientPort; class HandleTable; class Process; @@ -30,7 +35,14 @@ private: using NamedPortTable = std::unordered_map>; public: - KernelCore(); + /// Constructs an instance of the kernel using the given System + /// instance as a context for any necessary system-related state, + /// such as threads, CPU core state, etc. + /// + /// @post After execution of the constructor, the provided System + /// object *must* outlive the kernel instance itself. + /// + explicit KernelCore(Core::System& system); ~KernelCore(); KernelCore(const KernelCore&) = delete; @@ -40,11 +52,7 @@ public: KernelCore& operator=(KernelCore&&) = delete; /// Resets the kernel to a clean slate for use. - /// - /// @param core_timing CoreTiming instance used to create any necessary - /// kernel-specific callback events. - /// - void Initialize(Core::Timing::CoreTiming& core_timing); + void Initialize(); /// Clears all resources in use by the kernel instance. void Shutdown(); @@ -67,6 +75,12 @@ public: /// Retrieves a const pointer to the current process. const Process* CurrentProcess() const; + /// Provides a reference to the kernel's address arbiter. + Kernel::AddressArbiter& AddressArbiter(); + + /// Provides a const reference to the kernel's address arbiter. + const Kernel::AddressArbiter& AddressArbiter() const; + /// Adds a port to the named port table void AddNamedPort(std::string name, SharedPtr port); diff --git a/src/core/hle/kernel/svc.cpp b/src/core/hle/kernel/svc.cpp index 223d717e27..75b88a3339 100644 --- a/src/core/hle/kernel/svc.cpp +++ b/src/core/hle/kernel/svc.cpp @@ -1478,13 +1478,14 @@ static ResultCode WaitForAddress(VAddr address, u32 type, s32 value, s64 timeout return ERR_INVALID_ADDRESS; } + auto& address_arbiter = Core::System::GetInstance().Kernel().AddressArbiter(); switch (static_cast(type)) { case AddressArbiter::ArbitrationType::WaitIfLessThan: - return AddressArbiter::WaitForAddressIfLessThan(address, value, timeout, false); + return address_arbiter.WaitForAddressIfLessThan(address, value, timeout, false); case AddressArbiter::ArbitrationType::DecrementAndWaitIfLessThan: - return AddressArbiter::WaitForAddressIfLessThan(address, value, timeout, true); + return address_arbiter.WaitForAddressIfLessThan(address, value, timeout, true); case AddressArbiter::ArbitrationType::WaitIfEqual: - return AddressArbiter::WaitForAddressIfEqual(address, value, timeout); + return address_arbiter.WaitForAddressIfEqual(address, value, timeout); default: LOG_ERROR(Kernel_SVC, "Invalid arbitration type, expected WaitIfLessThan, DecrementAndWaitIfLessThan " @@ -1509,13 +1510,14 @@ static ResultCode SignalToAddress(VAddr address, u32 type, s32 value, s32 num_to return ERR_INVALID_ADDRESS; } + auto& address_arbiter = Core::System::GetInstance().Kernel().AddressArbiter(); switch (static_cast(type)) { case AddressArbiter::SignalType::Signal: - return AddressArbiter::SignalToAddress(address, num_to_wake); + return address_arbiter.SignalToAddress(address, num_to_wake); case AddressArbiter::SignalType::IncrementAndSignalIfEqual: - return AddressArbiter::IncrementAndSignalToAddressIfEqual(address, value, num_to_wake); + return address_arbiter.IncrementAndSignalToAddressIfEqual(address, value, num_to_wake); case AddressArbiter::SignalType::ModifyByWaitingCountAndSignalIfEqual: - return AddressArbiter::ModifyByWaitingCountAndSignalToAddressIfEqual(address, value, + return address_arbiter.ModifyByWaitingCountAndSignalToAddressIfEqual(address, value, num_to_wake); default: LOG_ERROR(Kernel_SVC, diff --git a/src/tests/core/arm/arm_test_common.cpp b/src/tests/core/arm/arm_test_common.cpp index 9b8a44fa19..ea27ef90da 100644 --- a/src/tests/core/arm/arm_test_common.cpp +++ b/src/tests/core/arm/arm_test_common.cpp @@ -13,11 +13,11 @@ namespace ArmTests { TestEnvironment::TestEnvironment(bool mutable_memory_) - : mutable_memory(mutable_memory_), test_memory(std::make_shared(this)) { - + : mutable_memory(mutable_memory_), + test_memory(std::make_shared(this)), kernel{Core::System::GetInstance()} { auto process = Kernel::Process::Create(kernel, ""); kernel.MakeCurrentProcess(process.get()); - page_table = &Core::CurrentProcess()->VMManager().page_table; + page_table = &process->VMManager().page_table; std::fill(page_table->pointers.begin(), page_table->pointers.end(), nullptr); page_table->special_regions.clear();