Skip to content

Commit

Permalink
add is_cpu api (#10172)
Browse files Browse the repository at this point in the history
  • Loading branch information
hjchen2 authored Apr 20, 2023
1 parent 2d54365 commit b9d012c
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 6 deletions.
5 changes: 5 additions & 0 deletions oneflow/api/python/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,10 @@ static PyObject* PyTensorObject_dtype(PyObject* self, void* unused) {
END_HANDLE_ERRORS
}

static PyObject* PyTensorObject_is_cpu(PyObject* self, void* unused) {
return functional::CastToPyObject(PyTensor_Unpack(self)->is_cpu());
}

static PyObject* PyTensorObject_is_cuda(PyObject* self, void* unused) {
return functional::CastToPyObject(PyTensor_Unpack(self)->is_cuda());
}
Expand Down Expand Up @@ -701,6 +705,7 @@ static PyGetSetDef PyTensorObject_properties[] = {
{PYGETSET_NAME("ndim"), (getter)PyTensorObject_ndim, NULL, NULL, NULL},
{PYGETSET_NAME("shape"), (getter)PyTensorObject_shape, NULL, NULL, NULL},
{PYGETSET_NAME("dtype"), (getter)PyTensorObject_dtype, NULL, NULL, NULL},
{PYGETSET_NAME("is_cpu"), (getter)PyTensorObject_is_cpu, NULL, NULL, NULL},
{PYGETSET_NAME("is_cuda"), (getter)PyTensorObject_is_cuda, NULL, NULL, NULL},
{PYGETSET_NAME("grad"), (getter)PyTensorObject_grad, (setter)PyTensorObject_set_grad, NULL,
NULL},
Expand Down
4 changes: 2 additions & 2 deletions oneflow/api/python/framework/tensor_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ PyObject* concat_self(PyObject* self, PyObject* args) {
PyObject* ndarray_judgment_and_compatibility(PyObject* self, PyObject* other) {
if (PyArray_Check(other)) {
const auto& tensor = PyTensor_Unpack(self);
CHECK_OR_THROW(!tensor->is_cuda())
<< Error::RuntimeError() << "Can't convert cuda device type tensor to numpy";
CHECK_OR_THROW(tensor->is_cpu())
<< Error::RuntimeError() << "Can't convert non-cpu device tensor to numpy";
if (tensor->is_global()) {
Symbol<ParallelDesc> placement = ASSERT(tensor->parallel_desc());
auto ndsbp = ASSERT(tensor->nd_sbp());
Expand Down
8 changes: 6 additions & 2 deletions oneflow/core/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ std::shared_ptr<Tensor> Parameter::pin_memory() const {
}
}

bool LocalTensor::is_cpu() const { return CHECK_JUST(device())->type() == "cpu"; }
bool LocalTensor::is_cuda() const { return CHECK_JUST(device())->type() == "cuda"; }

Maybe<Tensor> LocalTensor::detach() const {
Expand Down Expand Up @@ -144,8 +145,8 @@ Maybe<void> LocalTensor::set_data(const std::shared_ptr<Tensor>& other) {
}

#define TENSOR_OFFLOAD_CHECK(is_offloaded, msg) \
if (!is_cuda()) { \
LOG(WARNING) << "Only cuda tensor can be offloaded."; \
if (is_cpu()) { \
LOG(WARNING) << "Only non-cpu tensor can be offloaded."; \
return Maybe<void>::Ok(); \
} \
if (is_offloaded_ != is_offloaded) { \
Expand Down Expand Up @@ -233,6 +234,9 @@ Maybe<GlobalTensor> GlobalTensor::MakeTensor(const std::shared_ptr<const Shape>&
return std::make_shared<GlobalTensor>(impl);
}

bool GlobalTensor::is_cpu() const {
return CHECK_JUST(parallel_desc())->device_type() == DeviceType::kCPU;
}
bool GlobalTensor::is_cuda() const {
return CHECK_JUST(parallel_desc())->device_type() == DeviceType::kCUDA;
}
Expand Down
8 changes: 8 additions & 0 deletions oneflow/core/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
virtual Maybe<Symbol<ParallelDesc>> parallel_desc() const = 0;
virtual Maybe<Symbol<Device>> device() const = 0;
virtual Maybe<Symbol<Device>*> mut_device() = 0;
virtual bool is_cpu() const = 0;
virtual bool is_cuda() const = 0;
virtual bool is_global() const = 0;
virtual bool is_local() const { return !is_global(); }
Expand Down Expand Up @@ -161,6 +162,10 @@ class StaticZerosTensor final : public Tensor {
Maybe<Symbol<ParallelDesc>> parallel_desc() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Symbol<Device>> device() const override { return device_; }
Maybe<Symbol<Device>*> mut_device() override { RETURN_ERROR_WITH_BUG_PROMPT(); }
bool is_cpu() const override {
PRINT_BUG_PROMPT_AND_ABORT();
return false;
}
bool is_cuda() const override {
PRINT_BUG_PROMPT_AND_ABORT();
return false;
Expand Down Expand Up @@ -344,6 +349,7 @@ class ProxyTensor : public TensorIf<DerivedT> {
}
virtual Maybe<Symbol<Device>> device() const override { return tensor_->device(); }
virtual Maybe<Symbol<Device>*> mut_device() override { return tensor_->mut_device(); }
virtual bool is_cpu() const override { return tensor_->is_cpu(); }
virtual bool is_cuda() const override { return tensor_->is_cuda(); }
virtual bool is_global() const override { return tensor_->is_global(); }
virtual bool is_local() const override { return tensor_->is_local(); }
Expand Down Expand Up @@ -515,6 +521,7 @@ class LocalTensor final : public TensorIf<LocalTensor> {
Maybe<Symbol<Device>*> mut_device() override { return impl_->mut_device(); }
bool is_lazy() const override { return impl_->is_lazy(); }
bool is_global() const override { return false; }
bool is_cpu() const override;
bool is_cuda() const override;
std::shared_ptr<Tensor> contiguous() const override;

Expand Down Expand Up @@ -644,6 +651,7 @@ class GlobalTensor final : public TensorIf<GlobalTensor> {
return impl_->consumer_nd_sbp_constraint();
}
Maybe<LocalTensor> cur_rank_phy_tensor() const override { return impl_->cur_rank_phy_tensor(); }
bool is_cpu() const override;
bool is_cuda() const override;
std::shared_ptr<Tensor> contiguous() const override;
Maybe<Tensor> data() override { return this->detach(); }
Expand Down
3 changes: 1 addition & 2 deletions python/oneflow/nn/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
f"Given normalized_shape={normalized_shape}, expected input with shape [*, {str(normalized_shape)[1:-1]}], but got input of size {input.shape}"
)

input_device_type = input.device.type if input.is_local else input.placement.type
if input_device_type == "cpu":
if input.is_cpu:
reduce_axis = []
for dim in range(len(input.shape)):
if dim >= begin_norm_axis:
Expand Down

0 comments on commit b9d012c

Please sign in to comment.