mirror of https://github.com/Jittor/Jittor
copy, empty and default 32bit
This commit is contained in:
parent
29238b8cfb
commit
42a2abfe50
|
@ -198,18 +198,13 @@ def clean():
|
|||
cast = unary
|
||||
|
||||
def array(data, dtype=None):
|
||||
if type(data) == core.Var:
|
||||
if isinstance(data, core.Var):
|
||||
if dtype is None:
|
||||
return cast(data, data.dtype)
|
||||
return data.clone()
|
||||
return cast(data, dtype)
|
||||
if dtype != None:
|
||||
return ops.array(np.array(data, dtype))
|
||||
if type(data) == np.ndarray:
|
||||
if data.flags.c_contiguous:
|
||||
return ops.array(data)
|
||||
else:
|
||||
return ops.array(data.copy())
|
||||
return ops.array(np.array(data))
|
||||
return ops.array(data)
|
||||
|
||||
def grad(loss, targets):
|
||||
if type(targets) == core.Var:
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "ops/copy_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
CopyOp::CopyOp(Var* x) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
auto y = create_output(nullptr, x->dtype());
|
||||
if (x->name.ptr)
|
||||
y->name = x->name;
|
||||
}
|
||||
|
||||
VarPtr CopyOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
return dout;
|
||||
}
|
||||
|
||||
void CopyOp::infer_shape() {
|
||||
outputs().front()->set_shape(inputs().front()->shape);
|
||||
}
|
||||
|
||||
void CopyOp::run() {
|
||||
auto x = inputs().front();
|
||||
auto size = x->size;
|
||||
auto x_ptr = x->mem_ptr;
|
||||
auto y_ptr = outputs().front()->mem_ptr;
|
||||
if (flags.get(NodeFlags::_cpu)) {
|
||||
std::memcpy(y_ptr, x_ptr, size);
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
else {
|
||||
std::cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,22 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CopyOp : Op {
|
||||
CopyOp(Var* x);
|
||||
|
||||
const char* name() const override { return "copy"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
void run() override;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,21 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "ops/empty_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
EmptyOp::EmptyOp(NanoVector shape, NanoString dtype) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
create_output(shape, dtype);
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,19 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct EmptyOp : Op {
|
||||
EmptyOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||
|
||||
const char* name() const override { return "empty"; }
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -308,6 +308,23 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
|
|||
args.buffer.reset(new char[size]);
|
||||
args.ptr = (void*)args.buffer.get();
|
||||
memcpy((void*)args.buffer.get(), (void*)arr->data, size);
|
||||
if (Py_TYPE(obj) != PyArray_Type && args.dtype.dsize()==8) {
|
||||
// convert to 32bit
|
||||
auto num = size/8;
|
||||
if (args.dtype.is_int()) {
|
||||
auto* __restrict__ i64 = (int64*)args.ptr;
|
||||
auto* __restrict__ i32 = (int32*)args.ptr;
|
||||
for (int i=0; i<num; i++)
|
||||
i32[i] = (int32)i64[i];
|
||||
args.dtype = ns_int32;
|
||||
} else if (args.dtype.is_float()) {
|
||||
auto* __restrict__ f64 = (float64*)args.ptr;
|
||||
auto* __restrict__ f32 = (float32*)args.ptr;
|
||||
for (int i=0; i<num; i++)
|
||||
f32[i] = (float32)f64[i];
|
||||
args.dtype = ns_float32;
|
||||
}
|
||||
}
|
||||
return args;
|
||||
}
|
||||
T args;
|
||||
|
|
|
@ -178,6 +178,14 @@ struct VarHolder {
|
|||
std::memcpy(var->mem_ptr, array.ptr, size);
|
||||
}
|
||||
|
||||
// @pyjt(share_with)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* share_with(VarHolder* other) {
|
||||
CHECK(!var->allocator) << "This var is already executed or shared.";
|
||||
var->allocator = (Allocator*)(other->var);
|
||||
return this;
|
||||
}
|
||||
|
||||
// @pyjt(debug_msg)
|
||||
string debug_msg();
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue