copy, empty and default 32bit

This commit is contained in:
Dun Liang 2020-07-22 22:40:49 +08:00
parent 29238b8cfb
commit 42a2abfe50
7 changed files with 136 additions and 8 deletions

View File

@ -198,18 +198,13 @@ def clean():
cast = unary
def array(data, dtype=None):
if type(data) == core.Var:
if isinstance(data, core.Var):
if dtype is None:
return cast(data, data.dtype)
return data.clone()
return cast(data, dtype)
if dtype != None:
return ops.array(np.array(data, dtype))
if type(data) == np.ndarray:
if data.flags.c_contiguous:
return ops.array(data)
else:
return ops.array(data.copy())
return ops.array(np.array(data))
return ops.array(data)
def grad(loss, targets):
if type(targets) == core.Var:

46
src/ops/copy_op.cc Normal file
View File

@ -0,0 +1,46 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "var.h"
#include "ops/op_register.h"
#include "ops/copy_op.h"
namespace jittor {
CopyOp::CopyOp(Var* x) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
auto y = create_output(nullptr, x->dtype());
if (x->name.ptr)
y->name = x->name;
}
VarPtr CopyOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return dout;
}
void CopyOp::infer_shape() {
outputs().front()->set_shape(inputs().front()->shape);
}
void CopyOp::run() {
auto x = inputs().front();
auto size = x->size;
auto x_ptr = x->mem_ptr;
auto y_ptr = outputs().front()->mem_ptr;
if (flags.get(NodeFlags::_cpu)) {
std::memcpy(y_ptr, x_ptr, size);
}
#ifdef HAS_CUDA
else {
std::cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0);
}
#endif
}
} // jittor

22
src/ops/copy_op.h Normal file
View File

@ -0,0 +1,22 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct CopyOp : Op {
CopyOp(Var* x);
const char* name() const override { return "copy"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
void run() override;
};
} // jittor

21
src/ops/empty_op.cc Normal file
View File

@ -0,0 +1,21 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "var.h"
#include "ops/array_op.h"
#include "ops/op_register.h"
#include "ops/empty_op.h"
namespace jittor {
EmptyOp::EmptyOp(NanoVector shape, NanoString dtype) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
create_output(shape, dtype);
}
} // jittor

19
src/ops/empty_op.h Normal file
View File

@ -0,0 +1,19 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct EmptyOp : Op {
EmptyOp(NanoVector shape, NanoString dtype=ns_float32);
const char* name() const override { return "empty"; }
};
} // jittor

View File

@ -308,6 +308,23 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
args.buffer.reset(new char[size]);
args.ptr = (void*)args.buffer.get();
memcpy((void*)args.buffer.get(), (void*)arr->data, size);
if (Py_TYPE(obj) != PyArray_Type && args.dtype.dsize()==8) {
// convert to 32bit
auto num = size/8;
if (args.dtype.is_int()) {
auto* __restrict__ i64 = (int64*)args.ptr;
auto* __restrict__ i32 = (int32*)args.ptr;
for (int i=0; i<num; i++)
i32[i] = (int32)i64[i];
args.dtype = ns_int32;
} else if (args.dtype.is_float()) {
auto* __restrict__ f64 = (float64*)args.ptr;
auto* __restrict__ f32 = (float32*)args.ptr;
for (int i=0; i<num; i++)
f32[i] = (float32)f64[i];
args.dtype = ns_float32;
}
}
return args;
}
T args;

View File

@ -178,6 +178,14 @@ struct VarHolder {
std::memcpy(var->mem_ptr, array.ptr, size);
}
// @pyjt(share_with)
// @attrs(return_self)
inline VarHolder* share_with(VarHolder* other) {
CHECK(!var->allocator) << "This var is already executed or shared.";
var->allocator = (Allocator*)(other->var);
return this;
}
// @pyjt(debug_msg)
string debug_msg();
};