Custom Data Readers(自定义数据读取器)

自定义数据读取器

先决条件:

  • 熟悉C ++。

我们将支持文件格式的任务分成两部分:

  • 文件格式:我们使用Reader Op 从文件中读取记录(可以是任何字符串)。

  • 记录格式:我们使用解码器或解析Ops将一个字符串记录转换为TensorFlow可用的张量。

例如,要读取CSV文件,我们使用Reader作为文本文件,然后使用Op来分析一行文本中的CSV数据

为文件格式编写Reader

Reader是从文件读取记录的东西。TensorFlow中已经内置了一些Reader Ops的例子:

  • tf.TFRecordReader(来源于kernels/tf_record_reader_op.cc

  • tf.FixedLengthRecordReader(来源于kernels/fixed_length_record_reader_op.cc

  • tf.TextLineReader(来源于kernels/text_line_reader_op.cc

你可以看到这些都暴露了相同的接口,唯一的区别在于它们的构造函数。最重要的方法是read。它需要一个队列参数,这是它获取文件名以从需要时读取的文件名(例如,当readop首次运行时,或前read一次从文件读取最后一个记录时)。它产生两个标量张量:一个字符串键和一个字符串值。

要创建一个新的读者SomeReader,你需要:

1. 在C ++中,定义一个tensorflow::ReaderBase被调用的子类SomeReader

2. 在C ++中,用名称注册一个新的读取器操作系统和内核"SomeReader"

3. 在Python中,定义一个tf.ReaderBase被调用的子类SomeReader

你可以把所有的C ++代码放在一个文件tensorflow/core/user_ops/some_reader_op.cc中。读取文件的代码将存放在C ++ ReaderBase类的后代中,C ++ 类定义在后者中tensorflow/core/kernels/reader_base.h。您将需要实施以下方法:

  • OnWorkStartedLocked:打开下一个文件

  • ReadLocked:读取记录或报告EOF /错误

  • OnWorkFinishedLocked:关闭当前文件,并

  • ResetLocked:在例如错误之后得到干净的平板

这些方法的名称以“Locked”结尾,因为ReaderBase在调用这些方法之前确保获得互斥体,所以您通常不必担心线程安全性(尽管只保护类的成员,而不是全局状态) 。

对于OnWorkStartedLocked要打开的文件的名称是该current_work()方法返回的值。ReadLocked有这样的签名:

Status ReadLocked(string* key, string* value, bool* produced, bool* at_end)

如果ReadLocked成功从文件中读取记录,则应填写:

  • *key:带有记录的标识符,人可以用来再次查找该记录。你可以包含文件名current_work(),并附加一个记录号码或其他。

  • *value:与记录的内容。

  • *produced:设为true

如果您点击文件末尾(EOF),请设置*at_endtrue。无论哪种情况,都会返回Status::OK()。如果出现错误,只需使用其中一个辅助函数即可返回它,tensorflow/core/lib/core/errors.h而无需修改任何参数。

接下来,您将创建实际的Reader操作。如果您熟悉添加操作方法,这将有所帮助。主要步骤是:

  • 注册操作。

  • 定义并注册一个OpKernel

要注册该操作,您将使用在中REGISTER_OP定义的呼叫tensorflow/core/framework/op.h。读者操作系统从不接受任何输入,并且始终只有一个带有类型的输出resource。他们应该有字符串containershared_nameattrs。您可以选择定义额外的attrs进行配置或在文档中包含一个Doc。例如,请参阅tensorflow/core/ops/io_ops.cc:例如:

#include "tensorflow/core/framework/op.h" REGISTER_OP("TextLineReader") .Output("reader_handle: resource") .Attr("skip_header_lines: int = 0") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape) .Doc(R"doc( A Reader that outputs the lines of a file delimited by '\n'. )doc"

要定义一个OpKernel,读者可以使用降序的快捷方式ReaderOpKernel,定义tensorflow/core/framework/reader_op_kernel.h和实现调用的构造函数SetReaderFactory。定义你的课程后,你需要使用注册REGISTER_KERNEL_BUILDER(...)。没有attrs的例子:

#include "tensorflow/core/framework/reader_op_kernel.h" class TFRecordReaderOp : public ReaderOpKernel { public: explicit TFRecordReaderOp(OpKernelConstruction* context) : ReaderOpKernel(context) { Env* env = context->env( SetReaderFactory([this, env]() { return new TFRecordReader(name(), env } } }; REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU), TFRecordReaderOp

有attrs的一个例子:

#include "tensorflow/core/framework/reader_op_kernel.h" class TextLineReaderOp : public ReaderOpKernel { public: explicit TextLineReaderOp(OpKernelConstruction* context) : ReaderOpKernel(context) { int skip_header_lines = -1; OP_REQUIRES_OK(context, context->GetAttr("skip_header_lines", &skip_header_lines) OP_REQUIRES(context, skip_header_lines >= 0, errors::InvalidArgument("skip_header_lines must be >= 0 not ", skip_header_lines) Env* env = context->env( SetReaderFactory([this, skip_header_lines, env]() { return new TextLineReader(name(), skip_header_lines, env } } }; REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU), TextLineReaderOp

最后一步是添加Python包装器。你可以通过编译一个动态库来实现,或者如果你是从源代码构建TensorFlow,添加到user_ops.py。对于后者,您将导入tensorflow.python.ops.io_opstensorflow/python/user_ops/user_ops.py添加的后裔io_ops.ReaderBase

from tensorflow.python.framework import ops from tensorflow.python.ops import common_shapes from tensorflow.python.ops import io_ops class SomeReader(io_ops.ReaderBase): def __init__(self, name=None): rr = gen_user_ops.some_reader(name=name) super(SomeReader, self).__init__(rr) ops.NotDifferentiable("SomeReader")

你可以看到一些例子tensorflow/python/ops/io_ops.py

为记录格式编写操作

通常这是一个普通的操作,它将标量字符串记录作为输入,因此按照说明添加操作。您可以选择使用标量字符串键作为输入,并将其包含在报告格式不正确的数据的错误消息中。这样用户可以更轻松地追踪坏数据的来源。

可用于解码记录的Ops示例:

  • tf.parse_single_example(和tf.parse_example

  • tf.decode_csv

  • tf.decode_raw

请注意,使用多个Ops来解码特定的记录格式会很有用。例如,可能必须保存为一个字符串的图像一个tf.train.Example协议缓冲器。根据该图像的格式,你可能会采取相应的输出从tf.parse_single_exampleOP和呼叫tf.image.decode_jpegtf.image.decode_pngtf.decode_raw。采用输出tf.decode_raw和使用tf.slice以及tf.reshape提取碎片是很常见的。