Skip to content
58 changes: 58 additions & 0 deletions runtime/core/device_allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/core/device_allocator.h>

#include <executorch/runtime/platform/assert.h>

namespace executorch {
namespace runtime {

DeviceAllocatorRegistry& DeviceAllocatorRegistry::instance() {
static DeviceAllocatorRegistry registry;
return registry;
}

void DeviceAllocatorRegistry::register_allocator(
etensor::DeviceType type,
DeviceAllocator* alloc) {
auto index = static_cast<size_t>(type);
ET_CHECK_MSG(
index < etensor::kNumDeviceTypes,
"Invalid device type: %d",
static_cast<int>(type));
ET_CHECK_MSG(
allocators_[index] == nullptr,
"Allocator already registered for device type: %d",
static_cast<int>(type));
allocators_[index] = alloc;
}

DeviceAllocator* DeviceAllocatorRegistry::get_allocator(
etensor::DeviceType type) {
auto index = static_cast<size_t>(type);
if (index >= etensor::kNumDeviceTypes) {
return nullptr;
}
return allocators_[index];
}

// Convenience free functions

void register_device_allocator(
etensor::DeviceType type,
DeviceAllocator* alloc) {
DeviceAllocatorRegistry::instance().register_allocator(type, alloc);
}

DeviceAllocator* get_device_allocator(etensor::DeviceType type) {
return DeviceAllocatorRegistry::instance().get_allocator(type);
}

} // namespace runtime
} // namespace executorch
156 changes: 156 additions & 0 deletions runtime/core/device_allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstddef>

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/portable_type/device.h>
#include <executorch/runtime/core/result.h>

namespace executorch {
namespace runtime {

/**
* Abstract interface for device-specific memory allocation.
*
* Each device type (CUDA, etc.) provides a concrete implementation
* that handles memory allocation on that device. Implementations are
* expected to be singletons with static lifetime, registered via
* DeviceAllocatorRegistry.

*/
class DeviceAllocator {
public:
virtual ~DeviceAllocator() = default;
/**
* Allocate device memory.
*
* @param nbytes Number of bytes to allocate.
* @param index The device index.
* @return A Result containing the device pointer on success, or an error.
*/
virtual Result<void*> allocate(size_t nbytes, etensor::DeviceIndex index) = 0;

/**
* Deallocate device memory previously allocated via allocate().
*
* @param ptr Pointer to the memory to deallocate.
* @param index The device index.
*/
virtual void deallocate(void* ptr, etensor::DeviceIndex index) = 0;

/**
* Copy data from host memory to device memory.
*
* @param dst Destination pointer (device memory).
* @param src Source pointer (host memory).
* @param nbytes Number of bytes to copy.
* @param index The device index.
* @return Error::Ok on success, or an appropriate error code on failure.
*/
virtual Error copy_host_to_device(
void* dst,
const void* src,
size_t nbytes,
etensor::DeviceIndex index) = 0;

/**
* Copy data from device memory to host memory.
*
* @param dst Destination pointer (host memory).
* @param src Source pointer (device memory).
* @param nbytes Number of bytes to copy.
* @param index The device index.
* @return Error::Ok on success, or an appropriate error code on failure.
*/
virtual Error copy_device_to_host(
void* dst,
const void* src,
size_t nbytes,
etensor::DeviceIndex index) = 0;

/**
* Returns the device type this allocator handles.
*/
virtual etensor::DeviceType device_type() const = 0;
};

/**
* Registry for device allocators.
*
* Provides a global mapping from DeviceType to DeviceAllocator instances.
* Device allocators register themselves at static initialization time,
* and the runtime queries the registry to find the appropriate allocator
* for a given device type.
*/
class DeviceAllocatorRegistry {
public:
/**
* Returns the singleton instance of the registry.
*/
static DeviceAllocatorRegistry& instance();

/**
* Register an allocator for a specific device type.
*
* @param type The device type this allocator handles.
* @param alloc Pointer to the allocator (must have static lifetime).
*/
void register_allocator(etensor::DeviceType type, DeviceAllocator* alloc);

/**
* Get the allocator for a specific device type.
*
* @param type The device type.
* @return Pointer to the allocator, or nullptr if not registered.
*/
DeviceAllocator* get_allocator(etensor::DeviceType type);

private:
DeviceAllocatorRegistry() = default;

// Fixed-size array indexed by device type. This avoids dynamic allocation
// and is suitable for embedded environments.
DeviceAllocator* allocators_[etensor::kNumDeviceTypes] = {};
};

// Convenience free functions

/**
* Register a device allocator for a specific device type.
*
* @param type The device type this allocator handles.
* @param alloc Pointer to the allocator (must have static lifetime).
*/
void register_device_allocator(
etensor::DeviceType type,
DeviceAllocator* alloc);

/**
* Get the device allocator for a specific device type.
*
* @param type The device type.
* @return Pointer to the allocator, or nullptr if not registered.
*/
DeviceAllocator* get_device_allocator(etensor::DeviceType type);

} // namespace runtime
} // namespace executorch

namespace torch {
namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::runtime::DeviceAllocator;
using ::executorch::runtime::DeviceAllocatorRegistry;
using ::executorch::runtime::get_device_allocator;
using ::executorch::runtime::register_device_allocator;
} // namespace executor
} // namespace torch
Loading
Loading