Signed-off-by: jagger <cossjie@foxmail.com>

Former-commit-id: 2d3e239f42
This commit is contained in:
jagger 2024-07-03 16:10:50 +08:00
parent c60c2a819f
commit 6aa89132a8
6695 changed files with 15 additions and 2066492 deletions

View File

@ -54,7 +54,7 @@ workflow:
image_name: '"registry.cn-hangzhou.aliyuncs.com/jcce/pcm-core-api"'
image_tag: '"latest"'
registry_address: '"registry.cn-hangzhou.aliyuncs.com"'
docker_file: git_clone_0.git_path + '/api/Dockerfile'
docker_file: git_clone_0.git_path + '/Dockerfile'
docker_build_path: git_clone_0.git_path
workspace: git_clone_0.git_path
image_clean: true

View File

@ -1,34 +1,28 @@
FROM registry.cn-hangzhou.aliyuncs.com/jcce-images/golang:1.22.4-alpine3.20 AS builder
LABEL stage=gobuilder
ENV CGO_ENABLED 0
ENV GOARCH amd64
ENV GOPROXY https://goproxy.cn,direct
WORKDIR /app
ADD go.mod .
ADD go.sum .
RUN go mod download
COPY . .
COPY api/etc/ /app/
RUN go build -o pcm-coordinator-api /app/api/pcm.go
ENV GO111MODULE=on GOPROXY=https://goproxy.cn,direct CGO_ENABLED=0
RUN go mod download
RUN go build -ldflags="-w -s" -o pcm-core-api
FROM registry.cn-hangzhou.aliyuncs.com/jcce-images/alpine3.20
FROM registry.cn-hangzhou.aliyuncs.com/jcce-images/alpine:3.20
WORKDIR /app
#修改alpine源为上海交通大学
RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.sjtug.sjtu.edu.cn/g' /etc/apk/repositories && \
apk add --no-cache ca-certificates tzdata && \
update-ca-certificates && \
rm -rf /var/cache/apk/*
apk add --no-cache ca-certificates && update-ca-certificates && \
apk add --update tzdata && \
rm -rf /var/cache/apk/*
COPY --from=builder /app/pcm-coordinator-api /app/
COPY --from=builder /app/api/etc/pcm.yaml /app/
COPY --from=builder /app/pcm-core-api .
COPY etc/pcm.yaml .
ENV TZ=Asia/Shanghai
EXPOSE 8999
EXPOSE 2002
ENTRYPOINT ["./pcm-coordinator-api", "-f", "pcm.yaml"]
ENTRYPOINT ./pcm-core-api --f pcm.yaml

3
go.mod
View File

@ -2,9 +2,6 @@ module gitlink.org.cn/JointCloud/pcm-coordinator
go 1.22.0
toolchain go1.22.4
retract v0.1.20-0.20240319015239-6ae13da05255
require (
github.com/JCCE-nudt/apigw-go-sdk v0.0.0-20230525025609-34159d6f2818

View File

@ -1,27 +0,0 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,14 +0,0 @@
# filippo.io/edwards25519
```
import "filippo.io/edwards25519"
```
This library implements the edwards25519 elliptic curve, exposing the necessary APIs to build a wide array of higher-level primitives.
Read the docs at [pkg.go.dev/filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519).
The code is originally derived from Adam Langley's internal implementation in the Go standard library, and includes George Tankersley's [performance improvements](https://golang.org/cl/71950). It was then further developed by Henry de Valence for use in ristretto255, and was finally [merged back into the Go standard library](https://golang.org/cl/276272) as of Go 1.17. It now tracks the upstream codebase and extends it with additional functionality.
Most users don't need this package, and should instead use `crypto/ed25519` for signatures, `golang.org/x/crypto/curve25519` for Diffie-Hellman, or `github.com/gtank/ristretto255` for prime order group logic. However, for anyone currently using a fork of `crypto/internal/edwards25519`/`crypto/ed25519/internal/edwards25519` or `github.com/agl/edwards25519`, this package should be a safer, faster, and more powerful alternative.
Since this package is meant to curb proliferation of edwards25519 implementations in the Go ecosystem, it welcomes requests for new APIs or reviewable performance improvements.

View File

@ -1,20 +0,0 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package edwards25519 implements group logic for the twisted Edwards curve
//
// -x^2 + y^2 = 1 + -(121665/121666)*x^2*y^2
//
// This is better known as the Edwards curve equivalent to Curve25519, and is
// the curve used by the Ed25519 signature scheme.
//
// Most users don't need this package, and should instead use crypto/ed25519 for
// signatures, golang.org/x/crypto/curve25519 for Diffie-Hellman, or
// github.com/gtank/ristretto255 for prime order group logic.
//
// However, developers who do need to interact with low-level edwards25519
// operations can use this package, which is an extended version of
// crypto/internal/edwards25519 from the standard library repackaged as
// an importable module.
package edwards25519

View File

@ -1,427 +0,0 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"errors"
"filippo.io/edwards25519/field"
)
// Point types.
type projP1xP1 struct {
X, Y, Z, T field.Element
}
type projP2 struct {
X, Y, Z field.Element
}
// Point represents a point on the edwards25519 curve.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is NOT valid, and it may be used only as a receiver.
type Point struct {
// Make the type not comparable (i.e. used with == or as a map key), as
// equivalent points can be represented by different Go values.
_ incomparable
// The point is internally represented in extended coordinates (X, Y, Z, T)
// where x = X/Z, y = Y/Z, and xy = T/Z per https://eprint.iacr.org/2008/522.
x, y, z, t field.Element
}
type incomparable [0]func()
func checkInitialized(points ...*Point) {
for _, p := range points {
if p.x == (field.Element{}) && p.y == (field.Element{}) {
panic("edwards25519: use of uninitialized Point")
}
}
}
type projCached struct {
YplusX, YminusX, Z, T2d field.Element
}
type affineCached struct {
YplusX, YminusX, T2d field.Element
}
// Constructors.
func (v *projP2) Zero() *projP2 {
v.X.Zero()
v.Y.One()
v.Z.One()
return v
}
// identity is the point at infinity.
var identity, _ = new(Point).SetBytes([]byte{
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
// NewIdentityPoint returns a new Point set to the identity.
func NewIdentityPoint() *Point {
return new(Point).Set(identity)
}
// generator is the canonical curve basepoint. See TestGenerator for the
// correspondence of this encoding with the values in RFC 8032.
var generator, _ = new(Point).SetBytes([]byte{
0x58, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66})
// NewGeneratorPoint returns a new Point set to the canonical generator.
func NewGeneratorPoint() *Point {
return new(Point).Set(generator)
}
func (v *projCached) Zero() *projCached {
v.YplusX.One()
v.YminusX.One()
v.Z.One()
v.T2d.Zero()
return v
}
func (v *affineCached) Zero() *affineCached {
v.YplusX.One()
v.YminusX.One()
v.T2d.Zero()
return v
}
// Assignments.
// Set sets v = u, and returns v.
func (v *Point) Set(u *Point) *Point {
*v = *u
return v
}
// Encoding.
// Bytes returns the canonical 32-byte encoding of v, according to RFC 8032,
// Section 5.1.2.
func (v *Point) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytes(&buf)
}
func (v *Point) bytes(buf *[32]byte) []byte {
checkInitialized(v)
var zInv, x, y field.Element
zInv.Invert(&v.z) // zInv = 1 / Z
x.Multiply(&v.x, &zInv) // x = X / Z
y.Multiply(&v.y, &zInv) // y = Y / Z
out := copyFieldElement(buf, &y)
out[31] |= byte(x.IsNegative() << 7)
return out
}
var feOne = new(field.Element).One()
// SetBytes sets v = x, where x is a 32-byte encoding of v. If x does not
// represent a valid point on the curve, SetBytes returns nil and an error and
// the receiver is unchanged. Otherwise, SetBytes returns v.
//
// Note that SetBytes accepts all non-canonical encodings of valid points.
// That is, it follows decoding rules that match most implementations in
// the ecosystem rather than RFC 8032.
func (v *Point) SetBytes(x []byte) (*Point, error) {
// Specifically, the non-canonical encodings that are accepted are
// 1) the ones where the field element is not reduced (see the
// (*field.Element).SetBytes docs) and
// 2) the ones where the x-coordinate is zero and the sign bit is set.
//
// Read more at https://hdevalence.ca/blog/2020-10-04-its-25519am,
// specifically the "Canonical A, R" section.
y, err := new(field.Element).SetBytes(x)
if err != nil {
return nil, errors.New("edwards25519: invalid point encoding length")
}
// -x² + y² = 1 + dx²y²
// x² + dx²y² = x²(dy² + 1) = y² - 1
// x² = (y² - 1) / (dy² + 1)
// u = y² - 1
y2 := new(field.Element).Square(y)
u := new(field.Element).Subtract(y2, feOne)
// v = dy² + 1
vv := new(field.Element).Multiply(y2, d)
vv = vv.Add(vv, feOne)
// x = +√(u/v)
xx, wasSquare := new(field.Element).SqrtRatio(u, vv)
if wasSquare == 0 {
return nil, errors.New("edwards25519: invalid point encoding")
}
// Select the negative square root if the sign bit is set.
xxNeg := new(field.Element).Negate(xx)
xx = xx.Select(xxNeg, xx, int(x[31]>>7))
v.x.Set(xx)
v.y.Set(y)
v.z.One()
v.t.Multiply(xx, y) // xy = T / Z
return v, nil
}
func copyFieldElement(buf *[32]byte, v *field.Element) []byte {
copy(buf[:], v.Bytes())
return buf[:]
}
// Conversions.
func (v *projP2) FromP1xP1(p *projP1xP1) *projP2 {
v.X.Multiply(&p.X, &p.T)
v.Y.Multiply(&p.Y, &p.Z)
v.Z.Multiply(&p.Z, &p.T)
return v
}
func (v *projP2) FromP3(p *Point) *projP2 {
v.X.Set(&p.x)
v.Y.Set(&p.y)
v.Z.Set(&p.z)
return v
}
func (v *Point) fromP1xP1(p *projP1xP1) *Point {
v.x.Multiply(&p.X, &p.T)
v.y.Multiply(&p.Y, &p.Z)
v.z.Multiply(&p.Z, &p.T)
v.t.Multiply(&p.X, &p.Y)
return v
}
func (v *Point) fromP2(p *projP2) *Point {
v.x.Multiply(&p.X, &p.Z)
v.y.Multiply(&p.Y, &p.Z)
v.z.Square(&p.Z)
v.t.Multiply(&p.X, &p.Y)
return v
}
// d is a constant in the curve equation.
var d, _ = new(field.Element).SetBytes([]byte{
0xa3, 0x78, 0x59, 0x13, 0xca, 0x4d, 0xeb, 0x75,
0xab, 0xd8, 0x41, 0x41, 0x4d, 0x0a, 0x70, 0x00,
0x98, 0xe8, 0x79, 0x77, 0x79, 0x40, 0xc7, 0x8c,
0x73, 0xfe, 0x6f, 0x2b, 0xee, 0x6c, 0x03, 0x52})
var d2 = new(field.Element).Add(d, d)
func (v *projCached) FromP3(p *Point) *projCached {
v.YplusX.Add(&p.y, &p.x)
v.YminusX.Subtract(&p.y, &p.x)
v.Z.Set(&p.z)
v.T2d.Multiply(&p.t, d2)
return v
}
func (v *affineCached) FromP3(p *Point) *affineCached {
v.YplusX.Add(&p.y, &p.x)
v.YminusX.Subtract(&p.y, &p.x)
v.T2d.Multiply(&p.t, d2)
var invZ field.Element
invZ.Invert(&p.z)
v.YplusX.Multiply(&v.YplusX, &invZ)
v.YminusX.Multiply(&v.YminusX, &invZ)
v.T2d.Multiply(&v.T2d, &invZ)
return v
}
// (Re)addition and subtraction.
// Add sets v = p + q, and returns v.
func (v *Point) Add(p, q *Point) *Point {
checkInitialized(p, q)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Add(p, qCached)
return v.fromP1xP1(result)
}
// Subtract sets v = p - q, and returns v.
func (v *Point) Subtract(p, q *Point) *Point {
checkInitialized(p, q)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Sub(p, qCached)
return v.fromP1xP1(result)
}
func (v *projP1xP1) Add(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YplusX)
MM.Multiply(&YminusX, &q.YminusX)
TT2d.Multiply(&p.t, &q.T2d)
ZZ2.Multiply(&p.z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&ZZ2, &TT2d)
v.T.Subtract(&ZZ2, &TT2d)
return v
}
func (v *projP1xP1) Sub(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
TT2d.Multiply(&p.t, &q.T2d)
ZZ2.Multiply(&p.z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Subtract(&ZZ2, &TT2d) // flipped sign
v.T.Add(&ZZ2, &TT2d) // flipped sign
return v
}
func (v *projP1xP1) AddAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YplusX)
MM.Multiply(&YminusX, &q.YminusX)
TT2d.Multiply(&p.t, &q.T2d)
Z2.Add(&p.z, &p.z)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&Z2, &TT2d)
v.T.Subtract(&Z2, &TT2d)
return v
}
func (v *projP1xP1) SubAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
TT2d.Multiply(&p.t, &q.T2d)
Z2.Add(&p.z, &p.z)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Subtract(&Z2, &TT2d) // flipped sign
v.T.Add(&Z2, &TT2d) // flipped sign
return v
}
// Doubling.
func (v *projP1xP1) Double(p *projP2) *projP1xP1 {
var XX, YY, ZZ2, XplusYsq field.Element
XX.Square(&p.X)
YY.Square(&p.Y)
ZZ2.Square(&p.Z)
ZZ2.Add(&ZZ2, &ZZ2)
XplusYsq.Add(&p.X, &p.Y)
XplusYsq.Square(&XplusYsq)
v.Y.Add(&YY, &XX)
v.Z.Subtract(&YY, &XX)
v.X.Subtract(&XplusYsq, &v.Y)
v.T.Subtract(&ZZ2, &v.Z)
return v
}
// Negation.
// Negate sets v = -p, and returns v.
func (v *Point) Negate(p *Point) *Point {
checkInitialized(p)
v.x.Negate(&p.x)
v.y.Set(&p.y)
v.z.Set(&p.z)
v.t.Negate(&p.t)
return v
}
// Equal returns 1 if v is equivalent to u, and 0 otherwise.
func (v *Point) Equal(u *Point) int {
checkInitialized(v, u)
var t1, t2, t3, t4 field.Element
t1.Multiply(&v.x, &u.z)
t2.Multiply(&u.x, &v.z)
t3.Multiply(&v.y, &u.z)
t4.Multiply(&u.y, &v.z)
return t1.Equal(&t2) & t3.Equal(&t4)
}
// Constant-time operations
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *projCached) Select(a, b *projCached, cond int) *projCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.Z.Select(&a.Z, &b.Z, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *affineCached) Select(a, b *affineCached, cond int) *affineCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *projCached) CondNeg(cond int) *projCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *affineCached) CondNeg(cond int) *affineCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}

View File

@ -1,349 +0,0 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
// This file contains additional functionality that is not included in the
// upstream crypto/internal/edwards25519 package.
import (
"errors"
"filippo.io/edwards25519/field"
)
// ExtendedCoordinates returns v in extended coordinates (X:Y:Z:T) where
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
func (v *Point) ExtendedCoordinates() (X, Y, Z, T *field.Element) {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap. Don't change the style without making
// sure it doesn't increase the inliner cost.
var e [4]field.Element
X, Y, Z, T = v.extendedCoordinates(&e)
return
}
func (v *Point) extendedCoordinates(e *[4]field.Element) (X, Y, Z, T *field.Element) {
checkInitialized(v)
X = e[0].Set(&v.x)
Y = e[1].Set(&v.y)
Z = e[2].Set(&v.z)
T = e[3].Set(&v.t)
return
}
// SetExtendedCoordinates sets v = (X:Y:Z:T) in extended coordinates where
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
//
// If the coordinates are invalid or don't represent a valid point on the curve,
// SetExtendedCoordinates returns nil and an error and the receiver is
// unchanged. Otherwise, SetExtendedCoordinates returns v.
func (v *Point) SetExtendedCoordinates(X, Y, Z, T *field.Element) (*Point, error) {
if !isOnCurve(X, Y, Z, T) {
return nil, errors.New("edwards25519: invalid point coordinates")
}
v.x.Set(X)
v.y.Set(Y)
v.z.Set(Z)
v.t.Set(T)
return v, nil
}
func isOnCurve(X, Y, Z, T *field.Element) bool {
var lhs, rhs field.Element
XX := new(field.Element).Square(X)
YY := new(field.Element).Square(Y)
ZZ := new(field.Element).Square(Z)
TT := new(field.Element).Square(T)
// -x² + y² = 1 + dx²y²
// -(X/Z)² + (Y/Z)² = 1 + d(T/Z)²
// -X² + Y² = Z² + dT²
lhs.Subtract(YY, XX)
rhs.Multiply(d, TT).Add(&rhs, ZZ)
if lhs.Equal(&rhs) != 1 {
return false
}
// xy = T/Z
// XY/Z² = T/Z
// XY = TZ
lhs.Multiply(X, Y)
rhs.Multiply(T, Z)
return lhs.Equal(&rhs) == 1
}
// BytesMontgomery converts v to a point on the birationally-equivalent
// Curve25519 Montgomery curve, and returns its canonical 32 bytes encoding
// according to RFC 7748.
//
// Note that BytesMontgomery only encodes the u-coordinate, so v and -v encode
// to the same value. If v is the identity point, BytesMontgomery returns 32
// zero bytes, analogously to the X25519 function.
//
// The lack of an inverse operation (such as SetMontgomeryBytes) is deliberate:
// while every valid edwards25519 point has a unique u-coordinate Montgomery
// encoding, X25519 accepts inputs on the quadratic twist, which don't correspond
// to any edwards25519 point, and every other X25519 input corresponds to two
// edwards25519 points.
func (v *Point) BytesMontgomery() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytesMontgomery(&buf)
}
func (v *Point) bytesMontgomery(buf *[32]byte) []byte {
checkInitialized(v)
// RFC 7748, Section 4.1 provides the bilinear map to calculate the
// Montgomery u-coordinate
//
// u = (1 + y) / (1 - y)
//
// where y = Y / Z.
var y, recip, u field.Element
y.Multiply(&v.y, y.Invert(&v.z)) // y = Y / Z
recip.Invert(recip.Subtract(feOne, &y)) // r = 1/(1 - y)
u.Multiply(u.Add(feOne, &y), &recip) // u = (1 + y)*r
return copyFieldElement(buf, &u)
}
// MultByCofactor sets v = 8 * p, and returns v.
func (v *Point) MultByCofactor(p *Point) *Point {
checkInitialized(p)
result := projP1xP1{}
pp := (&projP2{}).FromP3(p)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
return v.fromP1xP1(&result)
}
// Given k > 0, set s = s**(2*i).
func (s *Scalar) pow2k(k int) {
for i := 0; i < k; i++ {
s.Multiply(s, s)
}
}
// Invert sets s to the inverse of a nonzero scalar v, and returns s.
//
// If t is zero, Invert returns zero.
func (s *Scalar) Invert(t *Scalar) *Scalar {
// Uses a hardcoded sliding window of width 4.
var table [8]Scalar
var tt Scalar
tt.Multiply(t, t)
table[0] = *t
for i := 0; i < 7; i++ {
table[i+1].Multiply(&table[i], &tt)
}
// Now table = [t**1, t**3, t**5, t**7, t**9, t**11, t**13, t**15]
// so t**k = t[k/2] for odd k
// To compute the sliding window digits, use the following Sage script:
// sage: import itertools
// sage: def sliding_window(w,k):
// ....: digits = []
// ....: while k > 0:
// ....: if k % 2 == 1:
// ....: kmod = k % (2**w)
// ....: digits.append(kmod)
// ....: k = k - kmod
// ....: else:
// ....: digits.append(0)
// ....: k = k // 2
// ....: return digits
// Now we can compute s roughly as follows:
// sage: s = 1
// sage: for coeff in reversed(sliding_window(4,l-2)):
// ....: s = s*s
// ....: if coeff > 0 :
// ....: s = s*t**coeff
// This works on one bit at a time, with many runs of zeros.
// The digits can be collapsed into [(count, coeff)] as follows:
// sage: [(len(list(group)),d) for d,group in itertools.groupby(sliding_window(4,l-2))]
// Entries of the form (k, 0) turn into pow2k(k)
// Entries of the form (1, coeff) turn into a squaring and then a table lookup.
// We can fold the squaring into the previous pow2k(k) as pow2k(k+1).
*s = table[1/2]
s.pow2k(127 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[5/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(5 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(9 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
return s
}
// MultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends only on the lengths of the two slices, which must match.
func (v *Point) MultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called MultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Proceed as in the single-base case, but share doublings
// between each point in the multiscalar equation.
// Build lookup tables for each point
tables := make([]projLookupTable, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute signed radix-16 digits for each scalar
digits := make([][64]int8, len(scalars))
for i := range digits {
digits[i] = scalars[i].signedRadix16()
}
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][63])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,63)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
for i := 62; i >= 0; i-- {
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][i])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,i)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
}
return v
}
// VarTimeMultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeMultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called VarTimeMultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Generalize double-base NAF computation to arbitrary sizes.
// Here all the points are dynamic, so we only use the smaller
// tables.
// Build lookup tables for each point
tables := make([]nafLookupTable5, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute a NAF for each scalar
nafs := make([][256]int8, len(scalars))
for i := range nafs {
nafs[i] = scalars[i].nonAdjacentForm(5)
}
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
//
// Skip trying to find the first nonzero coefficent, because
// searching might be more work than a few extra doublings.
for i := 255; i >= 0; i-- {
tmp1.Double(tmp2)
for j := range nafs {
if nafs[j][i] > 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, nafs[j][i])
tmp1.Add(v, multiple)
} else if nafs[j][i] < 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, -nafs[j][i])
tmp1.Sub(v, multiple)
}
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}

View File

@ -1,420 +0,0 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package field implements fast arithmetic modulo 2^255-19.
package field
import (
"crypto/subtle"
"encoding/binary"
"errors"
"math/bits"
)
// Element represents an element of the field GF(2^255-19). Note that this
// is not a cryptographically secure group, and should only be used to interact
// with edwards25519.Point coordinates.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is a valid zero element.
type Element struct {
// An element t represents the integer
// t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204
//
// Between operations, all limbs are expected to be lower than 2^52.
l0 uint64
l1 uint64
l2 uint64
l3 uint64
l4 uint64
}
const maskLow51Bits uint64 = (1 << 51) - 1
var feZero = &Element{0, 0, 0, 0, 0}
// Zero sets v = 0, and returns v.
func (v *Element) Zero() *Element {
*v = *feZero
return v
}
var feOne = &Element{1, 0, 0, 0, 0}
// One sets v = 1, and returns v.
func (v *Element) One() *Element {
*v = *feOne
return v
}
// reduce reduces v modulo 2^255 - 19 and returns it.
func (v *Element) reduce() *Element {
v.carryPropagate()
// After the light reduction we now have a field element representation
// v < 2^255 + 2^13 * 19, but need v < 2^255 - 19.
// If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1,
// generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise.
c := (v.l0 + 19) >> 51
c = (v.l1 + c) >> 51
c = (v.l2 + c) >> 51
c = (v.l3 + c) >> 51
c = (v.l4 + c) >> 51
// If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's
// effectively applying the reduction identity to the carry.
v.l0 += 19 * c
v.l1 += v.l0 >> 51
v.l0 = v.l0 & maskLow51Bits
v.l2 += v.l1 >> 51
v.l1 = v.l1 & maskLow51Bits
v.l3 += v.l2 >> 51
v.l2 = v.l2 & maskLow51Bits
v.l4 += v.l3 >> 51
v.l3 = v.l3 & maskLow51Bits
// no additional carry
v.l4 = v.l4 & maskLow51Bits
return v
}
// Add sets v = a + b, and returns v.
func (v *Element) Add(a, b *Element) *Element {
v.l0 = a.l0 + b.l0
v.l1 = a.l1 + b.l1
v.l2 = a.l2 + b.l2
v.l3 = a.l3 + b.l3
v.l4 = a.l4 + b.l4
// Using the generic implementation here is actually faster than the
// assembly. Probably because the body of this function is so simple that
// the compiler can figure out better optimizations by inlining the carry
// propagation.
return v.carryPropagateGeneric()
}
// Subtract sets v = a - b, and returns v.
func (v *Element) Subtract(a, b *Element) *Element {
// We first add 2 * p, to guarantee the subtraction won't underflow, and
// then subtract b (which can be up to 2^255 + 2^13 * 19).
v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0
v.l1 = (a.l1 + 0xFFFFFFFFFFFFE) - b.l1
v.l2 = (a.l2 + 0xFFFFFFFFFFFFE) - b.l2
v.l3 = (a.l3 + 0xFFFFFFFFFFFFE) - b.l3
v.l4 = (a.l4 + 0xFFFFFFFFFFFFE) - b.l4
return v.carryPropagate()
}
// Negate sets v = -a, and returns v.
func (v *Element) Negate(a *Element) *Element {
return v.Subtract(feZero, a)
}
// Invert sets v = 1/z mod p, and returns v.
//
// If z == 0, Invert returns v = 0.
func (v *Element) Invert(z *Element) *Element {
// Inversion is implemented as exponentiation with exponent p 2. It uses the
// same sequence of 255 squarings and 11 multiplications as [Curve25519].
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element
z2.Square(z) // 2
t.Square(&z2) // 4
t.Square(&t) // 8
z9.Multiply(&t, z) // 9
z11.Multiply(&z9, &z2) // 11
t.Square(&z11) // 22
z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0
t.Square(&z2_5_0) // 2^6 - 2^1
for i := 0; i < 4; i++ {
t.Square(&t) // 2^10 - 2^5
}
z2_10_0.Multiply(&t, &z2_5_0) // 2^10 - 2^0
t.Square(&z2_10_0) // 2^11 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^20 - 2^10
}
z2_20_0.Multiply(&t, &z2_10_0) // 2^20 - 2^0
t.Square(&z2_20_0) // 2^21 - 2^1
for i := 0; i < 19; i++ {
t.Square(&t) // 2^40 - 2^20
}
t.Multiply(&t, &z2_20_0) // 2^40 - 2^0
t.Square(&t) // 2^41 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^50 - 2^10
}
z2_50_0.Multiply(&t, &z2_10_0) // 2^50 - 2^0
t.Square(&z2_50_0) // 2^51 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^100 - 2^50
}
z2_100_0.Multiply(&t, &z2_50_0) // 2^100 - 2^0
t.Square(&z2_100_0) // 2^101 - 2^1
for i := 0; i < 99; i++ {
t.Square(&t) // 2^200 - 2^100
}
t.Multiply(&t, &z2_100_0) // 2^200 - 2^0
t.Square(&t) // 2^201 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^250 - 2^50
}
t.Multiply(&t, &z2_50_0) // 2^250 - 2^0
t.Square(&t) // 2^251 - 2^1
t.Square(&t) // 2^252 - 2^2
t.Square(&t) // 2^253 - 2^3
t.Square(&t) // 2^254 - 2^4
t.Square(&t) // 2^255 - 2^5
return v.Multiply(&t, &z11) // 2^255 - 21
}
// Set sets v = a, and returns v.
func (v *Element) Set(a *Element) *Element {
*v = *a
return v
}
// SetBytes sets v to x, where x is a 32-byte little-endian encoding. If x is
// not of the right length, SetBytes returns nil and an error, and the
// receiver is unchanged.
//
// Consistent with RFC 7748, the most significant bit (the high bit of the
// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1)
// are accepted. Note that this is laxer than specified by RFC 8032, but
// consistent with most Ed25519 implementations.
func (v *Element) SetBytes(x []byte) (*Element, error) {
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid field element input size")
}
// Bits 0:51 (bytes 0:8, bits 0:64, shift 0, mask 51).
v.l0 = binary.LittleEndian.Uint64(x[0:8])
v.l0 &= maskLow51Bits
// Bits 51:102 (bytes 6:14, bits 48:112, shift 3, mask 51).
v.l1 = binary.LittleEndian.Uint64(x[6:14]) >> 3
v.l1 &= maskLow51Bits
// Bits 102:153 (bytes 12:20, bits 96:160, shift 6, mask 51).
v.l2 = binary.LittleEndian.Uint64(x[12:20]) >> 6
v.l2 &= maskLow51Bits
// Bits 153:204 (bytes 19:27, bits 152:216, shift 1, mask 51).
v.l3 = binary.LittleEndian.Uint64(x[19:27]) >> 1
v.l3 &= maskLow51Bits
// Bits 204:255 (bytes 24:32, bits 192:256, shift 12, mask 51).
// Note: not bytes 25:33, shift 4, to avoid overread.
v.l4 = binary.LittleEndian.Uint64(x[24:32]) >> 12
v.l4 &= maskLow51Bits
return v, nil
}
// Bytes returns the canonical 32-byte little-endian encoding of v.
func (v *Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [32]byte
return v.bytes(&out)
}
func (v *Element) bytes(out *[32]byte) []byte {
t := *v
t.reduce()
var buf [8]byte
for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} {
bitsOffset := i * 51
binary.LittleEndian.PutUint64(buf[:], l<<uint(bitsOffset%8))
for i, bb := range buf {
off := bitsOffset/8 + i
if off >= len(out) {
break
}
out[off] |= bb
}
}
return out[:]
}
// Equal returns 1 if v and u are equal, and 0 otherwise.
func (v *Element) Equal(u *Element) int {
sa, sv := u.Bytes(), v.Bytes()
return subtle.ConstantTimeCompare(sa, sv)
}
// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise.
func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) }
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *Element) Select(a, b *Element, cond int) *Element {
m := mask64Bits(cond)
v.l0 = (m & a.l0) | (^m & b.l0)
v.l1 = (m & a.l1) | (^m & b.l1)
v.l2 = (m & a.l2) | (^m & b.l2)
v.l3 = (m & a.l3) | (^m & b.l3)
v.l4 = (m & a.l4) | (^m & b.l4)
return v
}
// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v.
func (v *Element) Swap(u *Element, cond int) {
m := mask64Bits(cond)
t := m & (v.l0 ^ u.l0)
v.l0 ^= t
u.l0 ^= t
t = m & (v.l1 ^ u.l1)
v.l1 ^= t
u.l1 ^= t
t = m & (v.l2 ^ u.l2)
v.l2 ^= t
u.l2 ^= t
t = m & (v.l3 ^ u.l3)
v.l3 ^= t
u.l3 ^= t
t = m & (v.l4 ^ u.l4)
v.l4 ^= t
u.l4 ^= t
}
// IsNegative returns 1 if v is negative, and 0 otherwise.
func (v *Element) IsNegative() int {
return int(v.Bytes()[0] & 1)
}
// Absolute sets v to |u|, and returns v.
func (v *Element) Absolute(u *Element) *Element {
return v.Select(new(Element).Negate(u), u, u.IsNegative())
}
// Multiply sets v = x * y, and returns v.
func (v *Element) Multiply(x, y *Element) *Element {
feMul(v, x, y)
return v
}
// Square sets v = x * x, and returns v.
func (v *Element) Square(x *Element) *Element {
feSquare(v, x)
return v
}
// Mult32 sets v = x * y, and returns v.
func (v *Element) Mult32(x *Element, y uint32) *Element {
x0lo, x0hi := mul51(x.l0, y)
x1lo, x1hi := mul51(x.l1, y)
x2lo, x2hi := mul51(x.l2, y)
x3lo, x3hi := mul51(x.l3, y)
x4lo, x4hi := mul51(x.l4, y)
v.l0 = x0lo + 19*x4hi // carried over per the reduction identity
v.l1 = x1lo + x0hi
v.l2 = x2lo + x1hi
v.l3 = x3lo + x2hi
v.l4 = x4lo + x3hi
// The hi portions are going to be only 32 bits, plus any previous excess,
// so we can skip the carry propagation.
return v
}
// mul51 returns lo + hi * 2⁵¹ = a * b.
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
mh, ml := bits.Mul64(a, uint64(b))
lo = ml & maskLow51Bits
hi = (mh << 13) | (ml >> 51)
return
}
// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3.
func (v *Element) Pow22523(x *Element) *Element {
var t0, t1, t2 Element
t0.Square(x) // x^2
t1.Square(&t0) // x^4
t1.Square(&t1) // x^8
t1.Multiply(x, &t1) // x^9
t0.Multiply(&t0, &t1) // x^11
t0.Square(&t0) // x^22
t0.Multiply(&t1, &t0) // x^31
t1.Square(&t0) // x^62
for i := 1; i < 5; i++ { // x^992
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // x^1023 -> 1023 = 2^10 - 1
t1.Square(&t0) // 2^11 - 2
for i := 1; i < 10; i++ { // 2^20 - 2^10
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^20 - 1
t2.Square(&t1) // 2^21 - 2
for i := 1; i < 20; i++ { // 2^40 - 2^20
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^40 - 1
t1.Square(&t1) // 2^41 - 2
for i := 1; i < 10; i++ { // 2^50 - 2^10
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^50 - 1
t1.Square(&t0) // 2^51 - 2
for i := 1; i < 50; i++ { // 2^100 - 2^50
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^100 - 1
t2.Square(&t1) // 2^101 - 2
for i := 1; i < 100; i++ { // 2^200 - 2^100
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^200 - 1
t1.Square(&t1) // 2^201 - 2
for i := 1; i < 50; i++ { // 2^250 - 2^50
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^250 - 1
t0.Square(&t0) // 2^251 - 2
t0.Square(&t0) // 2^252 - 4
return v.Multiply(&t0, x) // 2^252 - 3 -> x^(2^252-3)
}
// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion.
var sqrtM1 = &Element{1718705420411056, 234908883556509,
2233514472574048, 2117202627021982, 765476049583133}
// SqrtRatio sets r to the non-negative square root of the ratio of u and v.
//
// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio
// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00,
// and returns r and 0.
func (r *Element) SqrtRatio(u, v *Element) (R *Element, wasSquare int) {
t0 := new(Element)
// r = (u * v3) * (u * v7)^((p-5)/8)
v2 := new(Element).Square(v)
uv3 := new(Element).Multiply(u, t0.Multiply(v2, v))
uv7 := new(Element).Multiply(uv3, t0.Square(v2))
rr := new(Element).Multiply(uv3, t0.Pow22523(uv7))
check := new(Element).Multiply(v, t0.Square(rr)) // check = v * r^2
uNeg := new(Element).Negate(u)
correctSignSqrt := check.Equal(u)
flippedSignSqrt := check.Equal(uNeg)
flippedSignSqrtI := check.Equal(t0.Multiply(uNeg, sqrtM1))
rPrime := new(Element).Multiply(rr, sqrtM1) // r_prime = SQRT_M1 * r
// r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r)
rr.Select(rPrime, rr, flippedSignSqrt|flippedSignSqrtI)
r.Absolute(rr) // Choose the nonnegative square root.
return r, correctSignSqrt | flippedSignSqrt
}

View File

@ -1,16 +0,0 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
package field
// feMul sets out = a * b. It works like feMulGeneric.
//
//go:noescape
func feMul(out *Element, a *Element, b *Element)
// feSquare sets out = a * a. It works like feSquareGeneric.
//
//go:noescape
func feSquare(out *Element, a *Element)

View File

@ -1,379 +0,0 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
#include "textflag.h"
// func feMul(out *Element, a *Element, b *Element)
TEXT ·feMul(SB), NOSPLIT, $0-24
MOVQ a+8(FP), CX
MOVQ b+16(FP), BX
// r0 = a0×b0
MOVQ (CX), AX
MULQ (BX)
MOVQ AX, DI
MOVQ DX, SI
// r0 += 19×a1×b4
MOVQ 8(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a2×b3
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a3×b2
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a4×b1
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 8(BX)
ADDQ AX, DI
ADCQ DX, SI
// r1 = a0×b1
MOVQ (CX), AX
MULQ 8(BX)
MOVQ AX, R9
MOVQ DX, R8
// r1 += a1×b0
MOVQ 8(CX), AX
MULQ (BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a2×b4
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a3×b3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a4×b2
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, R9
ADCQ DX, R8
// r2 = a0×b2
MOVQ (CX), AX
MULQ 16(BX)
MOVQ AX, R11
MOVQ DX, R10
// r2 += a1×b1
MOVQ 8(CX), AX
MULQ 8(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += a2×b0
MOVQ 16(CX), AX
MULQ (BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a3×b4
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a4×b3
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R11
ADCQ DX, R10
// r3 = a0×b3
MOVQ (CX), AX
MULQ 24(BX)
MOVQ AX, R13
MOVQ DX, R12
// r3 += a1×b2
MOVQ 8(CX), AX
MULQ 16(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a2×b1
MOVQ 16(CX), AX
MULQ 8(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a3×b0
MOVQ 24(CX), AX
MULQ (BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += 19×a4×b4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R13
ADCQ DX, R12
// r4 = a0×b4
MOVQ (CX), AX
MULQ 32(BX)
MOVQ AX, R15
MOVQ DX, R14
// r4 += a1×b3
MOVQ 8(CX), AX
MULQ 24(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a2×b2
MOVQ 16(CX), AX
MULQ 16(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a3×b1
MOVQ 24(CX), AX
MULQ 8(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a4×b0
MOVQ 32(CX), AX
MULQ (BX)
ADDQ AX, R15
ADCQ DX, R14
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, DI, SI
SHLQ $0x0d, R9, R8
SHLQ $0x0d, R11, R10
SHLQ $0x0d, R13, R12
SHLQ $0x0d, R15, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Second reduction chain (carryPropagate)
MOVQ DI, SI
SHRQ $0x33, SI
MOVQ R9, R8
SHRQ $0x33, R8
MOVQ R11, R10
SHRQ $0x33, R10
MOVQ R13, R12
SHRQ $0x33, R12
MOVQ R15, R14
SHRQ $0x33, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Store output
MOVQ out+0(FP), AX
MOVQ DI, (AX)
MOVQ R9, 8(AX)
MOVQ R11, 16(AX)
MOVQ R13, 24(AX)
MOVQ R15, 32(AX)
RET
// func feSquare(out *Element, a *Element)
TEXT ·feSquare(SB), NOSPLIT, $0-16
MOVQ a+8(FP), CX
// r0 = l0×l0
MOVQ (CX), AX
MULQ (CX)
MOVQ AX, SI
MOVQ DX, BX
// r0 += 38×l1×l4
MOVQ 8(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, SI
ADCQ DX, BX
// r0 += 38×l2×l3
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 24(CX)
ADDQ AX, SI
ADCQ DX, BX
// r1 = 2×l0×l1
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 8(CX)
MOVQ AX, R8
MOVQ DX, DI
// r1 += 38×l2×l4
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R8
ADCQ DX, DI
// r1 += 19×l3×l3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(CX)
ADDQ AX, R8
ADCQ DX, DI
// r2 = 2×l0×l2
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 16(CX)
MOVQ AX, R10
MOVQ DX, R9
// r2 += l1×l1
MOVQ 8(CX), AX
MULQ 8(CX)
ADDQ AX, R10
ADCQ DX, R9
// r2 += 38×l3×l4
MOVQ 24(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R10
ADCQ DX, R9
// r3 = 2×l0×l3
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 24(CX)
MOVQ AX, R12
MOVQ DX, R11
// r3 += 2×l1×l2
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 16(CX)
ADDQ AX, R12
ADCQ DX, R11
// r3 += 19×l4×l4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(CX)
ADDQ AX, R12
ADCQ DX, R11
// r4 = 2×l0×l4
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 32(CX)
MOVQ AX, R14
MOVQ DX, R13
// r4 += 2×l1×l3
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 24(CX)
ADDQ AX, R14
ADCQ DX, R13
// r4 += l2×l2
MOVQ 16(CX), AX
MULQ 16(CX)
ADDQ AX, R14
ADCQ DX, R13
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, SI, BX
SHLQ $0x0d, R8, DI
SHLQ $0x0d, R10, R9
SHLQ $0x0d, R12, R11
SHLQ $0x0d, R14, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Second reduction chain (carryPropagate)
MOVQ SI, BX
SHRQ $0x33, BX
MOVQ R8, DI
SHRQ $0x33, DI
MOVQ R10, R9
SHRQ $0x33, R9
MOVQ R12, R11
SHRQ $0x33, R11
MOVQ R14, R13
SHRQ $0x33, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Store output
MOVQ out+0(FP), AX
MOVQ SI, (AX)
MOVQ R8, 8(AX)
MOVQ R10, 16(AX)
MOVQ R12, 24(AX)
MOVQ R14, 32(AX)
RET

View File

@ -1,12 +0,0 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !amd64 || !gc || purego
// +build !amd64 !gc purego
package field
func feMul(v, x, y *Element) { feMulGeneric(v, x, y) }
func feSquare(v, x *Element) { feSquareGeneric(v, x) }

View File

@ -1,16 +0,0 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
// +build arm64,gc,!purego
package field
//go:noescape
func carryPropagate(v *Element)
func (v *Element) carryPropagate() *Element {
carryPropagate(v)
return v
}

View File

@ -1,42 +0,0 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
#include "textflag.h"
// carryPropagate works exactly like carryPropagateGeneric and uses the
// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but
// avoids loading R0-R4 twice and uses LDP and STP.
//
// See https://golang.org/issues/43145 for the main compiler issue.
//
// func carryPropagate(v *Element)
TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8
MOVD v+0(FP), R20
LDP 0(R20), (R0, R1)
LDP 16(R20), (R2, R3)
MOVD 32(R20), R4
AND $0x7ffffffffffff, R0, R10
AND $0x7ffffffffffff, R1, R11
AND $0x7ffffffffffff, R2, R12
AND $0x7ffffffffffff, R3, R13
AND $0x7ffffffffffff, R4, R14
ADD R0>>51, R11, R11
ADD R1>>51, R12, R12
ADD R2>>51, R13, R13
ADD R3>>51, R14, R14
// R4>>51 * 19 + R10 -> R10
LSR $51, R4, R21
MOVD $19, R22
MADD R22, R10, R21, R10
STP (R10, R11), 0(R20)
STP (R12, R13), 16(R20)
MOVD R14, 32(R20)
RET

View File

@ -1,12 +0,0 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !arm64 || !gc || purego
// +build !arm64 !gc purego
package field
func (v *Element) carryPropagate() *Element {
return v.carryPropagateGeneric()
}

View File

@ -1,50 +0,0 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "errors"
// This file contains additional functionality that is not included in the
// upstream crypto/ed25519/edwards25519/field package.
// SetWideBytes sets v to x, where x is a 64-byte little-endian encoding, which
// is reduced modulo the field order. If x is not of the right length,
// SetWideBytes returns nil and an error, and the receiver is unchanged.
//
// SetWideBytes is not necessary to select a uniformly distributed value, and is
// only provided for compatibility: SetBytes can be used instead as the chance
// of bias is less than 2⁻²⁵⁰.
func (v *Element) SetWideBytes(x []byte) (*Element, error) {
if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetWideBytes input size")
}
// Split the 64 bytes into two elements, and extract the most significant
// bit of each, which is ignored by SetBytes.
lo, _ := new(Element).SetBytes(x[:32])
loMSB := uint64(x[31] >> 7)
hi, _ := new(Element).SetBytes(x[32:])
hiMSB := uint64(x[63] >> 7)
// The output we want is
//
// v = lo + loMSB * 2²⁵⁵ + hi * 2²⁵⁶ + hiMSB * 2⁵¹¹
//
// which applying the reduction identity comes out to
//
// v = lo + loMSB * 19 + hi * 2 * 19 + hiMSB * 2 * 19²
//
// l0 will be the sum of a 52 bits value (lo.l0), plus a 5 bits value
// (loMSB * 19), a 6 bits value (hi.l0 * 2 * 19), and a 10 bits value
// (hiMSB * 2 * 19²), so it fits in a uint64.
v.l0 = lo.l0 + loMSB*19 + hi.l0*2*19 + hiMSB*2*19*19
v.l1 = lo.l1 + hi.l1*2*19
v.l2 = lo.l2 + hi.l2*2*19
v.l3 = lo.l3 + hi.l3*2*19
v.l4 = lo.l4 + hi.l4*2*19
return v.carryPropagate(), nil
}

View File

@ -1,266 +0,0 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "math/bits"
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
// bits.Mul64 and bits.Add64 intrinsics.
type uint128 struct {
lo, hi uint64
}
// mul64 returns a * b.
func mul64(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
return uint128{lo, hi}
}
// addMul64 returns v + a * b.
func addMul64(v uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
lo, c := bits.Add64(lo, v.lo, 0)
hi, _ = bits.Add64(hi, v.hi, c)
return uint128{lo, hi}
}
// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits.
func shiftRightBy51(a uint128) uint64 {
return (a.hi << (64 - 51)) | (a.lo >> 51)
}
func feMulGeneric(v, a, b *Element) {
a0 := a.l0
a1 := a.l1
a2 := a.l2
a3 := a.l3
a4 := a.l4
b0 := b.l0
b1 := b.l1
b2 := b.l2
b3 := b.l3
b4 := b.l4
// Limb multiplication works like pen-and-paper columnar multiplication, but
// with 51-bit limbs instead of digits.
//
// a4 a3 a2 a1 a0 x
// b4 b3 b2 b1 b0 =
// ------------------------
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a4b1 a3b1 a2b1 a1b1 a0b1 +
// a4b2 a3b2 a2b2 a1b2 a0b2 +
// a4b3 a3b3 a2b3 a1b3 a0b3 +
// a4b4 a3b4 a2b4 a1b4 a0b4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to
// reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5,
// r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc.
//
// Reduction can be carried out simultaneously to multiplication. For
// example, we do not compute r5: whenever the result of a multiplication
// belongs to r5, like a1b4, we multiply it by 19 and add the result to r0.
//
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a3b1 a2b1 a1b1 a0b1 19×a4b1 +
// a2b2 a1b2 a0b2 19×a4b2 19×a3b2 +
// a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 +
// a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// Finally we add up the columns into wide, overlapping limbs.
a1_19 := a1 * 19
a2_19 := a2 * 19
a3_19 := a3 * 19
a4_19 := a4 * 19
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
r0 := mul64(a0, b0)
r0 = addMul64(r0, a1_19, b4)
r0 = addMul64(r0, a2_19, b3)
r0 = addMul64(r0, a3_19, b2)
r0 = addMul64(r0, a4_19, b1)
// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
r1 := mul64(a0, b1)
r1 = addMul64(r1, a1, b0)
r1 = addMul64(r1, a2_19, b4)
r1 = addMul64(r1, a3_19, b3)
r1 = addMul64(r1, a4_19, b2)
// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
r2 := mul64(a0, b2)
r2 = addMul64(r2, a1, b1)
r2 = addMul64(r2, a2, b0)
r2 = addMul64(r2, a3_19, b4)
r2 = addMul64(r2, a4_19, b3)
// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
r3 := mul64(a0, b3)
r3 = addMul64(r3, a1, b2)
r3 = addMul64(r3, a2, b1)
r3 = addMul64(r3, a3, b0)
r3 = addMul64(r3, a4_19, b4)
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
r4 := mul64(a0, b4)
r4 = addMul64(r4, a1, b3)
r4 = addMul64(r4, a2, b2)
r4 = addMul64(r4, a3, b1)
r4 = addMul64(r4, a4, b0)
// After the multiplication, we need to reduce (carry) the five coefficients
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
// to respect the Element invariant.
//
// Overall, the reduction works the same as carryPropagate, except with
// wider inputs: we take the carry for each coefficient by shifting it right
// by 51, and add it to the limb above it. The top carry is multiplied by 19
// according to the reduction identity and added to the lowest limb.
//
// The largest coefficient (r0) will be at most 111 bits, which guarantees
// that all carries are at most 111 - 51 = 60 bits, which fits in a uint64.
//
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
// r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²)
// r0 < (1 + 19 × 4) × 2⁵² × 2⁵²
// r0 < 2⁷ × 2⁵² × 2⁵²
// r0 < 2¹¹¹
//
// Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most
// 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and
// allows us to easily apply the reduction identity.
//
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
// r4 < 5 × 2⁵² × 2⁵²
// r4 < 2¹⁰⁷
//
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
// Now all coefficients fit into 64-bit registers but are still too large to
// be passed around as an Element. We therefore do one last carry chain,
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
func feSquareGeneric(v, a *Element) {
l0 := a.l0
l1 := a.l1
l2 := a.l2
l3 := a.l3
l4 := a.l4
// Squaring works precisely like multiplication above, but thanks to its
// symmetry we get to group a few terms together.
//
// l4 l3 l2 l1 l0 x
// l4 l3 l2 l1 l0 =
// ------------------------
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l4l1 l3l1 l2l1 l1l1 l0l1 +
// l4l2 l3l2 l2l2 l1l2 l0l2 +
// l4l3 l3l3 l2l3 l1l3 l0l3 +
// l4l4 l3l4 l2l4 l1l4 l0l4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l3l1 l2l1 l1l1 l0l1 19×l4l1 +
// l2l2 l1l2 l0l2 19×l4l2 19×l3l2 +
// l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 +
// l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with
// only three Mul64 and four Add64, instead of five and eight.
l0_2 := l0 * 2
l1_2 := l1 * 2
l1_38 := l1 * 38
l2_38 := l2 * 38
l3_38 := l3 * 38
l3_19 := l3 * 19
l4_19 := l4 * 19
// r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3)
r0 := mul64(l0, l0)
r0 = addMul64(r0, l1_38, l4)
r0 = addMul64(r0, l2_38, l3)
// r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
r1 := mul64(l0_2, l1)
r1 = addMul64(r1, l2_38, l4)
r1 = addMul64(r1, l3_19, l3)
// r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4
r2 := mul64(l0_2, l2)
r2 = addMul64(r2, l1, l1)
r2 = addMul64(r2, l3_38, l4)
// r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
r3 := mul64(l0_2, l3)
r3 = addMul64(r3, l1_2, l2)
r3 = addMul64(r3, l4_19, l4)
// r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2
r4 := mul64(l0_2, l4)
r4 = addMul64(r4, l1_2, l3)
r4 = addMul64(r4, l2, l2)
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
// carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry.
func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51
c1 := v.l1 >> 51
c2 := v.l2 >> 51
c3 := v.l3 >> 51
c4 := v.l4 >> 51
// c4 is at most 64 - 51 = 13 bits, so c4*19 is at most 18 bits, and
// the final l0 will be at most 52 bits. Similarly for the rest.
v.l0 = v.l0&maskLow51Bits + c4*19
v.l1 = v.l1&maskLow51Bits + c0
v.l2 = v.l2&maskLow51Bits + c1
v.l3 = v.l3&maskLow51Bits + c2
v.l4 = v.l4&maskLow51Bits + c3
return v
}

View File

@ -1,343 +0,0 @@
// Copyright (c) 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"encoding/binary"
"errors"
)
// A Scalar is an integer modulo
//
// l = 2^252 + 27742317777372353535851937790883648493
//
// which is the prime order of the edwards25519 group.
//
// This type works similarly to math/big.Int, and all arguments and
// receivers are allowed to alias.
//
// The zero value is a valid zero element.
type Scalar struct {
// s is the scalar in the Montgomery domain, in the format of the
// fiat-crypto implementation.
s fiatScalarMontgomeryDomainFieldElement
}
// The field implementation in scalar_fiat.go is generated by the fiat-crypto
// project (https://github.com/mit-plv/fiat-crypto) at version v0.0.9 (23d2dbc)
// from a formally verified model.
//
// fiat-crypto code comes under the following license.
//
// Copyright (c) 2015-2020 The fiat-crypto Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// THIS SOFTWARE IS PROVIDED BY the fiat-crypto authors "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Berkeley Software Design,
// Inc. BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// NewScalar returns a new zero Scalar.
func NewScalar() *Scalar {
return &Scalar{}
}
// MultiplyAdd sets s = x * y + z mod l, and returns s. It is equivalent to
// using Multiply and then Add.
func (s *Scalar) MultiplyAdd(x, y, z *Scalar) *Scalar {
// Make a copy of z in case it aliases s.
zCopy := new(Scalar).Set(z)
return s.Multiply(x, y).Add(s, zCopy)
}
// Add sets s = x + y mod l, and returns s.
func (s *Scalar) Add(x, y *Scalar) *Scalar {
// s = 1 * x + y mod l
fiatScalarAdd(&s.s, &x.s, &y.s)
return s
}
// Subtract sets s = x - y mod l, and returns s.
func (s *Scalar) Subtract(x, y *Scalar) *Scalar {
// s = -1 * y + x mod l
fiatScalarSub(&s.s, &x.s, &y.s)
return s
}
// Negate sets s = -x mod l, and returns s.
func (s *Scalar) Negate(x *Scalar) *Scalar {
// s = -1 * x + 0 mod l
fiatScalarOpp(&s.s, &x.s)
return s
}
// Multiply sets s = x * y mod l, and returns s.
func (s *Scalar) Multiply(x, y *Scalar) *Scalar {
// s = x * y + 0 mod l
fiatScalarMul(&s.s, &x.s, &y.s)
return s
}
// Set sets s = x, and returns s.
func (s *Scalar) Set(x *Scalar) *Scalar {
*s = *x
return s
}
// SetUniformBytes sets s = x mod l, where x is a 64-byte little-endian integer.
// If x is not of the right length, SetUniformBytes returns nil and an error,
// and the receiver is unchanged.
//
// SetUniformBytes can be used to set s to a uniformly distributed value given
// 64 uniformly distributed random bytes.
func (s *Scalar) SetUniformBytes(x []byte) (*Scalar, error) {
if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetUniformBytes input length")
}
// We have a value x of 512 bits, but our fiatScalarFromBytes function
// expects an input lower than l, which is a little over 252 bits.
//
// Instead of writing a reduction function that operates on wider inputs, we
// can interpret x as the sum of three shorter values a, b, and c.
//
// x = a + b * 2^168 + c * 2^336 mod l
//
// We then precompute 2^168 and 2^336 modulo l, and perform the reduction
// with two multiplications and two additions.
s.setShortBytes(x[:21])
t := new(Scalar).setShortBytes(x[21:42])
s.Add(s, t.Multiply(t, scalarTwo168))
t.setShortBytes(x[42:])
s.Add(s, t.Multiply(t, scalarTwo336))
return s, nil
}
// scalarTwo168 and scalarTwo336 are 2^168 and 2^336 modulo l, encoded as a
// fiatScalarMontgomeryDomainFieldElement, which is a little-endian 4-limb value
// in the 2^256 Montgomery domain.
var scalarTwo168 = &Scalar{s: [4]uint64{0x5b8ab432eac74798, 0x38afddd6de59d5d7,
0xa2c131b399411b7c, 0x6329a7ed9ce5a30}}
var scalarTwo336 = &Scalar{s: [4]uint64{0xbd3d108e2b35ecc5, 0x5c3a3718bdf9c90b,
0x63aa97a331b4f2ee, 0x3d217f5be65cb5c}}
// setShortBytes sets s = x mod l, where x is a little-endian integer shorter
// than 32 bytes.
func (s *Scalar) setShortBytes(x []byte) *Scalar {
if len(x) >= 32 {
panic("edwards25519: internal error: setShortBytes called with a long string")
}
var buf [32]byte
copy(buf[:], x)
fiatScalarFromBytes((*[4]uint64)(&s.s), &buf)
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s
}
// SetCanonicalBytes sets s = x, where x is a 32-byte little-endian encoding of
// s, and returns s. If x is not a canonical encoding of s, SetCanonicalBytes
// returns nil and an error, and the receiver is unchanged.
func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) {
if len(x) != 32 {
return nil, errors.New("invalid scalar length")
}
if !isReduced(x) {
return nil, errors.New("invalid scalar encoding")
}
fiatScalarFromBytes((*[4]uint64)(&s.s), (*[32]byte)(x))
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s, nil
}
// scalarMinusOneBytes is l - 1 in little endian.
var scalarMinusOneBytes = [32]byte{236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16}
// isReduced returns whether the given scalar in 32-byte little endian encoded
// form is reduced modulo l.
func isReduced(s []byte) bool {
if len(s) != 32 {
return false
}
for i := len(s) - 1; i >= 0; i-- {
switch {
case s[i] > scalarMinusOneBytes[i]:
return false
case s[i] < scalarMinusOneBytes[i]:
return true
}
}
return true
}
// SetBytesWithClamping applies the buffer pruning described in RFC 8032,
// Section 5.1.5 (also known as clamping) and sets s to the result. The input
// must be 32 bytes, and it is not modified. If x is not of the right length,
// SetBytesWithClamping returns nil and an error, and the receiver is unchanged.
//
// Note that since Scalar values are always reduced modulo the prime order of
// the curve, the resulting value will not preserve any of the cofactor-clearing
// properties that clamping is meant to provide. It will however work as
// expected as long as it is applied to points on the prime order subgroup, like
// in Ed25519. In fact, it is lost to history why RFC 8032 adopted the
// irrelevant RFC 7748 clamping, but it is now required for compatibility.
func (s *Scalar) SetBytesWithClamping(x []byte) (*Scalar, error) {
// The description above omits the purpose of the high bits of the clamping
// for brevity, but those are also lost to reductions, and are also
// irrelevant to edwards25519 as they protect against a specific
// implementation bug that was once observed in a generic Montgomery ladder.
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid SetBytesWithClamping input length")
}
// We need to use the wide reduction from SetUniformBytes, since clamping
// sets the 2^254 bit, making the value higher than the order.
var wideBytes [64]byte
copy(wideBytes[:], x[:])
wideBytes[0] &= 248
wideBytes[31] &= 63
wideBytes[31] |= 64
return s.SetUniformBytes(wideBytes[:])
}
// Bytes returns the canonical 32-byte little-endian encoding of s.
func (s *Scalar) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var encoded [32]byte
return s.bytes(&encoded)
}
func (s *Scalar) bytes(out *[32]byte) []byte {
var ss fiatScalarNonMontgomeryDomainFieldElement
fiatScalarFromMontgomery(&ss, &s.s)
fiatScalarToBytes(out, (*[4]uint64)(&ss))
return out[:]
}
// Equal returns 1 if s and t are equal, and 0 otherwise.
func (s *Scalar) Equal(t *Scalar) int {
var diff fiatScalarMontgomeryDomainFieldElement
fiatScalarSub(&diff, &s.s, &t.s)
var nonzero uint64
fiatScalarNonzero(&nonzero, (*[4]uint64)(&diff))
nonzero |= nonzero >> 32
nonzero |= nonzero >> 16
nonzero |= nonzero >> 8
nonzero |= nonzero >> 4
nonzero |= nonzero >> 2
nonzero |= nonzero >> 1
return int(^nonzero) & 1
}
// nonAdjacentForm computes a width-w non-adjacent form for this scalar.
//
// w must be between 2 and 8, or nonAdjacentForm will panic.
func (s *Scalar) nonAdjacentForm(w uint) [256]int8 {
// This implementation is adapted from the one
// in curve25519-dalek and is documented there:
// https://github.com/dalek-cryptography/curve25519-dalek/blob/f630041af28e9a405255f98a8a93adca18e4315b/src/scalar.rs#L800-L871
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
if w < 2 {
panic("w must be at least 2 by the definition of NAF")
} else if w > 8 {
panic("NAF digits must fit in int8")
}
var naf [256]int8
var digits [5]uint64
for i := 0; i < 4; i++ {
digits[i] = binary.LittleEndian.Uint64(b[i*8:])
}
width := uint64(1 << w)
windowMask := uint64(width - 1)
pos := uint(0)
carry := uint64(0)
for pos < 256 {
indexU64 := pos / 64
indexBit := pos % 64
var bitBuf uint64
if indexBit < 64-w {
// This window's bits are contained in a single u64
bitBuf = digits[indexU64] >> indexBit
} else {
// Combine the current 64 bits with bits from the next 64
bitBuf = (digits[indexU64] >> indexBit) | (digits[1+indexU64] << (64 - indexBit))
}
// Add carry into the current window
window := carry + (bitBuf & windowMask)
if window&1 == 0 {
// If the window value is even, preserve the carry and continue.
// Why is the carry preserved?
// If carry == 0 and window & 1 == 0,
// then the next carry should be 0
// If carry == 1 and window & 1 == 0,
// then bit_buf & 1 == 1 so the next carry should be 1
pos += 1
continue
}
if window < width/2 {
carry = 0
naf[pos] = int8(window)
} else {
carry = 1
naf[pos] = int8(window) - int8(width)
}
pos += w
}
return naf
}
func (s *Scalar) signedRadix16() [64]int8 {
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
var digits [64]int8
// Compute unsigned radix-16 digits:
for i := 0; i < 32; i++ {
digits[2*i] = int8(b[i] & 15)
digits[2*i+1] = int8((b[i] >> 4) & 15)
}
// Recenter coefficients:
for i := 0; i < 63; i++ {
carry := (digits[i] + 8) >> 4
digits[i] -= carry << 4
digits[i+1] += carry
}
return digits
}

File diff suppressed because it is too large Load Diff

View File

@ -1,214 +0,0 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import "sync"
// basepointTable is a set of 32 affineLookupTables, where table i is generated
// from 256i * basepoint. It is precomputed the first time it's used.
func basepointTable() *[32]affineLookupTable {
basepointTablePrecomp.initOnce.Do(func() {
p := NewGeneratorPoint()
for i := 0; i < 32; i++ {
basepointTablePrecomp.table[i].FromP3(p)
for j := 0; j < 8; j++ {
p.Add(p, p)
}
}
})
return &basepointTablePrecomp.table
}
var basepointTablePrecomp struct {
table [32]affineLookupTable
initOnce sync.Once
}
// ScalarBaseMult sets v = x * B, where B is the canonical generator, and
// returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarBaseMult(x *Scalar) *Point {
basepointTable := basepointTable()
// Write x = sum(x_i * 16^i) so x*B = sum( B*x_i*16^i )
// as described in the Ed25519 paper
//
// Group even and odd coefficients
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + x_1*16^1*B + x_3*16^3*B + ... + x_63*16^63*B
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + 16*( x_1*16^0*B + x_3*16^2*B + ... + x_63*16^62*B)
//
// We use a lookup table for each i to get x_i*16^(2*i)*B
// and do four doublings to multiply by 16.
digits := x.signedRadix16()
multiple := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Accumulate the odd components first
v.Set(NewIdentityPoint())
for i := 1; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
// Multiply by 16
tmp2.FromP3(v) // tmp2 = v in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*v in P1xP1 coords
v.fromP1xP1(tmp1) // now v = 16*(odd components)
// Accumulate the even components
for i := 0; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
return v
}
// ScalarMult sets v = x * q, and returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarMult(x *Scalar, q *Point) *Point {
checkInitialized(q)
var table projLookupTable
table.FromP3(q)
// Write x = sum(x_i * 16^i)
// so x*Q = sum( Q*x_i*16^i )
// = Q*x_0 + 16*(Q*x_1 + 16*( ... + Q*x_63) ... )
// <------compute inside out---------
//
// We use the lookup table to get the x_i*Q values
// and do four doublings to compute 16*Q
digits := x.signedRadix16()
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
table.SelectInto(multiple, digits[63])
v.Set(NewIdentityPoint())
tmp1.Add(v, multiple) // tmp1 = x_63*Q in P1xP1 coords
for i := 62; i >= 0; i-- {
tmp2.FromP1xP1(tmp1) // tmp2 = (prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
table.SelectInto(multiple, digits[i])
tmp1.Add(v, multiple) // tmp1 = x_i*Q + 16*(prev) in P1xP1 coords
}
v.fromP1xP1(tmp1)
return v
}
// basepointNafTable is the nafLookupTable8 for the basepoint.
// It is precomputed the first time it's used.
func basepointNafTable() *nafLookupTable8 {
basepointNafTablePrecomp.initOnce.Do(func() {
basepointNafTablePrecomp.table.FromP3(NewGeneratorPoint())
})
return &basepointNafTablePrecomp.table
}
var basepointNafTablePrecomp struct {
table nafLookupTable8
initOnce sync.Once
}
// VarTimeDoubleScalarBaseMult sets v = a * A + b * B, where B is the canonical
// generator, and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeDoubleScalarBaseMult(a *Scalar, A *Point, b *Scalar) *Point {
checkInitialized(A)
// Similarly to the single variable-base approach, we compute
// digits and use them with a lookup table. However, because
// we are allowed to do variable-time operations, we don't
// need constant-time lookups or constant-time digit
// computations.
//
// So we use a non-adjacent form of some width w instead of
// radix 16. This is like a binary representation (one digit
// for each binary place) but we allow the digits to grow in
// magnitude up to 2^{w-1} so that the nonzero digits are as
// sparse as possible. Intuitively, this "condenses" the
// "mass" of the scalar onto sparse coefficients (meaning
// fewer additions).
basepointNafTable := basepointNafTable()
var aTable nafLookupTable5
aTable.FromP3(A)
// Because the basepoint is fixed, we can use a wider NAF
// corresponding to a bigger table.
aNaf := a.nonAdjacentForm(5)
bNaf := b.nonAdjacentForm(8)
// Find the first nonzero coefficient.
i := 255
for j := i; j >= 0; j-- {
if aNaf[j] != 0 || bNaf[j] != 0 {
break
}
}
multA := &projCached{}
multB := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
for ; i >= 0; i-- {
tmp1.Double(tmp2)
// Only update v if we have a nonzero coeff to add in.
if aNaf[i] > 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, aNaf[i])
tmp1.Add(v, multA)
} else if aNaf[i] < 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, -aNaf[i])
tmp1.Sub(v, multA)
}
if bNaf[i] > 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, bNaf[i])
tmp1.AddAffine(v, multB)
} else if bNaf[i] < 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, -bNaf[i])
tmp1.SubAffine(v, multB)
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}

View File

@ -1,129 +0,0 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"crypto/subtle"
)
// A dynamic lookup table for variable-base, constant-time scalar muls.
type projLookupTable struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, constant-time scalar muls.
type affineLookupTable struct {
points [8]affineCached
}
// A dynamic lookup table for variable-base, variable-time scalar muls.
type nafLookupTable5 struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, variable-time scalar muls.
type nafLookupTable8 struct {
points [64]affineCached
}
// Constructors.
// Builds a lookup table at runtime. Fast.
func (v *projLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to a projCached
// This is needlessly complicated because the API has explicit
// receivers instead of creating stack objects and relying on RVO
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(q, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *affineLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to affineCached
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(q, &v.points[i])))
}
}
// Builds a lookup table at runtime. Fast.
func (v *nafLookupTable5) FromP3(q *Point) {
// Goal: v.points[i] = (2*i+1)*Q, i.e., Q, 3Q, 5Q, ..., 15Q
// This allows lookup of -15Q, ..., -3Q, -Q, 0, Q, 3Q, ..., 15Q
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(&q2, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *nafLookupTable8) FromP3(q *Point) {
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 63; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(&q2, &v.points[i])))
}
}
// Selectors.
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *projLookupTable) SelectInto(dest *projCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *affineLookupTable) SelectInto(dest *affineCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Given odd x with 0 < x < 2^4, return x*Q (in variable time).
func (v *nafLookupTable5) SelectInto(dest *projCached, x int8) {
*dest = v.points[x/2]
}
// Given odd x with 0 < x < 2^7, return x*Q (in variable time).
func (v *nafLookupTable8) SelectInto(dest *affineCached, x int8) {
*dest = v.points[x/2]
}

View File

@ -1,42 +0,0 @@
// based on https://github.com/golang/go/blob/master/src/net/url/url.go
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
func shouldEscape(c byte) bool {
if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' || c == '_' || c == '-' || c == '~' || c == '.' {
return false
}
return true
}
func escape(s string) string {
hexCount := 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c) {
hexCount++
}
}
if hexCount == 0 {
return s
}
t := make([]byte, len(s)+2*hexCount)
j := 0
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case shouldEscape(c):
t[j] = '%'
t[j+1] = "0123456789ABCDEF"[c>>4]
t[j+2] = "0123456789ABCDEF"[c&15]
j += 3
default:
t[j] = s[i]
j++
}
}
return string(t)
}

View File

@ -1,201 +0,0 @@
// HWS API Gateway Signature
// based on https://github.com/datastream/aws/blob/master/signv4.go
// Copyright (c) 2014, Xianjie
package core
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"fmt"
"io/ioutil"
"net/http"
"sort"
"strings"
"time"
)
const (
DateFormat = "20060102T150405Z"
SignAlgorithm = "SDK-HMAC-SHA256"
HeaderXDateTime = "X-Sdk-Date"
HeaderXHost = "host"
HeaderXAuthorization = "Authorization"
HeaderXContentSha256 = "X-Sdk-Content-Sha256"
)
func hmacsha256(keyByte []byte, dataStr string) ([]byte, error) {
hm := hmac.New(sha256.New, []byte(keyByte))
if _, err := hm.Write([]byte(dataStr)); err != nil {
return nil, err
}
return hm.Sum(nil), nil
}
// Build a CanonicalRequest from a regular request string
func CanonicalRequest(request *http.Request, signedHeaders []string) (string, error) {
var hexencode string
var err error
if hex := request.Header.Get(HeaderXContentSha256); hex != "" {
hexencode = hex
} else {
bodyData, err := RequestPayload(request)
if err != nil {
return "", err
}
hexencode, err = HexEncodeSHA256Hash(bodyData)
if err != nil {
return "", err
}
}
return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", request.Method, CanonicalURI(request), CanonicalQueryString(request), CanonicalHeaders(request, signedHeaders), strings.Join(signedHeaders, ";"), hexencode), err
}
// CanonicalURI returns request uri
func CanonicalURI(request *http.Request) string {
pattens := strings.Split(request.URL.Path, "/")
var uriSlice []string
for _, v := range pattens {
uriSlice = append(uriSlice, escape(v))
}
urlpath := strings.Join(uriSlice, "/")
if len(urlpath) == 0 || urlpath[len(urlpath)-1] != '/' {
urlpath = urlpath + "/"
}
return urlpath
}
// CanonicalQueryString
func CanonicalQueryString(request *http.Request) string {
var keys []string
queryMap := request.URL.Query()
for key := range queryMap {
keys = append(keys, key)
}
sort.Strings(keys)
var query []string
for _, key := range keys {
k := escape(key)
sort.Strings(queryMap[key])
for _, v := range queryMap[key] {
kv := fmt.Sprintf("%s=%s", k, escape(v))
query = append(query, kv)
}
}
queryStr := strings.Join(query, "&")
request.URL.RawQuery = queryStr
return queryStr
}
// CanonicalHeaders
func CanonicalHeaders(request *http.Request, signerHeaders []string) string {
var canonicalHeaders []string
header := make(map[string][]string)
for k, v := range request.Header {
header[strings.ToLower(k)] = v
}
for _, key := range signerHeaders {
value := header[key]
if strings.EqualFold(key, HeaderXHost) {
value = []string{request.Host}
}
sort.Strings(value)
for _, v := range value {
canonicalHeaders = append(canonicalHeaders, key+":"+strings.TrimSpace(v))
}
}
return fmt.Sprintf("%s\n", strings.Join(canonicalHeaders, "\n"))
}
// SignedHeaders
func SignedHeaders(r *http.Request) []string {
var signedHeaders []string
for key := range r.Header {
signedHeaders = append(signedHeaders, strings.ToLower(key))
}
sort.Strings(signedHeaders)
return signedHeaders
}
// RequestPayload
func RequestPayload(request *http.Request) ([]byte, error) {
if request.Body == nil {
return []byte(""), nil
}
bodyByte, err := ioutil.ReadAll(request.Body)
if err != nil {
return []byte(""), err
}
request.Body = ioutil.NopCloser(bytes.NewBuffer(bodyByte))
return bodyByte, err
}
// Create a "String to Sign".
func StringToSign(canonicalRequest string, t time.Time) (string, error) {
hashStruct := sha256.New()
_, err := hashStruct.Write([]byte(canonicalRequest))
if err != nil {
return "", err
}
return fmt.Sprintf("%s\n%s\n%x",
SignAlgorithm, t.UTC().Format(DateFormat), hashStruct.Sum(nil)), nil
}
// Create the HWS Signature.
func SignStringToSign(stringToSign string, signingKey []byte) (string, error) {
hmsha, err := hmacsha256(signingKey, stringToSign)
return fmt.Sprintf("%x", hmsha), err
}
// HexEncodeSHA256Hash returns hexcode of sha256
func HexEncodeSHA256Hash(body []byte) (string, error) {
hashStruct := sha256.New()
if len(body) == 0 {
body = []byte("")
}
_, err := hashStruct.Write(body)
return fmt.Sprintf("%x", hashStruct.Sum(nil)), err
}
// Get the finalized value for the "Authorization" header. The signature parameter is the output from SignStringToSign
func AuthHeaderValue(signatureStr, accessKeyStr string, signedHeaders []string) string {
return fmt.Sprintf("%s Access=%s, SignedHeaders=%s, Signature=%s", SignAlgorithm, accessKeyStr, strings.Join(signedHeaders, ";"), signatureStr)
}
// Signature HWS meta
type Signer struct {
Key string
Secret string
}
// SignRequest set Authorization header
func (s *Signer) Sign(request *http.Request) error {
var t time.Time
var err error
var date string
if date = request.Header.Get(HeaderXDateTime); date != "" {
t, err = time.Parse(DateFormat, date)
}
if err != nil || date == "" {
t = time.Now()
request.Header.Set(HeaderXDateTime, t.UTC().Format(DateFormat))
}
signedHeaders := SignedHeaders(request)
canonicalRequest, err := CanonicalRequest(request, signedHeaders)
if err != nil {
return err
}
stringToSignStr, err := StringToSign(canonicalRequest, t)
if err != nil {
return err
}
signatureStr, err := SignStringToSign(stringToSignStr, []byte(s.Secret))
if err != nil {
return err
}
authValueStr := AuthHeaderValue(signatureStr, s.Key, signedHeaders)
request.Header.Set(HeaderXAuthorization, authValueStr)
return nil
}

View File

@ -1 +0,0 @@
squirrel.test

View File

@ -1,30 +0,0 @@
language: go
go:
- 1.11.x
- 1.12.x
- 1.13.x
services:
- mysql
- postgresql
# Setting sudo access to false will let Travis CI use containers rather than
# VMs to run the tests. For more details see:
# - http://docs.travis-ci.com/user/workers/container-based-infrastructure/
# - http://docs.travis-ci.com/user/workers/standard-infrastructure/
sudo: false
before_script:
- mysql -e 'CREATE DATABASE squirrel;'
- psql -c 'CREATE DATABASE squirrel;' -U postgres
script:
- go test
- cd integration
- go test -args -driver sqlite3
- go test -args -driver mysql -dataSource travis@/squirrel
- go test -args -driver postgres -dataSource 'postgres://postgres@localhost/squirrel?sslmode=disable'
notifications:
irc: "irc.freenode.net#masterminds"

View File

@ -1,23 +0,0 @@
MIT License
Squirrel: The Masterminds
Copyright (c) 2014-2015, Lann Martin. Copyright (C) 2015-2016, Google. Copyright (C) 2015, Matt Farina and Matt Butcher.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,142 +0,0 @@
[![Stability: Maintenance](https://masterminds.github.io/stability/maintenance.svg)](https://masterminds.github.io/stability/maintenance.html)
### Squirrel is "complete".
Bug fixes will still be merged (slowly). Bug reports are welcome, but I will not necessarily respond to them. If another fork (or substantially similar project) actively improves on what Squirrel does, let me know and I may link to it here.
# Squirrel - fluent SQL generator for Go
```go
import "github.com/Masterminds/squirrel"
```
[![GoDoc](https://godoc.org/github.com/Masterminds/squirrel?status.png)](https://godoc.org/github.com/Masterminds/squirrel)
[![Build Status](https://api.travis-ci.org/Masterminds/squirrel.svg?branch=master)](https://travis-ci.org/Masterminds/squirrel)
**Squirrel is not an ORM.** For an application of Squirrel, check out
[structable, a table-struct mapper](https://github.com/Masterminds/structable)
Squirrel helps you build SQL queries from composable parts:
```go
import sq "github.com/Masterminds/squirrel"
users := sq.Select("*").From("users").Join("emails USING (email_id)")
active := users.Where(sq.Eq{"deleted_at": nil})
sql, args, err := active.ToSql()
sql == "SELECT * FROM users JOIN emails USING (email_id) WHERE deleted_at IS NULL"
```
```go
sql, args, err := sq.
Insert("users").Columns("name", "age").
Values("moe", 13).Values("larry", sq.Expr("? + 5", 12)).
ToSql()
sql == "INSERT INTO users (name,age) VALUES (?,?),(?,? + 5)"
```
Squirrel can also execute queries directly:
```go
stooges := users.Where(sq.Eq{"username": []string{"moe", "larry", "curly", "shemp"}})
three_stooges := stooges.Limit(3)
rows, err := three_stooges.RunWith(db).Query()
// Behaves like:
rows, err := db.Query("SELECT * FROM users WHERE username IN (?,?,?,?) LIMIT 3",
"moe", "larry", "curly", "shemp")
```
Squirrel makes conditional query building a breeze:
```go
if len(q) > 0 {
users = users.Where("name LIKE ?", fmt.Sprint("%", q, "%"))
}
```
Squirrel wants to make your life easier:
```go
// StmtCache caches Prepared Stmts for you
dbCache := sq.NewStmtCache(db)
// StatementBuilder keeps your syntax neat
mydb := sq.StatementBuilder.RunWith(dbCache)
select_users := mydb.Select("*").From("users")
```
Squirrel loves PostgreSQL:
```go
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
// You use question marks for placeholders...
sql, _, _ := psql.Select("*").From("elephants").Where("name IN (?,?)", "Dumbo", "Verna").ToSql()
/// ...squirrel replaces them using PlaceholderFormat.
sql == "SELECT * FROM elephants WHERE name IN ($1,$2)"
/// You can retrieve id ...
query := sq.Insert("nodes").
Columns("uuid", "type", "data").
Values(node.Uuid, node.Type, node.Data).
Suffix("RETURNING \"id\"").
RunWith(m.db).
PlaceholderFormat(sq.Dollar)
query.QueryRow().Scan(&node.id)
```
You can escape question marks by inserting two question marks:
```sql
SELECT * FROM nodes WHERE meta->'format' ??| array[?,?]
```
will generate with the Dollar Placeholder:
```sql
SELECT * FROM nodes WHERE meta->'format' ?| array[$1,$2]
```
## FAQ
* **How can I build an IN query on composite keys / tuples, e.g. `WHERE (col1, col2) IN ((1,2),(3,4))`? ([#104](https://github.com/Masterminds/squirrel/issues/104))**
Squirrel does not explicitly support tuples, but you can get the same effect with e.g.:
```go
sq.Or{
sq.Eq{"col1": 1, "col2": 2},
sq.Eq{"col1": 3, "col2": 4}}
```
```sql
WHERE (col1 = 1 AND col2 = 2) OR (col1 = 3 AND col2 = 4)
```
(which should produce the same query plan as the tuple version)
* **Why doesn't `Eq{"mynumber": []uint8{1,2,3}}` turn into an `IN` query? ([#114](https://github.com/Masterminds/squirrel/issues/114))**
Values of type `[]byte` are handled specially by `database/sql`. In Go, [`byte` is just an alias of `uint8`](https://golang.org/pkg/builtin/#byte), so there is no way to distinguish `[]uint8` from `[]byte`.
* **Some features are poorly documented!**
This isn't a frequent complaints section!
* **Some features are poorly documented?**
Yes. The tests should be considered a part of the documentation; take a look at those for ideas on how to express more complex queries.
## License
Squirrel is released under the
[MIT License](http://www.opensource.org/licenses/MIT).

View File

@ -1,128 +0,0 @@
package squirrel
import (
"bytes"
"errors"
"github.com/lann/builder"
)
func init() {
builder.Register(CaseBuilder{}, caseData{})
}
// sqlizerBuffer is a helper that allows to write many Sqlizers one by one
// without constant checks for errors that may come from Sqlizer
type sqlizerBuffer struct {
bytes.Buffer
args []interface{}
err error
}
// WriteSql converts Sqlizer to SQL strings and writes it to buffer
func (b *sqlizerBuffer) WriteSql(item Sqlizer) {
if b.err != nil {
return
}
var str string
var args []interface{}
str, args, b.err = nestedToSql(item)
if b.err != nil {
return
}
b.WriteString(str)
b.WriteByte(' ')
b.args = append(b.args, args...)
}
func (b *sqlizerBuffer) ToSql() (string, []interface{}, error) {
return b.String(), b.args, b.err
}
// whenPart is a helper structure to describe SQLs "WHEN ... THEN ..." expression
type whenPart struct {
when Sqlizer
then Sqlizer
}
func newWhenPart(when interface{}, then interface{}) whenPart {
return whenPart{newPart(when), newPart(then)}
}
// caseData holds all the data required to build a CASE SQL construct
type caseData struct {
What Sqlizer
WhenParts []whenPart
Else Sqlizer
}
// ToSql implements Sqlizer
func (d *caseData) ToSql() (sqlStr string, args []interface{}, err error) {
if len(d.WhenParts) == 0 {
err = errors.New("case expression must contain at lease one WHEN clause")
return
}
sql := sqlizerBuffer{}
sql.WriteString("CASE ")
if d.What != nil {
sql.WriteSql(d.What)
}
for _, p := range d.WhenParts {
sql.WriteString("WHEN ")
sql.WriteSql(p.when)
sql.WriteString("THEN ")
sql.WriteSql(p.then)
}
if d.Else != nil {
sql.WriteString("ELSE ")
sql.WriteSql(d.Else)
}
sql.WriteString("END")
return sql.ToSql()
}
// CaseBuilder builds SQL CASE construct which could be used as parts of queries.
type CaseBuilder builder.Builder
// ToSql builds the query into a SQL string and bound args.
func (b CaseBuilder) ToSql() (string, []interface{}, error) {
data := builder.GetStruct(b).(caseData)
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b CaseBuilder) MustSql() (string, []interface{}) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// what sets optional value for CASE construct "CASE [value] ..."
func (b CaseBuilder) what(expr interface{}) CaseBuilder {
return builder.Set(b, "What", newPart(expr)).(CaseBuilder)
}
// When adds "WHEN ... THEN ..." part to CASE construct
func (b CaseBuilder) When(when interface{}, then interface{}) CaseBuilder {
// TODO: performance hint: replace slice of WhenPart with just slice of parts
// where even indices of the slice belong to "when"s and odd indices belong to "then"s
return builder.Append(b, "WhenParts", newWhenPart(when, then)).(CaseBuilder)
}
// What sets optional "ELSE ..." part for CASE construct
func (b CaseBuilder) Else(expr interface{}) CaseBuilder {
return builder.Set(b, "Else", newPart(expr)).(CaseBuilder)
}

View File

@ -1,191 +0,0 @@
package squirrel
import (
"bytes"
"database/sql"
"fmt"
"strings"
"github.com/lann/builder"
)
type deleteData struct {
PlaceholderFormat PlaceholderFormat
RunWith BaseRunner
Prefixes []Sqlizer
From string
WhereParts []Sqlizer
OrderBys []string
Limit string
Offset string
Suffixes []Sqlizer
}
func (d *deleteData) Exec() (sql.Result, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
return ExecWith(d.RunWith, d)
}
func (d *deleteData) ToSql() (sqlStr string, args []interface{}, err error) {
if len(d.From) == 0 {
err = fmt.Errorf("delete statements must specify a From table")
return
}
sql := &bytes.Buffer{}
if len(d.Prefixes) > 0 {
args, err = appendToSql(d.Prefixes, sql, " ", args)
if err != nil {
return
}
sql.WriteString(" ")
}
sql.WriteString("DELETE FROM ")
sql.WriteString(d.From)
if len(d.WhereParts) > 0 {
sql.WriteString(" WHERE ")
args, err = appendToSql(d.WhereParts, sql, " AND ", args)
if err != nil {
return
}
}
if len(d.OrderBys) > 0 {
sql.WriteString(" ORDER BY ")
sql.WriteString(strings.Join(d.OrderBys, ", "))
}
if len(d.Limit) > 0 {
sql.WriteString(" LIMIT ")
sql.WriteString(d.Limit)
}
if len(d.Offset) > 0 {
sql.WriteString(" OFFSET ")
sql.WriteString(d.Offset)
}
if len(d.Suffixes) > 0 {
sql.WriteString(" ")
args, err = appendToSql(d.Suffixes, sql, " ", args)
if err != nil {
return
}
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
return
}
// Builder
// DeleteBuilder builds SQL DELETE statements.
type DeleteBuilder builder.Builder
func init() {
builder.Register(DeleteBuilder{}, deleteData{})
}
// Format methods
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b DeleteBuilder) PlaceholderFormat(f PlaceholderFormat) DeleteBuilder {
return builder.Set(b, "PlaceholderFormat", f).(DeleteBuilder)
}
// Runner methods
// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec.
func (b DeleteBuilder) RunWith(runner BaseRunner) DeleteBuilder {
return setRunWith(b, runner).(DeleteBuilder)
}
// Exec builds and Execs the query with the Runner set by RunWith.
func (b DeleteBuilder) Exec() (sql.Result, error) {
data := builder.GetStruct(b).(deleteData)
return data.Exec()
}
// SQL methods
// ToSql builds the query into a SQL string and bound args.
func (b DeleteBuilder) ToSql() (string, []interface{}, error) {
data := builder.GetStruct(b).(deleteData)
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b DeleteBuilder) MustSql() (string, []interface{}) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// Prefix adds an expression to the beginning of the query
func (b DeleteBuilder) Prefix(sql string, args ...interface{}) DeleteBuilder {
return b.PrefixExpr(Expr(sql, args...))
}
// PrefixExpr adds an expression to the very beginning of the query
func (b DeleteBuilder) PrefixExpr(expr Sqlizer) DeleteBuilder {
return builder.Append(b, "Prefixes", expr).(DeleteBuilder)
}
// From sets the table to be deleted from.
func (b DeleteBuilder) From(from string) DeleteBuilder {
return builder.Set(b, "From", from).(DeleteBuilder)
}
// Where adds WHERE expressions to the query.
//
// See SelectBuilder.Where for more information.
func (b DeleteBuilder) Where(pred interface{}, args ...interface{}) DeleteBuilder {
return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(DeleteBuilder)
}
// OrderBy adds ORDER BY expressions to the query.
func (b DeleteBuilder) OrderBy(orderBys ...string) DeleteBuilder {
return builder.Extend(b, "OrderBys", orderBys).(DeleteBuilder)
}
// Limit sets a LIMIT clause on the query.
func (b DeleteBuilder) Limit(limit uint64) DeleteBuilder {
return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(DeleteBuilder)
}
// Offset sets a OFFSET clause on the query.
func (b DeleteBuilder) Offset(offset uint64) DeleteBuilder {
return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(DeleteBuilder)
}
// Suffix adds an expression to the end of the query
func (b DeleteBuilder) Suffix(sql string, args ...interface{}) DeleteBuilder {
return b.SuffixExpr(Expr(sql, args...))
}
// SuffixExpr adds an expression to the end of the query
func (b DeleteBuilder) SuffixExpr(expr Sqlizer) DeleteBuilder {
return builder.Append(b, "Suffixes", expr).(DeleteBuilder)
}
func (b DeleteBuilder) Query() (*sql.Rows, error) {
data := builder.GetStruct(b).(deleteData)
return data.Query()
}
func (d *deleteData) Query() (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
return QueryWith(d.RunWith, d)
}

View File

@ -1,69 +0,0 @@
// +build go1.8
package squirrel
import (
"context"
"database/sql"
"github.com/lann/builder"
)
func (d *deleteData) ExecContext(ctx context.Context) (sql.Result, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(ExecerContext)
if !ok {
return nil, NoContextSupport
}
return ExecContextWith(ctx, ctxRunner, d)
}
func (d *deleteData) QueryContext(ctx context.Context) (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(QueryerContext)
if !ok {
return nil, NoContextSupport
}
return QueryContextWith(ctx, ctxRunner, d)
}
func (d *deleteData) QueryRowContext(ctx context.Context) RowScanner {
if d.RunWith == nil {
return &Row{err: RunnerNotSet}
}
queryRower, ok := d.RunWith.(QueryRowerContext)
if !ok {
if _, ok := d.RunWith.(QueryerContext); !ok {
return &Row{err: RunnerNotQueryRunner}
}
return &Row{err: NoContextSupport}
}
return QueryRowContextWith(ctx, queryRower, d)
}
// ExecContext builds and ExecContexts the query with the Runner set by RunWith.
func (b DeleteBuilder) ExecContext(ctx context.Context) (sql.Result, error) {
data := builder.GetStruct(b).(deleteData)
return data.ExecContext(ctx)
}
// QueryContext builds and QueryContexts the query with the Runner set by RunWith.
func (b DeleteBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) {
data := builder.GetStruct(b).(deleteData)
return data.QueryContext(ctx)
}
// QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith.
func (b DeleteBuilder) QueryRowContext(ctx context.Context) RowScanner {
data := builder.GetStruct(b).(deleteData)
return data.QueryRowContext(ctx)
}
// ScanContext is a shortcut for QueryRowContext().Scan.
func (b DeleteBuilder) ScanContext(ctx context.Context, dest ...interface{}) error {
return b.QueryRowContext(ctx).Scan(dest...)
}

View File

@ -1,419 +0,0 @@
package squirrel
import (
"bytes"
"database/sql/driver"
"fmt"
"reflect"
"sort"
"strings"
)
const (
// Portable true/false literals.
sqlTrue = "(1=1)"
sqlFalse = "(1=0)"
)
type expr struct {
sql string
args []interface{}
}
// Expr builds an expression from a SQL fragment and arguments.
//
// Ex:
// Expr("FROM_UNIXTIME(?)", t)
func Expr(sql string, args ...interface{}) Sqlizer {
return expr{sql: sql, args: args}
}
func (e expr) ToSql() (sql string, args []interface{}, err error) {
simple := true
for _, arg := range e.args {
if _, ok := arg.(Sqlizer); ok {
simple = false
}
}
if simple {
return e.sql, e.args, nil
}
buf := &bytes.Buffer{}
ap := e.args
sp := e.sql
var isql string
var iargs []interface{}
for err == nil && len(ap) > 0 && len(sp) > 0 {
i := strings.Index(sp, "?")
if i < 0 {
// no more placeholders
break
}
if len(sp) > i+1 && sp[i+1:i+2] == "?" {
// escaped "??"; append it and step past
buf.WriteString(sp[:i+2])
sp = sp[i+2:]
continue
}
if as, ok := ap[0].(Sqlizer); ok {
// sqlizer argument; expand it and append the result
isql, iargs, err = as.ToSql()
buf.WriteString(sp[:i])
buf.WriteString(isql)
args = append(args, iargs...)
} else {
// normal argument; append it and the placeholder
buf.WriteString(sp[:i+1])
args = append(args, ap[0])
}
// step past the argument and placeholder
ap = ap[1:]
sp = sp[i+1:]
}
// append the remaining sql and arguments
buf.WriteString(sp)
return buf.String(), append(args, ap...), err
}
type concatExpr []interface{}
func (ce concatExpr) ToSql() (sql string, args []interface{}, err error) {
for _, part := range ce {
switch p := part.(type) {
case string:
sql += p
case Sqlizer:
pSql, pArgs, err := p.ToSql()
if err != nil {
return "", nil, err
}
sql += pSql
args = append(args, pArgs...)
default:
return "", nil, fmt.Errorf("%#v is not a string or Sqlizer", part)
}
}
return
}
// ConcatExpr builds an expression by concatenating strings and other expressions.
//
// Ex:
// name_expr := Expr("CONCAT(?, ' ', ?)", firstName, lastName)
// ConcatExpr("COALESCE(full_name,", name_expr, ")")
func ConcatExpr(parts ...interface{}) concatExpr {
return concatExpr(parts)
}
// aliasExpr helps to alias part of SQL query generated with underlying "expr"
type aliasExpr struct {
expr Sqlizer
alias string
}
// Alias allows to define alias for column in SelectBuilder. Useful when column is
// defined as complex expression like IF or CASE
// Ex:
// .Column(Alias(caseStmt, "case_column"))
func Alias(expr Sqlizer, alias string) aliasExpr {
return aliasExpr{expr, alias}
}
func (e aliasExpr) ToSql() (sql string, args []interface{}, err error) {
sql, args, err = e.expr.ToSql()
if err == nil {
sql = fmt.Sprintf("(%s) AS %s", sql, e.alias)
}
return
}
// Eq is syntactic sugar for use with Where/Having/Set methods.
type Eq map[string]interface{}
func (eq Eq) toSQL(useNotOpr bool) (sql string, args []interface{}, err error) {
if len(eq) == 0 {
// Empty Sql{} evaluates to true.
sql = sqlTrue
return
}
var (
exprs []string
equalOpr = "="
inOpr = "IN"
nullOpr = "IS"
inEmptyExpr = sqlFalse
)
if useNotOpr {
equalOpr = "<>"
inOpr = "NOT IN"
nullOpr = "IS NOT"
inEmptyExpr = sqlTrue
}
sortedKeys := getSortedKeys(eq)
for _, key := range sortedKeys {
var expr string
val := eq[key]
switch v := val.(type) {
case driver.Valuer:
if val, err = v.Value(); err != nil {
return
}
}
r := reflect.ValueOf(val)
if r.Kind() == reflect.Ptr {
if r.IsNil() {
val = nil
} else {
val = r.Elem().Interface()
}
}
if val == nil {
expr = fmt.Sprintf("%s %s NULL", key, nullOpr)
} else {
if isListType(val) {
valVal := reflect.ValueOf(val)
if valVal.Len() == 0 {
expr = inEmptyExpr
if args == nil {
args = []interface{}{}
}
} else {
for i := 0; i < valVal.Len(); i++ {
args = append(args, valVal.Index(i).Interface())
}
expr = fmt.Sprintf("%s %s (%s)", key, inOpr, Placeholders(valVal.Len()))
}
} else {
expr = fmt.Sprintf("%s %s ?", key, equalOpr)
args = append(args, val)
}
}
exprs = append(exprs, expr)
}
sql = strings.Join(exprs, " AND ")
return
}
func (eq Eq) ToSql() (sql string, args []interface{}, err error) {
return eq.toSQL(false)
}
// NotEq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(NotEq{"id": 1}) == "id <> 1"
type NotEq Eq
func (neq NotEq) ToSql() (sql string, args []interface{}, err error) {
return Eq(neq).toSQL(true)
}
// Like is syntactic sugar for use with LIKE conditions.
// Ex:
// .Where(Like{"name": "%irrel"})
type Like map[string]interface{}
func (lk Like) toSql(opr string) (sql string, args []interface{}, err error) {
var exprs []string
for key, val := range lk {
expr := ""
switch v := val.(type) {
case driver.Valuer:
if val, err = v.Value(); err != nil {
return
}
}
if val == nil {
err = fmt.Errorf("cannot use null with like operators")
return
} else {
if isListType(val) {
err = fmt.Errorf("cannot use array or slice with like operators")
return
} else {
expr = fmt.Sprintf("%s %s ?", key, opr)
args = append(args, val)
}
}
exprs = append(exprs, expr)
}
sql = strings.Join(exprs, " AND ")
return
}
func (lk Like) ToSql() (sql string, args []interface{}, err error) {
return lk.toSql("LIKE")
}
// NotLike is syntactic sugar for use with LIKE conditions.
// Ex:
// .Where(NotLike{"name": "%irrel"})
type NotLike Like
func (nlk NotLike) ToSql() (sql string, args []interface{}, err error) {
return Like(nlk).toSql("NOT LIKE")
}
// ILike is syntactic sugar for use with ILIKE conditions.
// Ex:
// .Where(ILike{"name": "sq%"})
type ILike Like
func (ilk ILike) ToSql() (sql string, args []interface{}, err error) {
return Like(ilk).toSql("ILIKE")
}
// NotILike is syntactic sugar for use with ILIKE conditions.
// Ex:
// .Where(NotILike{"name": "sq%"})
type NotILike Like
func (nilk NotILike) ToSql() (sql string, args []interface{}, err error) {
return Like(nilk).toSql("NOT ILIKE")
}
// Lt is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(Lt{"id": 1})
type Lt map[string]interface{}
func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err error) {
var (
exprs []string
opr = "<"
)
if opposite {
opr = ">"
}
if orEq {
opr = fmt.Sprintf("%s%s", opr, "=")
}
sortedKeys := getSortedKeys(lt)
for _, key := range sortedKeys {
var expr string
val := lt[key]
switch v := val.(type) {
case driver.Valuer:
if val, err = v.Value(); err != nil {
return
}
}
if val == nil {
err = fmt.Errorf("cannot use null with less than or greater than operators")
return
}
if isListType(val) {
err = fmt.Errorf("cannot use array or slice with less than or greater than operators")
return
}
expr = fmt.Sprintf("%s %s ?", key, opr)
args = append(args, val)
exprs = append(exprs, expr)
}
sql = strings.Join(exprs, " AND ")
return
}
func (lt Lt) ToSql() (sql string, args []interface{}, err error) {
return lt.toSql(false, false)
}
// LtOrEq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(LtOrEq{"id": 1}) == "id <= 1"
type LtOrEq Lt
func (ltOrEq LtOrEq) ToSql() (sql string, args []interface{}, err error) {
return Lt(ltOrEq).toSql(false, true)
}
// Gt is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(Gt{"id": 1}) == "id > 1"
type Gt Lt
func (gt Gt) ToSql() (sql string, args []interface{}, err error) {
return Lt(gt).toSql(true, false)
}
// GtOrEq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(GtOrEq{"id": 1}) == "id >= 1"
type GtOrEq Lt
func (gtOrEq GtOrEq) ToSql() (sql string, args []interface{}, err error) {
return Lt(gtOrEq).toSql(true, true)
}
type conj []Sqlizer
func (c conj) join(sep, defaultExpr string) (sql string, args []interface{}, err error) {
if len(c) == 0 {
return defaultExpr, []interface{}{}, nil
}
var sqlParts []string
for _, sqlizer := range c {
partSQL, partArgs, err := nestedToSql(sqlizer)
if err != nil {
return "", nil, err
}
if partSQL != "" {
sqlParts = append(sqlParts, partSQL)
args = append(args, partArgs...)
}
}
if len(sqlParts) > 0 {
sql = fmt.Sprintf("(%s)", strings.Join(sqlParts, sep))
}
return
}
// And conjunction Sqlizers
type And conj
func (a And) ToSql() (string, []interface{}, error) {
return conj(a).join(" AND ", sqlTrue)
}
// Or conjunction Sqlizers
type Or conj
func (o Or) ToSql() (string, []interface{}, error) {
return conj(o).join(" OR ", sqlFalse)
}
func getSortedKeys(exp map[string]interface{}) []string {
sortedKeys := make([]string, 0, len(exp))
for k := range exp {
sortedKeys = append(sortedKeys, k)
}
sort.Strings(sortedKeys)
return sortedKeys
}
func isListType(val interface{}) bool {
if driver.IsValue(val) {
return false
}
valVal := reflect.ValueOf(val)
return valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice
}

View File

@ -1,298 +0,0 @@
package squirrel
import (
"bytes"
"database/sql"
"errors"
"fmt"
"io"
"sort"
"strings"
"github.com/lann/builder"
)
type insertData struct {
PlaceholderFormat PlaceholderFormat
RunWith BaseRunner
Prefixes []Sqlizer
StatementKeyword string
Options []string
Into string
Columns []string
Values [][]interface{}
Suffixes []Sqlizer
Select *SelectBuilder
}
func (d *insertData) Exec() (sql.Result, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
return ExecWith(d.RunWith, d)
}
func (d *insertData) Query() (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
return QueryWith(d.RunWith, d)
}
func (d *insertData) QueryRow() RowScanner {
if d.RunWith == nil {
return &Row{err: RunnerNotSet}
}
queryRower, ok := d.RunWith.(QueryRower)
if !ok {
return &Row{err: RunnerNotQueryRunner}
}
return QueryRowWith(queryRower, d)
}
func (d *insertData) ToSql() (sqlStr string, args []interface{}, err error) {
if len(d.Into) == 0 {
err = errors.New("insert statements must specify a table")
return
}
if len(d.Values) == 0 && d.Select == nil {
err = errors.New("insert statements must have at least one set of values or select clause")
return
}
sql := &bytes.Buffer{}
if len(d.Prefixes) > 0 {
args, err = appendToSql(d.Prefixes, sql, " ", args)
if err != nil {
return
}
sql.WriteString(" ")
}
if d.StatementKeyword == "" {
sql.WriteString("INSERT ")
} else {
sql.WriteString(d.StatementKeyword)
sql.WriteString(" ")
}
if len(d.Options) > 0 {
sql.WriteString(strings.Join(d.Options, " "))
sql.WriteString(" ")
}
sql.WriteString("INTO ")
sql.WriteString(d.Into)
sql.WriteString(" ")
if len(d.Columns) > 0 {
sql.WriteString("(")
sql.WriteString(strings.Join(d.Columns, ","))
sql.WriteString(") ")
}
if d.Select != nil {
args, err = d.appendSelectToSQL(sql, args)
} else {
args, err = d.appendValuesToSQL(sql, args)
}
if err != nil {
return
}
if len(d.Suffixes) > 0 {
sql.WriteString(" ")
args, err = appendToSql(d.Suffixes, sql, " ", args)
if err != nil {
return
}
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
return
}
func (d *insertData) appendValuesToSQL(w io.Writer, args []interface{}) ([]interface{}, error) {
if len(d.Values) == 0 {
return args, errors.New("values for insert statements are not set")
}
io.WriteString(w, "VALUES ")
valuesStrings := make([]string, len(d.Values))
for r, row := range d.Values {
valueStrings := make([]string, len(row))
for v, val := range row {
if vs, ok := val.(Sqlizer); ok {
vsql, vargs, err := vs.ToSql()
if err != nil {
return nil, err
}
valueStrings[v] = vsql
args = append(args, vargs...)
} else {
valueStrings[v] = "?"
args = append(args, val)
}
}
valuesStrings[r] = fmt.Sprintf("(%s)", strings.Join(valueStrings, ","))
}
io.WriteString(w, strings.Join(valuesStrings, ","))
return args, nil
}
func (d *insertData) appendSelectToSQL(w io.Writer, args []interface{}) ([]interface{}, error) {
if d.Select == nil {
return args, errors.New("select clause for insert statements are not set")
}
selectClause, sArgs, err := d.Select.ToSql()
if err != nil {
return args, err
}
io.WriteString(w, selectClause)
args = append(args, sArgs...)
return args, nil
}
// Builder
// InsertBuilder builds SQL INSERT statements.
type InsertBuilder builder.Builder
func init() {
builder.Register(InsertBuilder{}, insertData{})
}
// Format methods
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b InsertBuilder) PlaceholderFormat(f PlaceholderFormat) InsertBuilder {
return builder.Set(b, "PlaceholderFormat", f).(InsertBuilder)
}
// Runner methods
// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec.
func (b InsertBuilder) RunWith(runner BaseRunner) InsertBuilder {
return setRunWith(b, runner).(InsertBuilder)
}
// Exec builds and Execs the query with the Runner set by RunWith.
func (b InsertBuilder) Exec() (sql.Result, error) {
data := builder.GetStruct(b).(insertData)
return data.Exec()
}
// Query builds and Querys the query with the Runner set by RunWith.
func (b InsertBuilder) Query() (*sql.Rows, error) {
data := builder.GetStruct(b).(insertData)
return data.Query()
}
// QueryRow builds and QueryRows the query with the Runner set by RunWith.
func (b InsertBuilder) QueryRow() RowScanner {
data := builder.GetStruct(b).(insertData)
return data.QueryRow()
}
// Scan is a shortcut for QueryRow().Scan.
func (b InsertBuilder) Scan(dest ...interface{}) error {
return b.QueryRow().Scan(dest...)
}
// SQL methods
// ToSql builds the query into a SQL string and bound args.
func (b InsertBuilder) ToSql() (string, []interface{}, error) {
data := builder.GetStruct(b).(insertData)
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b InsertBuilder) MustSql() (string, []interface{}) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// Prefix adds an expression to the beginning of the query
func (b InsertBuilder) Prefix(sql string, args ...interface{}) InsertBuilder {
return b.PrefixExpr(Expr(sql, args...))
}
// PrefixExpr adds an expression to the very beginning of the query
func (b InsertBuilder) PrefixExpr(expr Sqlizer) InsertBuilder {
return builder.Append(b, "Prefixes", expr).(InsertBuilder)
}
// Options adds keyword options before the INTO clause of the query.
func (b InsertBuilder) Options(options ...string) InsertBuilder {
return builder.Extend(b, "Options", options).(InsertBuilder)
}
// Into sets the INTO clause of the query.
func (b InsertBuilder) Into(from string) InsertBuilder {
return builder.Set(b, "Into", from).(InsertBuilder)
}
// Columns adds insert columns to the query.
func (b InsertBuilder) Columns(columns ...string) InsertBuilder {
return builder.Extend(b, "Columns", columns).(InsertBuilder)
}
// Values adds a single row's values to the query.
func (b InsertBuilder) Values(values ...interface{}) InsertBuilder {
return builder.Append(b, "Values", values).(InsertBuilder)
}
// Suffix adds an expression to the end of the query
func (b InsertBuilder) Suffix(sql string, args ...interface{}) InsertBuilder {
return b.SuffixExpr(Expr(sql, args...))
}
// SuffixExpr adds an expression to the end of the query
func (b InsertBuilder) SuffixExpr(expr Sqlizer) InsertBuilder {
return builder.Append(b, "Suffixes", expr).(InsertBuilder)
}
// SetMap set columns and values for insert builder from a map of column name and value
// note that it will reset all previous columns and values was set if any
func (b InsertBuilder) SetMap(clauses map[string]interface{}) InsertBuilder {
// Keep the columns in a consistent order by sorting the column key string.
cols := make([]string, 0, len(clauses))
for col := range clauses {
cols = append(cols, col)
}
sort.Strings(cols)
vals := make([]interface{}, 0, len(clauses))
for _, col := range cols {
vals = append(vals, clauses[col])
}
b = builder.Set(b, "Columns", cols).(InsertBuilder)
b = builder.Set(b, "Values", [][]interface{}{vals}).(InsertBuilder)
return b
}
// Select set Select clause for insert query
// If Values and Select are used, then Select has higher priority
func (b InsertBuilder) Select(sb SelectBuilder) InsertBuilder {
return builder.Set(b, "Select", &sb).(InsertBuilder)
}
func (b InsertBuilder) statementKeyword(keyword string) InsertBuilder {
return builder.Set(b, "StatementKeyword", keyword).(InsertBuilder)
}

View File

@ -1,69 +0,0 @@
// +build go1.8
package squirrel
import (
"context"
"database/sql"
"github.com/lann/builder"
)
func (d *insertData) ExecContext(ctx context.Context) (sql.Result, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(ExecerContext)
if !ok {
return nil, NoContextSupport
}
return ExecContextWith(ctx, ctxRunner, d)
}
func (d *insertData) QueryContext(ctx context.Context) (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(QueryerContext)
if !ok {
return nil, NoContextSupport
}
return QueryContextWith(ctx, ctxRunner, d)
}
func (d *insertData) QueryRowContext(ctx context.Context) RowScanner {
if d.RunWith == nil {
return &Row{err: RunnerNotSet}
}
queryRower, ok := d.RunWith.(QueryRowerContext)
if !ok {
if _, ok := d.RunWith.(QueryerContext); !ok {
return &Row{err: RunnerNotQueryRunner}
}
return &Row{err: NoContextSupport}
}
return QueryRowContextWith(ctx, queryRower, d)
}
// ExecContext builds and ExecContexts the query with the Runner set by RunWith.
func (b InsertBuilder) ExecContext(ctx context.Context) (sql.Result, error) {
data := builder.GetStruct(b).(insertData)
return data.ExecContext(ctx)
}
// QueryContext builds and QueryContexts the query with the Runner set by RunWith.
func (b InsertBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) {
data := builder.GetStruct(b).(insertData)
return data.QueryContext(ctx)
}
// QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith.
func (b InsertBuilder) QueryRowContext(ctx context.Context) RowScanner {
data := builder.GetStruct(b).(insertData)
return data.QueryRowContext(ctx)
}
// ScanContext is a shortcut for QueryRowContext().Scan.
func (b InsertBuilder) ScanContext(ctx context.Context, dest ...interface{}) error {
return b.QueryRowContext(ctx).Scan(dest...)
}

View File

@ -1,63 +0,0 @@
package squirrel
import (
"fmt"
"io"
)
type part struct {
pred interface{}
args []interface{}
}
func newPart(pred interface{}, args ...interface{}) Sqlizer {
return &part{pred, args}
}
func (p part) ToSql() (sql string, args []interface{}, err error) {
switch pred := p.pred.(type) {
case nil:
// no-op
case Sqlizer:
sql, args, err = nestedToSql(pred)
case string:
sql = pred
args = p.args
default:
err = fmt.Errorf("expected string or Sqlizer, not %T", pred)
}
return
}
func nestedToSql(s Sqlizer) (string, []interface{}, error) {
if raw, ok := s.(rawSqlizer); ok {
return raw.toSqlRaw()
} else {
return s.ToSql()
}
}
func appendToSql(parts []Sqlizer, w io.Writer, sep string, args []interface{}) ([]interface{}, error) {
for i, p := range parts {
partSql, partArgs, err := nestedToSql(p)
if err != nil {
return nil, err
} else if len(partSql) == 0 {
continue
}
if i > 0 {
_, err := io.WriteString(w, sep)
if err != nil {
return nil, err
}
}
_, err = io.WriteString(w, partSql)
if err != nil {
return nil, err
}
args = append(args, partArgs...)
}
return args, nil
}

View File

@ -1,114 +0,0 @@
package squirrel
import (
"bytes"
"fmt"
"strings"
)
// PlaceholderFormat is the interface that wraps the ReplacePlaceholders method.
//
// ReplacePlaceholders takes a SQL statement and replaces each question mark
// placeholder with a (possibly different) SQL placeholder.
type PlaceholderFormat interface {
ReplacePlaceholders(sql string) (string, error)
}
type placeholderDebugger interface {
debugPlaceholder() string
}
var (
// Question is a PlaceholderFormat instance that leaves placeholders as
// question marks.
Question = questionFormat{}
// Dollar is a PlaceholderFormat instance that replaces placeholders with
// dollar-prefixed positional placeholders (e.g. $1, $2, $3).
Dollar = dollarFormat{}
// Colon is a PlaceholderFormat instance that replaces placeholders with
// colon-prefixed positional placeholders (e.g. :1, :2, :3).
Colon = colonFormat{}
// AtP is a PlaceholderFormat instance that replaces placeholders with
// "@p"-prefixed positional placeholders (e.g. @p1, @p2, @p3).
AtP = atpFormat{}
)
type questionFormat struct{}
func (questionFormat) ReplacePlaceholders(sql string) (string, error) {
return sql, nil
}
func (questionFormat) debugPlaceholder() string {
return "?"
}
type dollarFormat struct{}
func (dollarFormat) ReplacePlaceholders(sql string) (string, error) {
return replacePositionalPlaceholders(sql, "$")
}
func (dollarFormat) debugPlaceholder() string {
return "$"
}
type colonFormat struct{}
func (colonFormat) ReplacePlaceholders(sql string) (string, error) {
return replacePositionalPlaceholders(sql, ":")
}
func (colonFormat) debugPlaceholder() string {
return ":"
}
type atpFormat struct{}
func (atpFormat) ReplacePlaceholders(sql string) (string, error) {
return replacePositionalPlaceholders(sql, "@p")
}
func (atpFormat) debugPlaceholder() string {
return "@p"
}
// Placeholders returns a string with count ? placeholders joined with commas.
func Placeholders(count int) string {
if count < 1 {
return ""
}
return strings.Repeat(",?", count)[1:]
}
func replacePositionalPlaceholders(sql, prefix string) (string, error) {
buf := &bytes.Buffer{}
i := 0
for {
p := strings.Index(sql, "?")
if p == -1 {
break
}
if len(sql[p:]) > 1 && sql[p:p+2] == "??" { // escape ?? => ?
buf.WriteString(sql[:p])
buf.WriteString("?")
if len(sql[p:]) == 1 {
break
}
sql = sql[p+2:]
} else {
i++
buf.WriteString(sql[:p])
fmt.Fprintf(buf, "%s%d", prefix, i)
sql = sql[p+1:]
}
}
buf.WriteString(sql)
return buf.String(), nil
}

View File

@ -1,22 +0,0 @@
package squirrel
// RowScanner is the interface that wraps the Scan method.
//
// Scan behaves like database/sql.Row.Scan.
type RowScanner interface {
Scan(...interface{}) error
}
// Row wraps database/sql.Row to let squirrel return new errors on Scan.
type Row struct {
RowScanner
err error
}
// Scan returns Row.err or calls RowScanner.Scan.
func (r *Row) Scan(dest ...interface{}) error {
if r.err != nil {
return r.err
}
return r.RowScanner.Scan(dest...)
}

View File

@ -1,403 +0,0 @@
package squirrel
import (
"bytes"
"database/sql"
"fmt"
"strings"
"github.com/lann/builder"
)
type selectData struct {
PlaceholderFormat PlaceholderFormat
RunWith BaseRunner
Prefixes []Sqlizer
Options []string
Columns []Sqlizer
From Sqlizer
Joins []Sqlizer
WhereParts []Sqlizer
GroupBys []string
HavingParts []Sqlizer
OrderByParts []Sqlizer
Limit string
Offset string
Suffixes []Sqlizer
}
func (d *selectData) Exec() (sql.Result, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
return ExecWith(d.RunWith, d)
}
func (d *selectData) Query() (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
return QueryWith(d.RunWith, d)
}
func (d *selectData) QueryRow() RowScanner {
if d.RunWith == nil {
return &Row{err: RunnerNotSet}
}
queryRower, ok := d.RunWith.(QueryRower)
if !ok {
return &Row{err: RunnerNotQueryRunner}
}
return QueryRowWith(queryRower, d)
}
func (d *selectData) ToSql() (sqlStr string, args []interface{}, err error) {
sqlStr, args, err = d.toSqlRaw()
if err != nil {
return
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sqlStr)
return
}
func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) {
if len(d.Columns) == 0 {
err = fmt.Errorf("select statements must have at least one result column")
return
}
sql := &bytes.Buffer{}
if len(d.Prefixes) > 0 {
args, err = appendToSql(d.Prefixes, sql, " ", args)
if err != nil {
return
}
sql.WriteString(" ")
}
sql.WriteString("SELECT ")
if len(d.Options) > 0 {
sql.WriteString(strings.Join(d.Options, " "))
sql.WriteString(" ")
}
if len(d.Columns) > 0 {
args, err = appendToSql(d.Columns, sql, ", ", args)
if err != nil {
return
}
}
if d.From != nil {
sql.WriteString(" FROM ")
args, err = appendToSql([]Sqlizer{d.From}, sql, "", args)
if err != nil {
return
}
}
if len(d.Joins) > 0 {
sql.WriteString(" ")
args, err = appendToSql(d.Joins, sql, " ", args)
if err != nil {
return
}
}
if len(d.WhereParts) > 0 {
sql.WriteString(" WHERE ")
args, err = appendToSql(d.WhereParts, sql, " AND ", args)
if err != nil {
return
}
}
if len(d.GroupBys) > 0 {
sql.WriteString(" GROUP BY ")
sql.WriteString(strings.Join(d.GroupBys, ", "))
}
if len(d.HavingParts) > 0 {
sql.WriteString(" HAVING ")
args, err = appendToSql(d.HavingParts, sql, " AND ", args)
if err != nil {
return
}
}
if len(d.OrderByParts) > 0 {
sql.WriteString(" ORDER BY ")
args, err = appendToSql(d.OrderByParts, sql, ", ", args)
if err != nil {
return
}
}
if len(d.Limit) > 0 {
sql.WriteString(" LIMIT ")
sql.WriteString(d.Limit)
}
if len(d.Offset) > 0 {
sql.WriteString(" OFFSET ")
sql.WriteString(d.Offset)
}
if len(d.Suffixes) > 0 {
sql.WriteString(" ")
args, err = appendToSql(d.Suffixes, sql, " ", args)
if err != nil {
return
}
}
sqlStr = sql.String()
return
}
// Builder
// SelectBuilder builds SQL SELECT statements.
type SelectBuilder builder.Builder
func init() {
builder.Register(SelectBuilder{}, selectData{})
}
// Format methods
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b SelectBuilder) PlaceholderFormat(f PlaceholderFormat) SelectBuilder {
return builder.Set(b, "PlaceholderFormat", f).(SelectBuilder)
}
// Runner methods
// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec.
// For most cases runner will be a database connection.
//
// Internally we use this to mock out the database connection for testing.
func (b SelectBuilder) RunWith(runner BaseRunner) SelectBuilder {
return setRunWith(b, runner).(SelectBuilder)
}
// Exec builds and Execs the query with the Runner set by RunWith.
func (b SelectBuilder) Exec() (sql.Result, error) {
data := builder.GetStruct(b).(selectData)
return data.Exec()
}
// Query builds and Querys the query with the Runner set by RunWith.
func (b SelectBuilder) Query() (*sql.Rows, error) {
data := builder.GetStruct(b).(selectData)
return data.Query()
}
// QueryRow builds and QueryRows the query with the Runner set by RunWith.
func (b SelectBuilder) QueryRow() RowScanner {
data := builder.GetStruct(b).(selectData)
return data.QueryRow()
}
// Scan is a shortcut for QueryRow().Scan.
func (b SelectBuilder) Scan(dest ...interface{}) error {
return b.QueryRow().Scan(dest...)
}
// SQL methods
// ToSql builds the query into a SQL string and bound args.
func (b SelectBuilder) ToSql() (string, []interface{}, error) {
data := builder.GetStruct(b).(selectData)
return data.ToSql()
}
func (b SelectBuilder) toSqlRaw() (string, []interface{}, error) {
data := builder.GetStruct(b).(selectData)
return data.toSqlRaw()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b SelectBuilder) MustSql() (string, []interface{}) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// Prefix adds an expression to the beginning of the query
func (b SelectBuilder) Prefix(sql string, args ...interface{}) SelectBuilder {
return b.PrefixExpr(Expr(sql, args...))
}
// PrefixExpr adds an expression to the very beginning of the query
func (b SelectBuilder) PrefixExpr(expr Sqlizer) SelectBuilder {
return builder.Append(b, "Prefixes", expr).(SelectBuilder)
}
// Distinct adds a DISTINCT clause to the query.
func (b SelectBuilder) Distinct() SelectBuilder {
return b.Options("DISTINCT")
}
// Options adds select option to the query
func (b SelectBuilder) Options(options ...string) SelectBuilder {
return builder.Extend(b, "Options", options).(SelectBuilder)
}
// Columns adds result columns to the query.
func (b SelectBuilder) Columns(columns ...string) SelectBuilder {
parts := make([]interface{}, 0, len(columns))
for _, str := range columns {
parts = append(parts, newPart(str))
}
return builder.Extend(b, "Columns", parts).(SelectBuilder)
}
// RemoveColumns remove all columns from query.
// Must add a new column with Column or Columns methods, otherwise
// return a error.
func (b SelectBuilder) RemoveColumns() SelectBuilder {
return builder.Delete(b, "Columns").(SelectBuilder)
}
// Column adds a result column to the query.
// Unlike Columns, Column accepts args which will be bound to placeholders in
// the columns string, for example:
// Column("IF(col IN ("+squirrel.Placeholders(3)+"), 1, 0) as col", 1, 2, 3)
func (b SelectBuilder) Column(column interface{}, args ...interface{}) SelectBuilder {
return builder.Append(b, "Columns", newPart(column, args...)).(SelectBuilder)
}
// From sets the FROM clause of the query.
func (b SelectBuilder) From(from string) SelectBuilder {
return builder.Set(b, "From", newPart(from)).(SelectBuilder)
}
// FromSelect sets a subquery into the FROM clause of the query.
func (b SelectBuilder) FromSelect(from SelectBuilder, alias string) SelectBuilder {
// Prevent misnumbered parameters in nested selects (#183).
from = from.PlaceholderFormat(Question)
return builder.Set(b, "From", Alias(from, alias)).(SelectBuilder)
}
// JoinClause adds a join clause to the query.
func (b SelectBuilder) JoinClause(pred interface{}, args ...interface{}) SelectBuilder {
return builder.Append(b, "Joins", newPart(pred, args...)).(SelectBuilder)
}
// Join adds a JOIN clause to the query.
func (b SelectBuilder) Join(join string, rest ...interface{}) SelectBuilder {
return b.JoinClause("JOIN "+join, rest...)
}
// LeftJoin adds a LEFT JOIN clause to the query.
func (b SelectBuilder) LeftJoin(join string, rest ...interface{}) SelectBuilder {
return b.JoinClause("LEFT JOIN "+join, rest...)
}
// RightJoin adds a RIGHT JOIN clause to the query.
func (b SelectBuilder) RightJoin(join string, rest ...interface{}) SelectBuilder {
return b.JoinClause("RIGHT JOIN "+join, rest...)
}
// InnerJoin adds a INNER JOIN clause to the query.
func (b SelectBuilder) InnerJoin(join string, rest ...interface{}) SelectBuilder {
return b.JoinClause("INNER JOIN "+join, rest...)
}
// CrossJoin adds a CROSS JOIN clause to the query.
func (b SelectBuilder) CrossJoin(join string, rest ...interface{}) SelectBuilder {
return b.JoinClause("CROSS JOIN "+join, rest...)
}
// Where adds an expression to the WHERE clause of the query.
//
// Expressions are ANDed together in the generated SQL.
//
// Where accepts several types for its pred argument:
//
// nil OR "" - ignored.
//
// string - SQL expression.
// If the expression has SQL placeholders then a set of arguments must be passed
// as well, one for each placeholder.
//
// map[string]interface{} OR Eq - map of SQL expressions to values. Each key is
// transformed into an expression like "<key> = ?", with the corresponding value
// bound to the placeholder. If the value is nil, the expression will be "<key>
// IS NULL". If the value is an array or slice, the expression will be "<key> IN
// (?,?,...)", with one placeholder for each item in the value. These expressions
// are ANDed together.
//
// Where will panic if pred isn't any of the above types.
func (b SelectBuilder) Where(pred interface{}, args ...interface{}) SelectBuilder {
if pred == nil || pred == "" {
return b
}
return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(SelectBuilder)
}
// GroupBy adds GROUP BY expressions to the query.
func (b SelectBuilder) GroupBy(groupBys ...string) SelectBuilder {
return builder.Extend(b, "GroupBys", groupBys).(SelectBuilder)
}
// Having adds an expression to the HAVING clause of the query.
//
// See Where.
func (b SelectBuilder) Having(pred interface{}, rest ...interface{}) SelectBuilder {
return builder.Append(b, "HavingParts", newWherePart(pred, rest...)).(SelectBuilder)
}
// OrderByClause adds ORDER BY clause to the query.
func (b SelectBuilder) OrderByClause(pred interface{}, args ...interface{}) SelectBuilder {
return builder.Append(b, "OrderByParts", newPart(pred, args...)).(SelectBuilder)
}
// OrderBy adds ORDER BY expressions to the query.
func (b SelectBuilder) OrderBy(orderBys ...string) SelectBuilder {
for _, orderBy := range orderBys {
b = b.OrderByClause(orderBy)
}
return b
}
// Limit sets a LIMIT clause on the query.
func (b SelectBuilder) Limit(limit uint64) SelectBuilder {
return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(SelectBuilder)
}
// Limit ALL allows to access all records with limit
func (b SelectBuilder) RemoveLimit() SelectBuilder {
return builder.Delete(b, "Limit").(SelectBuilder)
}
// Offset sets a OFFSET clause on the query.
func (b SelectBuilder) Offset(offset uint64) SelectBuilder {
return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(SelectBuilder)
}
// RemoveOffset removes OFFSET clause.
func (b SelectBuilder) RemoveOffset() SelectBuilder {
return builder.Delete(b, "Offset").(SelectBuilder)
}
// Suffix adds an expression to the end of the query
func (b SelectBuilder) Suffix(sql string, args ...interface{}) SelectBuilder {
return b.SuffixExpr(Expr(sql, args...))
}
// SuffixExpr adds an expression to the end of the query
func (b SelectBuilder) SuffixExpr(expr Sqlizer) SelectBuilder {
return builder.Append(b, "Suffixes", expr).(SelectBuilder)
}

View File

@ -1,69 +0,0 @@
// +build go1.8
package squirrel
import (
"context"
"database/sql"
"github.com/lann/builder"
)
func (d *selectData) ExecContext(ctx context.Context) (sql.Result, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(ExecerContext)
if !ok {
return nil, NoContextSupport
}
return ExecContextWith(ctx, ctxRunner, d)
}
func (d *selectData) QueryContext(ctx context.Context) (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(QueryerContext)
if !ok {
return nil, NoContextSupport
}
return QueryContextWith(ctx, ctxRunner, d)
}
func (d *selectData) QueryRowContext(ctx context.Context) RowScanner {
if d.RunWith == nil {
return &Row{err: RunnerNotSet}
}
queryRower, ok := d.RunWith.(QueryRowerContext)
if !ok {
if _, ok := d.RunWith.(QueryerContext); !ok {
return &Row{err: RunnerNotQueryRunner}
}
return &Row{err: NoContextSupport}
}
return QueryRowContextWith(ctx, queryRower, d)
}
// ExecContext builds and ExecContexts the query with the Runner set by RunWith.
func (b SelectBuilder) ExecContext(ctx context.Context) (sql.Result, error) {
data := builder.GetStruct(b).(selectData)
return data.ExecContext(ctx)
}
// QueryContext builds and QueryContexts the query with the Runner set by RunWith.
func (b SelectBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) {
data := builder.GetStruct(b).(selectData)
return data.QueryContext(ctx)
}
// QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith.
func (b SelectBuilder) QueryRowContext(ctx context.Context) RowScanner {
data := builder.GetStruct(b).(selectData)
return data.QueryRowContext(ctx)
}
// ScanContext is a shortcut for QueryRowContext().Scan.
func (b SelectBuilder) ScanContext(ctx context.Context, dest ...interface{}) error {
return b.QueryRowContext(ctx).Scan(dest...)
}

View File

@ -1,183 +0,0 @@
// Package squirrel provides a fluent SQL generator.
//
// See https://github.com/Masterminds/squirrel for examples.
package squirrel
import (
"bytes"
"database/sql"
"fmt"
"strings"
"github.com/lann/builder"
)
// Sqlizer is the interface that wraps the ToSql method.
//
// ToSql returns a SQL representation of the Sqlizer, along with a slice of args
// as passed to e.g. database/sql.Exec. It can also return an error.
type Sqlizer interface {
ToSql() (string, []interface{}, error)
}
// rawSqlizer is expected to do what Sqlizer does, but without finalizing placeholders.
// This is useful for nested queries.
type rawSqlizer interface {
toSqlRaw() (string, []interface{}, error)
}
// Execer is the interface that wraps the Exec method.
//
// Exec executes the given query as implemented by database/sql.Exec.
type Execer interface {
Exec(query string, args ...interface{}) (sql.Result, error)
}
// Queryer is the interface that wraps the Query method.
//
// Query executes the given query as implemented by database/sql.Query.
type Queryer interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
}
// QueryRower is the interface that wraps the QueryRow method.
//
// QueryRow executes the given query as implemented by database/sql.QueryRow.
type QueryRower interface {
QueryRow(query string, args ...interface{}) RowScanner
}
// BaseRunner groups the Execer and Queryer interfaces.
type BaseRunner interface {
Execer
Queryer
}
// Runner groups the Execer, Queryer, and QueryRower interfaces.
type Runner interface {
Execer
Queryer
QueryRower
}
// WrapStdSql wraps a type implementing the standard SQL interface with methods that
// squirrel expects.
func WrapStdSql(stdSql StdSql) Runner {
return &stdsqlRunner{stdSql}
}
// StdSql encompasses the standard methods of the *sql.DB type, and other types that
// wrap these methods.
type StdSql interface {
Query(string, ...interface{}) (*sql.Rows, error)
QueryRow(string, ...interface{}) *sql.Row
Exec(string, ...interface{}) (sql.Result, error)
}
type stdsqlRunner struct {
StdSql
}
func (r *stdsqlRunner) QueryRow(query string, args ...interface{}) RowScanner {
return r.StdSql.QueryRow(query, args...)
}
func setRunWith(b interface{}, runner BaseRunner) interface{} {
switch r := runner.(type) {
case StdSqlCtx:
runner = WrapStdSqlCtx(r)
case StdSql:
runner = WrapStdSql(r)
}
return builder.Set(b, "RunWith", runner)
}
// RunnerNotSet is returned by methods that need a Runner if it isn't set.
var RunnerNotSet = fmt.Errorf("cannot run; no Runner set (RunWith)")
// RunnerNotQueryRunner is returned by QueryRow if the RunWith value doesn't implement QueryRower.
var RunnerNotQueryRunner = fmt.Errorf("cannot QueryRow; Runner is not a QueryRower")
// ExecWith Execs the SQL returned by s with db.
func ExecWith(db Execer, s Sqlizer) (res sql.Result, err error) {
query, args, err := s.ToSql()
if err != nil {
return
}
return db.Exec(query, args...)
}
// QueryWith Querys the SQL returned by s with db.
func QueryWith(db Queryer, s Sqlizer) (rows *sql.Rows, err error) {
query, args, err := s.ToSql()
if err != nil {
return
}
return db.Query(query, args...)
}
// QueryRowWith QueryRows the SQL returned by s with db.
func QueryRowWith(db QueryRower, s Sqlizer) RowScanner {
query, args, err := s.ToSql()
return &Row{RowScanner: db.QueryRow(query, args...), err: err}
}
// DebugSqlizer calls ToSql on s and shows the approximate SQL to be executed
//
// If ToSql returns an error, the result of this method will look like:
// "[ToSql error: %s]" or "[DebugSqlizer error: %s]"
//
// IMPORTANT: As its name suggests, this function should only be used for
// debugging. While the string result *might* be valid SQL, this function does
// not try very hard to ensure it. Additionally, executing the output of this
// function with any untrusted user input is certainly insecure.
func DebugSqlizer(s Sqlizer) string {
sql, args, err := s.ToSql()
if err != nil {
return fmt.Sprintf("[ToSql error: %s]", err)
}
var placeholder string
downCast, ok := s.(placeholderDebugger)
if !ok {
placeholder = "?"
} else {
placeholder = downCast.debugPlaceholder()
}
// TODO: dedupe this with placeholder.go
buf := &bytes.Buffer{}
i := 0
for {
p := strings.Index(sql, placeholder)
if p == -1 {
break
}
if len(sql[p:]) > 1 && sql[p:p+2] == "??" { // escape ?? => ?
buf.WriteString(sql[:p])
buf.WriteString("?")
if len(sql[p:]) == 1 {
break
}
sql = sql[p+2:]
} else {
if i+1 > len(args) {
return fmt.Sprintf(
"[DebugSqlizer error: too many placeholders in %#v for %d args]",
sql, len(args))
}
buf.WriteString(sql[:p])
fmt.Fprintf(buf, "'%v'", args[i])
// advance our sql string "cursor" beyond the arg we placed
sql = sql[p+1:]
i++
}
}
if i < len(args) {
return fmt.Sprintf(
"[DebugSqlizer error: not enough placeholders in %#v for %d args]",
sql, len(args))
}
// "append" any remaning sql that won't need interpolating
buf.WriteString(sql)
return buf.String()
}

View File

@ -1,93 +0,0 @@
// +build go1.8
package squirrel
import (
"context"
"database/sql"
"errors"
)
// NoContextSupport is returned if a db doesn't support Context.
var NoContextSupport = errors.New("DB does not support Context")
// ExecerContext is the interface that wraps the ExecContext method.
//
// Exec executes the given query as implemented by database/sql.ExecContext.
type ExecerContext interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
// QueryerContext is the interface that wraps the QueryContext method.
//
// QueryContext executes the given query as implemented by database/sql.QueryContext.
type QueryerContext interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
// QueryRowerContext is the interface that wraps the QueryRowContext method.
//
// QueryRowContext executes the given query as implemented by database/sql.QueryRowContext.
type QueryRowerContext interface {
QueryRowContext(ctx context.Context, query string, args ...interface{}) RowScanner
}
// RunnerContext groups the Runner interface, along with the Context versions of each of
// its methods
type RunnerContext interface {
Runner
QueryerContext
QueryRowerContext
ExecerContext
}
// WrapStdSqlCtx wraps a type implementing the standard SQL interface plus the context
// versions of the methods with methods that squirrel expects.
func WrapStdSqlCtx(stdSqlCtx StdSqlCtx) RunnerContext {
return &stdsqlCtxRunner{stdSqlCtx}
}
// StdSqlCtx encompasses the standard methods of the *sql.DB type, along with the Context
// versions of those methods, and other types that wrap these methods.
type StdSqlCtx interface {
StdSql
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
}
type stdsqlCtxRunner struct {
StdSqlCtx
}
func (r *stdsqlCtxRunner) QueryRow(query string, args ...interface{}) RowScanner {
return r.StdSqlCtx.QueryRow(query, args...)
}
func (r *stdsqlCtxRunner) QueryRowContext(ctx context.Context, query string, args ...interface{}) RowScanner {
return r.StdSqlCtx.QueryRowContext(ctx, query, args...)
}
// ExecContextWith ExecContexts the SQL returned by s with db.
func ExecContextWith(ctx context.Context, db ExecerContext, s Sqlizer) (res sql.Result, err error) {
query, args, err := s.ToSql()
if err != nil {
return
}
return db.ExecContext(ctx, query, args...)
}
// QueryContextWith QueryContexts the SQL returned by s with db.
func QueryContextWith(ctx context.Context, db QueryerContext, s Sqlizer) (rows *sql.Rows, err error) {
query, args, err := s.ToSql()
if err != nil {
return
}
return db.QueryContext(ctx, query, args...)
}
// QueryRowContextWith QueryRowContexts the SQL returned by s with db.
func QueryRowContextWith(ctx context.Context, db QueryRowerContext, s Sqlizer) RowScanner {
query, args, err := s.ToSql()
return &Row{RowScanner: db.QueryRowContext(ctx, query, args...), err: err}
}

View File

@ -1,104 +0,0 @@
package squirrel
import "github.com/lann/builder"
// StatementBuilderType is the type of StatementBuilder.
type StatementBuilderType builder.Builder
// Select returns a SelectBuilder for this StatementBuilderType.
func (b StatementBuilderType) Select(columns ...string) SelectBuilder {
return SelectBuilder(b).Columns(columns...)
}
// Insert returns a InsertBuilder for this StatementBuilderType.
func (b StatementBuilderType) Insert(into string) InsertBuilder {
return InsertBuilder(b).Into(into)
}
// Replace returns a InsertBuilder for this StatementBuilderType with the
// statement keyword set to "REPLACE".
func (b StatementBuilderType) Replace(into string) InsertBuilder {
return InsertBuilder(b).statementKeyword("REPLACE").Into(into)
}
// Update returns a UpdateBuilder for this StatementBuilderType.
func (b StatementBuilderType) Update(table string) UpdateBuilder {
return UpdateBuilder(b).Table(table)
}
// Delete returns a DeleteBuilder for this StatementBuilderType.
func (b StatementBuilderType) Delete(from string) DeleteBuilder {
return DeleteBuilder(b).From(from)
}
// PlaceholderFormat sets the PlaceholderFormat field for any child builders.
func (b StatementBuilderType) PlaceholderFormat(f PlaceholderFormat) StatementBuilderType {
return builder.Set(b, "PlaceholderFormat", f).(StatementBuilderType)
}
// RunWith sets the RunWith field for any child builders.
func (b StatementBuilderType) RunWith(runner BaseRunner) StatementBuilderType {
return setRunWith(b, runner).(StatementBuilderType)
}
// Where adds WHERE expressions to the query.
//
// See SelectBuilder.Where for more information.
func (b StatementBuilderType) Where(pred interface{}, args ...interface{}) StatementBuilderType {
return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(StatementBuilderType)
}
// StatementBuilder is a parent builder for other builders, e.g. SelectBuilder.
var StatementBuilder = StatementBuilderType(builder.EmptyBuilder).PlaceholderFormat(Question)
// Select returns a new SelectBuilder, optionally setting some result columns.
//
// See SelectBuilder.Columns.
func Select(columns ...string) SelectBuilder {
return StatementBuilder.Select(columns...)
}
// Insert returns a new InsertBuilder with the given table name.
//
// See InsertBuilder.Into.
func Insert(into string) InsertBuilder {
return StatementBuilder.Insert(into)
}
// Replace returns a new InsertBuilder with the statement keyword set to
// "REPLACE" and with the given table name.
//
// See InsertBuilder.Into.
func Replace(into string) InsertBuilder {
return StatementBuilder.Replace(into)
}
// Update returns a new UpdateBuilder with the given table name.
//
// See UpdateBuilder.Table.
func Update(table string) UpdateBuilder {
return StatementBuilder.Update(table)
}
// Delete returns a new DeleteBuilder with the given table name.
//
// See DeleteBuilder.Table.
func Delete(from string) DeleteBuilder {
return StatementBuilder.Delete(from)
}
// Case returns a new CaseBuilder
// "what" represents case value
func Case(what ...interface{}) CaseBuilder {
b := CaseBuilder(builder.EmptyBuilder)
switch len(what) {
case 0:
case 1:
b = b.what(what[0])
default:
b = b.what(newPart(what[0], what[1:]...))
}
return b
}

View File

@ -1,121 +0,0 @@
package squirrel
import (
"database/sql"
"fmt"
"sync"
)
// Prepareer is the interface that wraps the Prepare method.
//
// Prepare executes the given query as implemented by database/sql.Prepare.
type Preparer interface {
Prepare(query string) (*sql.Stmt, error)
}
// DBProxy groups the Execer, Queryer, QueryRower, and Preparer interfaces.
type DBProxy interface {
Execer
Queryer
QueryRower
Preparer
}
// NOTE: NewStmtCache is defined in stmtcacher_ctx.go (Go >= 1.8) or stmtcacher_noctx.go (Go < 1.8).
// StmtCache wraps and delegates down to a Preparer type
//
// It also automatically prepares all statements sent to the underlying Preparer calls
// for Exec, Query and QueryRow and caches the returns *sql.Stmt using the provided
// query as the key. So that it can be automatically re-used.
type StmtCache struct {
prep Preparer
cache map[string]*sql.Stmt
mu sync.Mutex
}
// Prepare delegates down to the underlying Preparer and caches the result
// using the provided query as a key
func (sc *StmtCache) Prepare(query string) (*sql.Stmt, error) {
sc.mu.Lock()
defer sc.mu.Unlock()
stmt, ok := sc.cache[query]
if ok {
return stmt, nil
}
stmt, err := sc.prep.Prepare(query)
if err == nil {
sc.cache[query] = stmt
}
return stmt, err
}
// Exec delegates down to the underlying Preparer using a prepared statement
func (sc *StmtCache) Exec(query string, args ...interface{}) (res sql.Result, err error) {
stmt, err := sc.Prepare(query)
if err != nil {
return
}
return stmt.Exec(args...)
}
// Query delegates down to the underlying Preparer using a prepared statement
func (sc *StmtCache) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := sc.Prepare(query)
if err != nil {
return
}
return stmt.Query(args...)
}
// QueryRow delegates down to the underlying Preparer using a prepared statement
func (sc *StmtCache) QueryRow(query string, args ...interface{}) RowScanner {
stmt, err := sc.Prepare(query)
if err != nil {
return &Row{err: err}
}
return stmt.QueryRow(args...)
}
// Clear removes and closes all the currently cached prepared statements
func (sc *StmtCache) Clear() (err error) {
sc.mu.Lock()
defer sc.mu.Unlock()
for key, stmt := range sc.cache {
delete(sc.cache, key)
if stmt == nil {
continue
}
if cerr := stmt.Close(); cerr != nil {
err = cerr
}
}
if err != nil {
return fmt.Errorf("one or more Stmt.Close failed; last error: %v", err)
}
return
}
type DBProxyBeginner interface {
DBProxy
Begin() (*sql.Tx, error)
}
type stmtCacheProxy struct {
DBProxy
db *sql.DB
}
func NewStmtCacheProxy(db *sql.DB) DBProxyBeginner {
return &stmtCacheProxy{DBProxy: NewStmtCache(db), db: db}
}
func (sp *stmtCacheProxy) Begin() (*sql.Tx, error) {
return sp.db.Begin()
}

View File

@ -1,86 +0,0 @@
// +build go1.8
package squirrel
import (
"context"
"database/sql"
)
// PrepareerContext is the interface that wraps the Prepare and PrepareContext methods.
//
// Prepare executes the given query as implemented by database/sql.Prepare.
// PrepareContext executes the given query as implemented by database/sql.PrepareContext.
type PreparerContext interface {
Preparer
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
// DBProxyContext groups the Execer, Queryer, QueryRower and PreparerContext interfaces.
type DBProxyContext interface {
Execer
Queryer
QueryRower
PreparerContext
}
// NewStmtCache returns a *StmtCache wrapping a PreparerContext that caches Prepared Stmts.
//
// Stmts are cached based on the string value of their queries.
func NewStmtCache(prep PreparerContext) *StmtCache {
return &StmtCache{prep: prep, cache: make(map[string]*sql.Stmt)}
}
// NewStmtCacher is deprecated
//
// Use NewStmtCache instead
func NewStmtCacher(prep PreparerContext) DBProxyContext {
return NewStmtCache(prep)
}
// PrepareContext delegates down to the underlying PreparerContext and caches the result
// using the provided query as a key
func (sc *StmtCache) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
ctxPrep, ok := sc.prep.(PreparerContext)
if !ok {
return nil, NoContextSupport
}
sc.mu.Lock()
defer sc.mu.Unlock()
stmt, ok := sc.cache[query]
if ok {
return stmt, nil
}
stmt, err := ctxPrep.PrepareContext(ctx, query)
if err == nil {
sc.cache[query] = stmt
}
return stmt, err
}
// ExecContext delegates down to the underlying PreparerContext using a prepared statement
func (sc *StmtCache) ExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result, err error) {
stmt, err := sc.PrepareContext(ctx, query)
if err != nil {
return
}
return stmt.ExecContext(ctx, args...)
}
// QueryContext delegates down to the underlying PreparerContext using a prepared statement
func (sc *StmtCache) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := sc.PrepareContext(ctx, query)
if err != nil {
return
}
return stmt.QueryContext(ctx, args...)
}
// QueryRowContext delegates down to the underlying PreparerContext using a prepared statement
func (sc *StmtCache) QueryRowContext(ctx context.Context, query string, args ...interface{}) RowScanner {
stmt, err := sc.PrepareContext(ctx, query)
if err != nil {
return &Row{err: err}
}
return stmt.QueryRowContext(ctx, args...)
}

View File

@ -1,21 +0,0 @@
// +build !go1.8
package squirrel
import (
"database/sql"
)
// NewStmtCacher returns a DBProxy wrapping prep that caches Prepared Stmts.
//
// Stmts are cached based on the string value of their queries.
func NewStmtCache(prep Preparer) *StmtCache {
return &StmtCacher{prep: prep, cache: make(map[string]*sql.Stmt)}
}
// NewStmtCacher is deprecated
//
// Use NewStmtCache instead
func NewStmtCacher(prep Preparer) DBProxy {
return NewStmtCache(prep)
}

View File

@ -1,288 +0,0 @@
package squirrel
import (
"bytes"
"database/sql"
"fmt"
"sort"
"strings"
"github.com/lann/builder"
)
type updateData struct {
PlaceholderFormat PlaceholderFormat
RunWith BaseRunner
Prefixes []Sqlizer
Table string
SetClauses []setClause
From Sqlizer
WhereParts []Sqlizer
OrderBys []string
Limit string
Offset string
Suffixes []Sqlizer
}
type setClause struct {
column string
value interface{}
}
func (d *updateData) Exec() (sql.Result, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
return ExecWith(d.RunWith, d)
}
func (d *updateData) Query() (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
return QueryWith(d.RunWith, d)
}
func (d *updateData) QueryRow() RowScanner {
if d.RunWith == nil {
return &Row{err: RunnerNotSet}
}
queryRower, ok := d.RunWith.(QueryRower)
if !ok {
return &Row{err: RunnerNotQueryRunner}
}
return QueryRowWith(queryRower, d)
}
func (d *updateData) ToSql() (sqlStr string, args []interface{}, err error) {
if len(d.Table) == 0 {
err = fmt.Errorf("update statements must specify a table")
return
}
if len(d.SetClauses) == 0 {
err = fmt.Errorf("update statements must have at least one Set clause")
return
}
sql := &bytes.Buffer{}
if len(d.Prefixes) > 0 {
args, err = appendToSql(d.Prefixes, sql, " ", args)
if err != nil {
return
}
sql.WriteString(" ")
}
sql.WriteString("UPDATE ")
sql.WriteString(d.Table)
sql.WriteString(" SET ")
setSqls := make([]string, len(d.SetClauses))
for i, setClause := range d.SetClauses {
var valSql string
if vs, ok := setClause.value.(Sqlizer); ok {
vsql, vargs, err := vs.ToSql()
if err != nil {
return "", nil, err
}
if _, ok := vs.(SelectBuilder); ok {
valSql = fmt.Sprintf("(%s)", vsql)
} else {
valSql = vsql
}
args = append(args, vargs...)
} else {
valSql = "?"
args = append(args, setClause.value)
}
setSqls[i] = fmt.Sprintf("%s = %s", setClause.column, valSql)
}
sql.WriteString(strings.Join(setSqls, ", "))
if d.From != nil {
sql.WriteString(" FROM ")
args, err = appendToSql([]Sqlizer{d.From}, sql, "", args)
if err != nil {
return
}
}
if len(d.WhereParts) > 0 {
sql.WriteString(" WHERE ")
args, err = appendToSql(d.WhereParts, sql, " AND ", args)
if err != nil {
return
}
}
if len(d.OrderBys) > 0 {
sql.WriteString(" ORDER BY ")
sql.WriteString(strings.Join(d.OrderBys, ", "))
}
if len(d.Limit) > 0 {
sql.WriteString(" LIMIT ")
sql.WriteString(d.Limit)
}
if len(d.Offset) > 0 {
sql.WriteString(" OFFSET ")
sql.WriteString(d.Offset)
}
if len(d.Suffixes) > 0 {
sql.WriteString(" ")
args, err = appendToSql(d.Suffixes, sql, " ", args)
if err != nil {
return
}
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
return
}
// Builder
// UpdateBuilder builds SQL UPDATE statements.
type UpdateBuilder builder.Builder
func init() {
builder.Register(UpdateBuilder{}, updateData{})
}
// Format methods
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b UpdateBuilder) PlaceholderFormat(f PlaceholderFormat) UpdateBuilder {
return builder.Set(b, "PlaceholderFormat", f).(UpdateBuilder)
}
// Runner methods
// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec.
func (b UpdateBuilder) RunWith(runner BaseRunner) UpdateBuilder {
return setRunWith(b, runner).(UpdateBuilder)
}
// Exec builds and Execs the query with the Runner set by RunWith.
func (b UpdateBuilder) Exec() (sql.Result, error) {
data := builder.GetStruct(b).(updateData)
return data.Exec()
}
func (b UpdateBuilder) Query() (*sql.Rows, error) {
data := builder.GetStruct(b).(updateData)
return data.Query()
}
func (b UpdateBuilder) QueryRow() RowScanner {
data := builder.GetStruct(b).(updateData)
return data.QueryRow()
}
func (b UpdateBuilder) Scan(dest ...interface{}) error {
return b.QueryRow().Scan(dest...)
}
// SQL methods
// ToSql builds the query into a SQL string and bound args.
func (b UpdateBuilder) ToSql() (string, []interface{}, error) {
data := builder.GetStruct(b).(updateData)
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b UpdateBuilder) MustSql() (string, []interface{}) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// Prefix adds an expression to the beginning of the query
func (b UpdateBuilder) Prefix(sql string, args ...interface{}) UpdateBuilder {
return b.PrefixExpr(Expr(sql, args...))
}
// PrefixExpr adds an expression to the very beginning of the query
func (b UpdateBuilder) PrefixExpr(expr Sqlizer) UpdateBuilder {
return builder.Append(b, "Prefixes", expr).(UpdateBuilder)
}
// Table sets the table to be updated.
func (b UpdateBuilder) Table(table string) UpdateBuilder {
return builder.Set(b, "Table", table).(UpdateBuilder)
}
// Set adds SET clauses to the query.
func (b UpdateBuilder) Set(column string, value interface{}) UpdateBuilder {
return builder.Append(b, "SetClauses", setClause{column: column, value: value}).(UpdateBuilder)
}
// SetMap is a convenience method which calls .Set for each key/value pair in clauses.
func (b UpdateBuilder) SetMap(clauses map[string]interface{}) UpdateBuilder {
keys := make([]string, len(clauses))
i := 0
for key := range clauses {
keys[i] = key
i++
}
sort.Strings(keys)
for _, key := range keys {
val, _ := clauses[key]
b = b.Set(key, val)
}
return b
}
// From adds FROM clause to the query
// FROM is valid construct in postgresql only.
func (b UpdateBuilder) From(from string) UpdateBuilder {
return builder.Set(b, "From", newPart(from)).(UpdateBuilder)
}
// FromSelect sets a subquery into the FROM clause of the query.
func (b UpdateBuilder) FromSelect(from SelectBuilder, alias string) UpdateBuilder {
// Prevent misnumbered parameters in nested selects (#183).
from = from.PlaceholderFormat(Question)
return builder.Set(b, "From", Alias(from, alias)).(UpdateBuilder)
}
// Where adds WHERE expressions to the query.
//
// See SelectBuilder.Where for more information.
func (b UpdateBuilder) Where(pred interface{}, args ...interface{}) UpdateBuilder {
return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(UpdateBuilder)
}
// OrderBy adds ORDER BY expressions to the query.
func (b UpdateBuilder) OrderBy(orderBys ...string) UpdateBuilder {
return builder.Extend(b, "OrderBys", orderBys).(UpdateBuilder)
}
// Limit sets a LIMIT clause on the query.
func (b UpdateBuilder) Limit(limit uint64) UpdateBuilder {
return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(UpdateBuilder)
}
// Offset sets a OFFSET clause on the query.
func (b UpdateBuilder) Offset(offset uint64) UpdateBuilder {
return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(UpdateBuilder)
}
// Suffix adds an expression to the end of the query
func (b UpdateBuilder) Suffix(sql string, args ...interface{}) UpdateBuilder {
return b.SuffixExpr(Expr(sql, args...))
}
// SuffixExpr adds an expression to the end of the query
func (b UpdateBuilder) SuffixExpr(expr Sqlizer) UpdateBuilder {
return builder.Append(b, "Suffixes", expr).(UpdateBuilder)
}

View File

@ -1,69 +0,0 @@
// +build go1.8
package squirrel
import (
"context"
"database/sql"
"github.com/lann/builder"
)
func (d *updateData) ExecContext(ctx context.Context) (sql.Result, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(ExecerContext)
if !ok {
return nil, NoContextSupport
}
return ExecContextWith(ctx, ctxRunner, d)
}
func (d *updateData) QueryContext(ctx context.Context) (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(QueryerContext)
if !ok {
return nil, NoContextSupport
}
return QueryContextWith(ctx, ctxRunner, d)
}
func (d *updateData) QueryRowContext(ctx context.Context) RowScanner {
if d.RunWith == nil {
return &Row{err: RunnerNotSet}
}
queryRower, ok := d.RunWith.(QueryRowerContext)
if !ok {
if _, ok := d.RunWith.(QueryerContext); !ok {
return &Row{err: RunnerNotQueryRunner}
}
return &Row{err: NoContextSupport}
}
return QueryRowContextWith(ctx, queryRower, d)
}
// ExecContext builds and ExecContexts the query with the Runner set by RunWith.
func (b UpdateBuilder) ExecContext(ctx context.Context) (sql.Result, error) {
data := builder.GetStruct(b).(updateData)
return data.ExecContext(ctx)
}
// QueryContext builds and QueryContexts the query with the Runner set by RunWith.
func (b UpdateBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) {
data := builder.GetStruct(b).(updateData)
return data.QueryContext(ctx)
}
// QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith.
func (b UpdateBuilder) QueryRowContext(ctx context.Context) RowScanner {
data := builder.GetStruct(b).(updateData)
return data.QueryRowContext(ctx)
}
// ScanContext is a shortcut for QueryRowContext().Scan.
func (b UpdateBuilder) ScanContext(ctx context.Context, dest ...interface{}) error {
return b.QueryRowContext(ctx).Scan(dest...)
}

View File

@ -1,30 +0,0 @@
package squirrel
import (
"fmt"
)
type wherePart part
func newWherePart(pred interface{}, args ...interface{}) Sqlizer {
return &wherePart{pred: pred, args: args}
}
func (p wherePart) ToSql() (sql string, args []interface{}, err error) {
switch pred := p.pred.(type) {
case nil:
// no-op
case rawSqlizer:
return pred.toSqlRaw()
case Sqlizer:
return pred.ToSql()
case map[string]interface{}:
return Eq(pred).ToSql()
case string:
sql = pred
args = p.args
default:
err = fmt.Errorf("expected string-keyed map or string, not %T", pred)
}
return
}

View File

@ -1,14 +0,0 @@
sudo: false
language: go
install: go get -t -v ./...
go:
- 1.2.x
- 1.3.x
- 1.4.x
- 1.5.x
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- 1.11.x

View File

@ -1,19 +0,0 @@
Copyright (C) 2014 Alec Thomas
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,709 +0,0 @@
# CONTRIBUTIONS ONLY
**What does this mean?** I do not have time to fix issues myself. The only way fixes or new features will be added is by people submitting PRs. If you are interested in taking over maintenance and have a history of contributions to Kingpin, please let me know.
**Current status.** Kingpin is largely feature stable. There hasn't been a need to add new features in a while, but there are some bugs that should be fixed.
**Why?** I no longer use Kingpin personally (I now use [kong](https://github.com/alecthomas/kong)). Rather than leave the project in a limbo of people filing issues and wondering why they're not being worked on, I believe this notice will more clearly set expectations.
# Kingpin - A Go (golang) command line and flag parser
[![](https://godoc.org/github.com/alecthomas/kingpin?status.svg)](http://godoc.org/github.com/alecthomas/kingpin) [![CI](https://github.com/alecthomas/kingpin/actions/workflows/ci.yml/badge.svg)](https://github.com/alecthomas/kingpin/actions/workflows/ci.yml)
<!-- MarkdownTOC -->
- [Overview](#overview)
- [Features](#features)
- [User-visible changes between v1 and v2](#user-visible-changes-between-v1-and-v2)
- [Flags can be used at any point after their definition.](#flags-can-be-used-at-any-point-after-their-definition)
- [Short flags can be combined with their parameters](#short-flags-can-be-combined-with-their-parameters)
- [API changes between v1 and v2](#api-changes-between-v1-and-v2)
- [Versions](#versions)
- [V2 is the current stable version](#v2-is-the-current-stable-version)
- [V1 is the OLD stable version](#v1-is-the-old-stable-version)
- [Change History](#change-history)
- [Examples](#examples)
- [Simple Example](#simple-example)
- [Complex Example](#complex-example)
- [Reference Documentation](#reference-documentation)
- [Displaying errors and usage information](#displaying-errors-and-usage-information)
- [Sub-commands](#sub-commands)
- [Custom Parsers](#custom-parsers)
- [Repeatable flags](#repeatable-flags)
- [Boolean Values](#boolean-values)
- [Default Values](#default-values)
- [Place-holders in Help](#place-holders-in-help)
- [Consuming all remaining arguments](#consuming-all-remaining-arguments)
- [Bash/ZSH Shell Completion](#bashzsh-shell-completion)
- [Supporting -h for help](#supporting--h-for-help)
- [Custom help](#custom-help)
<!-- /MarkdownTOC -->
## Overview
Kingpin is a [fluent-style](http://en.wikipedia.org/wiki/Fluent_interface),
type-safe command-line parser. It supports flags, nested commands, and
positional arguments.
Install it with:
$ go get github.com/alecthomas/kingpin/v2
It looks like this:
```go
var (
verbose = kingpin.Flag("verbose", "Verbose mode.").Short('v').Bool()
name = kingpin.Arg("name", "Name of user.").Required().String()
)
func main() {
kingpin.Parse()
fmt.Printf("%v, %s\n", *verbose, *name)
}
```
More [examples](https://github.com/alecthomas/kingpin/tree/master/_examples) are available.
Second to parsing, providing the user with useful help is probably the most
important thing a command-line parser does. Kingpin tries to provide detailed
contextual help if `--help` is encountered at any point in the command line
(excluding after `--`).
## Features
- Help output that isn't as ugly as sin.
- Fully [customisable help](#custom-help), via Go templates.
- Parsed, type-safe flags (`kingpin.Flag("f", "help").Int()`)
- Parsed, type-safe positional arguments (`kingpin.Arg("a", "help").Int()`).
- Parsed, type-safe, arbitrarily deep commands (`kingpin.Command("c", "help")`).
- Support for required flags and required positional arguments (`kingpin.Flag("f", "").Required().Int()`).
- Support for arbitrarily nested default commands (`command.Default()`).
- Callbacks per command, flag and argument (`kingpin.Command("c", "").Action(myAction)`).
- POSIX-style short flag combining (`-a -b` -> `-ab`).
- Short-flag+parameter combining (`-a parm` -> `-aparm`).
- Read command-line from files (`@<file>`).
- Automatically generate man pages (`--help-man`).
## User-visible changes between v1 and v2
### Flags can be used at any point after their definition.
Flags can be specified at any point after their definition, not just
*immediately after their associated command*. From the chat example below, the
following used to be required:
```
$ chat --server=chat.server.com:8080 post --image=~/Downloads/owls.jpg pics
```
But the following will now work:
```
$ chat post --server=chat.server.com:8080 --image=~/Downloads/owls.jpg pics
```
### Short flags can be combined with their parameters
Previously, if a short flag was used, any argument to that flag would have to
be separated by a space. That is no longer the case.
## API changes between v1 and v2
- `ParseWithFileExpansion()` is gone. The new parser directly supports expanding `@<file>`.
- Added `FatalUsage()` and `FatalUsageContext()` for displaying an error + usage and terminating.
- `Dispatch()` renamed to `Action()`.
- Added `ParseContext()` for parsing a command line into its intermediate context form without executing.
- Added `Terminate()` function to override the termination function.
- Added `UsageForContextWithTemplate()` for printing usage via a custom template.
- Added `UsageTemplate()` for overriding the default template to use. Two templates are included:
1. `DefaultUsageTemplate` - default template.
2. `CompactUsageTemplate` - compact command template for larger applications.
## Versions
The current stable version is [github.com/alecthomas/kingpin/v2](https://github.com/alecthomas/kingpin/v2). The previous version, [gopkg.in/alecthomas/kingpin.v1](https://gopkg.in/alecthomas/kingpin.v1), is deprecated and in maintenance mode.
### [V2](https://github.com/alecthomas/kingpin/v2) is the current stable version
Installation:
```sh
$ go get github.com/alecthomas/kingpin/v2
```
### [V1](https://gopkg.in/alecthomas/kingpin.v1) is the OLD stable version
Installation:
```sh
$ go get gopkg.in/alecthomas/kingpin.v1
```
## Change History
- *2015-09-19* -- Stable v2.1.0 release.
- Added `command.Default()` to specify a default command to use if no other
command matches. This allows for convenient user shortcuts.
- Exposed `HelpFlag` and `VersionFlag` for further customisation.
- `Action()` and `PreAction()` added and both now support an arbitrary
number of callbacks.
- `kingpin.SeparateOptionalFlagsUsageTemplate`.
- `--help-long` and `--help-man` (hidden by default) flags.
- Flags are "interspersed" by default, but can be disabled with `app.Interspersed(false)`.
- Added flags for all simple builtin types (int8, uint16, etc.) and slice variants.
- Use `app.Writer(os.Writer)` to specify the default writer for all output functions.
- Dropped `os.Writer` prefix from all printf-like functions.
- *2015-05-22* -- Stable v2.0.0 release.
- Initial stable release of v2.0.0.
- Fully supports interspersed flags, commands and arguments.
- Flags can be present at any point after their logical definition.
- Application.Parse() terminates if commands are present and a command is not parsed.
- Dispatch() -> Action().
- Actions are dispatched after all values are populated.
- Override termination function (defaults to os.Exit).
- Override output stream (defaults to os.Stderr).
- Templatised usage help, with default and compact templates.
- Make error/usage functions more consistent.
- Support argument expansion from files by default (with @<file>).
- Fully public data model is available via .Model().
- Parser has been completely refactored.
- Parsing and execution has been split into distinct stages.
- Use `go generate` to generate repeated flags.
- Support combined short-flag+argument: -fARG.
- *2015-01-23* -- Stable v1.3.4 release.
- Support "--" for separating flags from positional arguments.
- Support loading flags from files (ParseWithFileExpansion()). Use @FILE as an argument.
- Add post-app and post-cmd validation hooks. This allows arbitrary validation to be added.
- A bunch of improvements to help usage and formatting.
- Support arbitrarily nested sub-commands.
- *2014-07-08* -- Stable v1.2.0 release.
- Pass any value through to `Strings()` when final argument.
Allows for values that look like flags to be processed.
- Allow `--help` to be used with commands.
- Support `Hidden()` flags.
- Parser for [units.Base2Bytes](https://github.com/alecthomas/units)
type. Allows for flags like `--ram=512MB` or `--ram=1GB`.
- Add an `Enum()` value, allowing only one of a set of values
to be selected. eg. `Flag(...).Enum("debug", "info", "warning")`.
- *2014-06-27* -- Stable v1.1.0 release.
- Bug fixes.
- Always return an error (rather than panicing) when misconfigured.
- `OpenFile(flag, perm)` value type added, for finer control over opening files.
- Significantly improved usage formatting.
- *2014-06-19* -- Stable v1.0.0 release.
- Support [cumulative positional](#consuming-all-remaining-arguments) arguments.
- Return error rather than panic when there are fatal errors not caught by
the type system. eg. when a default value is invalid.
- Use gokpg.in.
- *2014-06-10* -- Place-holder streamlining.
- Renamed `MetaVar` to `PlaceHolder`.
- Removed `MetaVarFromDefault`. Kingpin now uses [heuristics](#place-holders-in-help)
to determine what to display.
## Examples
### Simple Example
Kingpin can be used for simple flag+arg applications like so:
```
$ ping --help
usage: ping [<flags>] <ip> [<count>]
Flags:
--debug Enable debug mode.
--help Show help.
-t, --timeout=5s Timeout waiting for ping.
Args:
<ip> IP address to ping.
[<count>] Number of packets to send
$ ping 1.2.3.4 5
Would ping: 1.2.3.4 with timeout 5s and count 5
```
From the following source:
```go
package main
import (
"fmt"
"github.com/alecthomas/kingpin/v2"
)
var (
debug = kingpin.Flag("debug", "Enable debug mode.").Bool()
timeout = kingpin.Flag("timeout", "Timeout waiting for ping.").Default("5s").Envar("PING_TIMEOUT").Short('t').Duration()
ip = kingpin.Arg("ip", "IP address to ping.").Required().IP()
count = kingpin.Arg("count", "Number of packets to send").Int()
)
func main() {
kingpin.Version("0.0.1")
kingpin.Parse()
fmt.Printf("Would ping: %s with timeout %s and count %d\n", *ip, *timeout, *count)
}
```
#### Reading arguments from a file
Kingpin supports reading arguments from a file.
Create a file with the corresponding arguments:
```
echo -t=5\n > args
```
And now supply it:
```
$ ping @args
```
### Complex Example
Kingpin can also produce complex command-line applications with global flags,
subcommands, and per-subcommand flags, like this:
```
$ chat --help
usage: chat [<flags>] <command> [<flags>] [<args> ...]
A command-line chat application.
Flags:
--help Show help.
--debug Enable debug mode.
--server=127.0.0.1 Server address.
Commands:
help [<command>]
Show help for a command.
register <nick> <name>
Register a new user.
post [<flags>] <channel> [<text>]
Post a message to a channel.
$ chat help post
usage: chat [<flags>] post [<flags>] <channel> [<text>]
Post a message to a channel.
Flags:
--image=IMAGE Image to post.
Args:
<channel> Channel to post to.
[<text>] Text to post.
$ chat post --image=~/Downloads/owls.jpg pics
...
```
From this code:
```go
package main
import (
"os"
"strings"
"github.com/alecthomas/kingpin/v2"
)
var (
app = kingpin.New("chat", "A command-line chat application.")
debug = app.Flag("debug", "Enable debug mode.").Bool()
serverIP = app.Flag("server", "Server address.").Default("127.0.0.1").IP()
register = app.Command("register", "Register a new user.")
registerNick = register.Arg("nick", "Nickname for user.").Required().String()
registerName = register.Arg("name", "Name of user.").Required().String()
post = app.Command("post", "Post a message to a channel.")
postImage = post.Flag("image", "Image to post.").File()
postChannel = post.Arg("channel", "Channel to post to.").Required().String()
postText = post.Arg("text", "Text to post.").Strings()
)
func main() {
switch kingpin.MustParse(app.Parse(os.Args[1:])) {
// Register user
case register.FullCommand():
println(*registerNick)
// Post message
case post.FullCommand():
if *postImage != nil {
}
text := strings.Join(*postText, " ")
println("Post:", text)
}
}
```
## Reference Documentation
### Displaying errors and usage information
Kingpin exports a set of functions to provide consistent errors and usage
information to the user.
Error messages look something like this:
<app>: error: <message>
The functions on `Application` are:
Function | Purpose
---------|--------------
`Errorf(format, args)` | Display a printf formatted error to the user.
`Fatalf(format, args)` | As with Errorf, but also call the termination handler.
`FatalUsage(format, args)` | As with Fatalf, but also print contextual usage information.
`FatalUsageContext(context, format, args)` | As with Fatalf, but also print contextual usage information from a `ParseContext`.
`FatalIfError(err, format, args)` | Conditionally print an error prefixed with format+args, then call the termination handler
There are equivalent global functions in the kingpin namespace for the default
`kingpin.CommandLine` instance.
### Sub-commands
Kingpin supports nested sub-commands, with separate flag and positional
arguments per sub-command. Note that positional arguments may only occur after
sub-commands.
For example:
```go
var (
deleteCommand = kingpin.Command("delete", "Delete an object.")
deleteUserCommand = deleteCommand.Command("user", "Delete a user.")
deleteUserUIDFlag = deleteUserCommand.Flag("uid", "Delete user by UID rather than username.")
deleteUserUsername = deleteUserCommand.Arg("username", "Username to delete.")
deletePostCommand = deleteCommand.Command("post", "Delete a post.")
)
func main() {
switch kingpin.Parse() {
case deleteUserCommand.FullCommand():
case deletePostCommand.FullCommand():
}
}
```
### Custom Parsers
Kingpin supports both flag and positional argument parsers for converting to
Go types. For example, some included parsers are `Int()`, `Float()`,
`Duration()` and `ExistingFile()` (see [parsers.go](./parsers.go) for a complete list of included parsers).
Parsers conform to Go's [`flag.Value`](http://godoc.org/flag#Value)
interface, so any existing implementations will work.
For example, a parser for accumulating HTTP header values might look like this:
```go
type HTTPHeaderValue http.Header
func (h *HTTPHeaderValue) Set(value string) error {
parts := strings.SplitN(value, ":", 2)
if len(parts) != 2 {
return fmt.Errorf("expected HEADER:VALUE got '%s'", value)
}
(*http.Header)(h).Add(parts[0], parts[1])
return nil
}
func (h *HTTPHeaderValue) String() string {
return ""
}
```
As a convenience, I would recommend something like this:
```go
func HTTPHeader(s Settings) (target *http.Header) {
target = &http.Header{}
s.SetValue((*HTTPHeaderValue)(target))
return
}
```
You would use it like so:
```go
headers = HTTPHeader(kingpin.Flag("header", "Add a HTTP header to the request.").Short('H'))
```
### Repeatable flags
Depending on the `Value` they hold, some flags may be repeated. The
`IsCumulative() bool` function on `Value` tells if it's safe to call `Set()`
multiple times or if an error should be raised if several values are passed.
The built-in `Value`s returning slices and maps, as well as `Counter` are
examples of `Value`s that make a flag repeatable.
### Boolean values
Boolean values are uniquely managed by Kingpin. Each boolean flag will have a negative complement:
`--<name>` and `--no-<name>`.
### Default Values
The default value is the zero value for a type. This can be overridden with
the `Default(value...)` function on flags and arguments. This function accepts
one or several strings, which are parsed by the value itself, so they *must*
be compliant with the format expected.
### Place-holders in Help
The place-holder value for a flag is the value used in the help to describe
the value of a non-boolean flag.
The value provided to PlaceHolder() is used if provided, then the value
provided by Default() if provided, then finally the capitalised flag name is
used.
Here are some examples of flags with various permutations:
--name=NAME // Flag(...).String()
--name="Harry" // Flag(...).Default("Harry").String()
--name=FULL-NAME // Flag(...).PlaceHolder("FULL-NAME").Default("Harry").String()
### Consuming all remaining arguments
A common command-line idiom is to use all remaining arguments for some
purpose. eg. The following command accepts an arbitrary number of
IP addresses as positional arguments:
./cmd ping 10.1.1.1 192.168.1.1
Such arguments are similar to [repeatable flags](#repeatable-flags), but for
arguments. Therefore they use the same `IsCumulative() bool` function on the
underlying `Value`, so the built-in `Value`s for which the `Set()` function
can be called several times will consume multiple arguments.
To implement the above example with a custom `Value`, we might do something
like this:
```go
type ipList []net.IP
func (i *ipList) Set(value string) error {
if ip := net.ParseIP(value); ip == nil {
return fmt.Errorf("'%s' is not an IP address", value)
} else {
*i = append(*i, ip)
return nil
}
}
func (i *ipList) String() string {
return ""
}
func (i *ipList) IsCumulative() bool {
return true
}
func IPList(s Settings) (target *[]net.IP) {
target = new([]net.IP)
s.SetValue((*ipList)(target))
return
}
```
And use it like so:
```go
ips := IPList(kingpin.Arg("ips", "IP addresses to ping."))
```
### Bash/ZSH Shell Completion
By default, all flags and commands/subcommands generate completions
internally.
Out of the box, CLI tools using kingpin should be able to take advantage
of completion hinting for flags and commands. By specifying
`--completion-bash` as the first argument, your CLI tool will show
possible subcommands. By ending your argv with `--`, hints for flags
will be shown.
To allow your end users to take advantage you must package a
`/etc/bash_completion.d` script with your distribution (or the equivalent
for your target platform/shell). An alternative is to instruct your end
user to source a script from their `bash_profile` (or equivalent).
Fortunately Kingpin makes it easy to generate or source a script for use
with end users shells. `./yourtool --completion-script-bash` and
`./yourtool --completion-script-zsh` will generate these scripts for you.
**Installation by Package**
For the best user experience, you should bundle your pre-created
completion script with your CLI tool and install it inside
`/etc/bash_completion.d` (or equivalent). A good suggestion is to add
this as an automated step to your build pipeline, in the implementation
is improved for bug fixed.
**Installation by `bash_profile`**
Alternatively, instruct your users to add an additional statement to
their `bash_profile` (or equivalent):
```
eval "$(your-cli-tool --completion-script-bash)"
```
Or for ZSH
```
eval "$(your-cli-tool --completion-script-zsh)"
```
#### Additional API
To provide more flexibility, a completion option API has been
exposed for flags to allow user defined completion options, to extend
completions further than just EnumVar/Enum.
**Provide Static Options**
When using an `Enum` or `EnumVar`, users are limited to only the options
given. Maybe we wish to hint possible options to the user, but also
allow them to provide their own custom option. `HintOptions` gives
this functionality to flags.
```
app := kingpin.New("completion", "My application with bash completion.")
app.Flag("port", "Provide a port to connect to").
Required().
HintOptions("80", "443", "8080").
IntVar(&c.port)
```
**Provide Dynamic Options**
Consider the case that you needed to read a local database or a file to
provide suggestions. You can dynamically generate the options
```
func listHosts() []string {
// Provide a dynamic list of hosts from a hosts file or otherwise
// for bash completion. In this example we simply return static slice.
// You could use this functionality to reach into a hosts file to provide
// completion for a list of known hosts.
return []string{"sshhost.example", "webhost.example", "ftphost.example"}
}
app := kingpin.New("completion", "My application with bash completion.")
app.Flag("flag-1", "").HintAction(listHosts).String()
```
**EnumVar/Enum**
When using `Enum` or `EnumVar`, any provided options will be automatically
used for bash autocompletion. However, if you wish to provide a subset or
different options, you can use `HintOptions` or `HintAction` which will override
the default completion options for `Enum`/`EnumVar`.
**Examples**
You can see an in depth example of the completion API within
`examples/completion/main.go`
### Supporting -h for help
`kingpin.CommandLine.HelpFlag.Short('h')`
Short help is also available when creating a more complicated app:
```go
var (
app = kingpin.New("chat", "A command-line chat application.")
// ...
)
func main() {
app.HelpFlag.Short('h')
switch kingpin.MustParse(app.Parse(os.Args[1:])) {
// ...
}
}
```
### Custom help
Kingpin v2 supports templatised help using the text/template library (actually, [a fork](https://github.com/alecthomas/template)).
You can specify the template to use with the [Application.UsageTemplate()](http://godoc.org/github.com/alecthomas/kingpin/v2#Application.UsageTemplate) function.
There are four included templates: `kingpin.DefaultUsageTemplate` is the default,
`kingpin.CompactUsageTemplate` provides a more compact representation for more complex command-line structures,
`kingpin.SeparateOptionalFlagsUsageTemplate` looks like the default template, but splits required
and optional command flags into separate lists, and `kingpin.ManPageTemplate` is used to generate man pages.
See the above templates for examples of usage, and the the function [UsageForContextWithTemplate()](https://github.com/alecthomas/kingpin/blob/master/usage.go#L198) method for details on the context.
#### Default help template
```
$ go run ./examples/curl/curl.go --help
usage: curl [<flags>] <command> [<args> ...]
An example implementation of curl.
Flags:
--help Show help.
-t, --timeout=5s Set connection timeout.
-H, --headers=HEADER=VALUE
Add HTTP headers to the request.
Commands:
help [<command>...]
Show help.
get url <url>
Retrieve a URL.
get file <file>
Retrieve a file.
post [<flags>] <url>
POST a resource.
```
#### Compact help template
```
$ go run ./examples/curl/curl.go --help
usage: curl [<flags>] <command> [<args> ...]
An example implementation of curl.
Flags:
--help Show help.
-t, --timeout=5s Set connection timeout.
-H, --headers=HEADER=VALUE
Add HTTP headers to the request.
Commands:
help [<command>...]
get [<flags>]
url <url>
file <file>
post [<flags>] <url>
```

View File

@ -1,42 +0,0 @@
package kingpin
// Action callback executed at various stages after all values are populated.
// The application, commands, arguments and flags all have corresponding
// actions.
type Action func(*ParseContext) error
type actionMixin struct {
actions []Action
preActions []Action
}
type actionApplier interface {
applyActions(*ParseContext) error
applyPreActions(*ParseContext) error
}
func (a *actionMixin) addAction(action Action) {
a.actions = append(a.actions, action)
}
func (a *actionMixin) addPreAction(action Action) {
a.preActions = append(a.preActions, action)
}
func (a *actionMixin) applyActions(context *ParseContext) error {
for _, action := range a.actions {
if err := action(context); err != nil {
return err
}
}
return nil
}
func (a *actionMixin) applyPreActions(context *ParseContext) error {
for _, preAction := range a.preActions {
if err := preAction(context); err != nil {
return err
}
}
return nil
}

View File

@ -1,703 +0,0 @@
package kingpin
import (
"fmt"
"io"
"os"
"regexp"
"strings"
"text/template"
)
var (
ErrCommandNotSpecified = fmt.Errorf("command not specified")
)
var (
envarTransformRegexp = regexp.MustCompile(`[^a-zA-Z0-9_]+`)
)
type ApplicationValidator func(*Application) error
// An Application contains the definitions of flags, arguments and commands
// for an application.
type Application struct {
cmdMixin
initialized bool
Name string
Help string
author string
version string
errorWriter io.Writer // Destination for errors.
usageWriter io.Writer // Destination for usage
usageTemplate string
usageFuncs template.FuncMap
validator ApplicationValidator
terminate func(status int) // See Terminate()
noInterspersed bool // can flags be interspersed with args (or must they come first)
defaultEnvars bool
completion bool
// Help flag. Exposed for user customisation.
HelpFlag *FlagClause
// Help command. Exposed for user customisation. May be nil.
HelpCommand *CmdClause
// Version flag. Exposed for user customisation. May be nil.
VersionFlag *FlagClause
}
// New creates a new Kingpin application instance.
func New(name, help string) *Application {
a := &Application{
Name: name,
Help: help,
errorWriter: os.Stderr, // Left for backwards compatibility purposes.
usageWriter: os.Stderr,
usageTemplate: DefaultUsageTemplate,
terminate: os.Exit,
}
a.flagGroup = newFlagGroup()
a.argGroup = newArgGroup()
a.cmdGroup = newCmdGroup(a)
a.HelpFlag = a.Flag("help", "Show context-sensitive help (also try --help-long and --help-man).")
a.HelpFlag.Bool()
a.Flag("help-long", "Generate long help.").Hidden().PreAction(a.generateLongHelp).Bool()
a.Flag("help-man", "Generate a man page.").Hidden().PreAction(a.generateManPage).Bool()
a.Flag("completion-bash", "Output possible completions for the given args.").Hidden().BoolVar(&a.completion)
a.Flag("completion-script-bash", "Generate completion script for bash.").Hidden().PreAction(a.generateBashCompletionScript).Bool()
a.Flag("completion-script-zsh", "Generate completion script for ZSH.").Hidden().PreAction(a.generateZSHCompletionScript).Bool()
return a
}
func (a *Application) generateLongHelp(c *ParseContext) error {
a.Writer(os.Stdout)
if err := a.UsageForContextWithTemplate(c, 2, LongHelpTemplate); err != nil {
return err
}
a.terminate(0)
return nil
}
func (a *Application) generateManPage(c *ParseContext) error {
a.Writer(os.Stdout)
if err := a.UsageForContextWithTemplate(c, 2, ManPageTemplate); err != nil {
return err
}
a.terminate(0)
return nil
}
func (a *Application) generateBashCompletionScript(c *ParseContext) error {
a.Writer(os.Stdout)
if err := a.UsageForContextWithTemplate(c, 2, BashCompletionTemplate); err != nil {
return err
}
a.terminate(0)
return nil
}
func (a *Application) generateZSHCompletionScript(c *ParseContext) error {
a.Writer(os.Stdout)
if err := a.UsageForContextWithTemplate(c, 2, ZshCompletionTemplate); err != nil {
return err
}
a.terminate(0)
return nil
}
// DefaultEnvars configures all flags (that do not already have an associated
// envar) to use a default environment variable in the form "<app>_<flag>".
//
// For example, if the application is named "foo" and a flag is named "bar-
// waz" the environment variable: "FOO_BAR_WAZ".
func (a *Application) DefaultEnvars() *Application {
a.defaultEnvars = true
return a
}
// Terminate specifies the termination handler. Defaults to os.Exit(status).
// If nil is passed, a no-op function will be used.
func (a *Application) Terminate(terminate func(int)) *Application {
if terminate == nil {
terminate = func(int) {}
}
a.terminate = terminate
return a
}
// Writer specifies the writer to use for usage and errors. Defaults to os.Stderr.
// DEPRECATED: See ErrorWriter and UsageWriter.
func (a *Application) Writer(w io.Writer) *Application {
a.errorWriter = w
a.usageWriter = w
return a
}
// ErrorWriter sets the io.Writer to use for errors.
func (a *Application) ErrorWriter(w io.Writer) *Application {
a.errorWriter = w
return a
}
// UsageWriter sets the io.Writer to use for errors.
func (a *Application) UsageWriter(w io.Writer) *Application {
a.usageWriter = w
return a
}
// UsageTemplate specifies the text template to use when displaying usage
// information. The default is UsageTemplate.
func (a *Application) UsageTemplate(template string) *Application {
a.usageTemplate = template
return a
}
// UsageFuncs adds extra functions that can be used in the usage template.
func (a *Application) UsageFuncs(funcs template.FuncMap) *Application {
a.usageFuncs = funcs
return a
}
// Validate sets a validation function to run when parsing.
func (a *Application) Validate(validator ApplicationValidator) *Application {
a.validator = validator
return a
}
// ParseContext parses the given command line and returns the fully populated
// ParseContext.
func (a *Application) ParseContext(args []string) (*ParseContext, error) {
return a.parseContext(false, args)
}
func (a *Application) parseContext(ignoreDefault bool, args []string) (*ParseContext, error) {
if err := a.init(); err != nil {
return nil, err
}
context := tokenize(args, ignoreDefault)
err := parse(context, a)
return context, err
}
// Parse parses command-line arguments. It returns the selected command and an
// error. The selected command will be a space separated subcommand, if
// subcommands have been configured.
//
// This will populate all flag and argument values, call all callbacks, and so
// on.
func (a *Application) Parse(args []string) (command string, err error) {
context, parseErr := a.ParseContext(args)
selected := []string{}
var setValuesErr error
if context == nil {
// Since we do not throw error immediately, there could be a case
// where a context returns nil. Protect against that.
return "", parseErr
}
if err = a.setDefaults(context); err != nil {
return "", err
}
selected, setValuesErr = a.setValues(context)
if err = a.applyPreActions(context, !a.completion); err != nil {
return "", err
}
if a.completion {
a.generateBashCompletion(context)
a.terminate(0)
} else {
if parseErr != nil {
return "", parseErr
}
a.maybeHelp(context)
if !context.EOL() {
return "", fmt.Errorf("unexpected argument '%s'", context.Peek())
}
if setValuesErr != nil {
return "", setValuesErr
}
command, err = a.execute(context, selected)
if err == ErrCommandNotSpecified {
a.writeUsage(context, nil)
}
}
return command, err
}
func (a *Application) writeUsage(context *ParseContext, err error) {
if err != nil {
a.Errorf("%s", err)
}
if err := a.UsageForContext(context); err != nil {
panic(err)
}
if err != nil {
a.terminate(1)
} else {
a.terminate(0)
}
}
func (a *Application) maybeHelp(context *ParseContext) {
for _, element := range context.Elements {
if flag, ok := element.Clause.(*FlagClause); ok && flag == a.HelpFlag {
// Re-parse the command-line ignoring defaults, so that help works correctly.
context, _ = a.parseContext(true, context.rawArgs)
a.writeUsage(context, nil)
}
}
}
// Version adds a --version flag for displaying the application version.
func (a *Application) Version(version string) *Application {
a.version = version
a.VersionFlag = a.Flag("version", "Show application version.").PreAction(func(*ParseContext) error {
fmt.Fprintln(a.usageWriter, version)
a.terminate(0)
return nil
})
a.VersionFlag.Bool()
return a
}
// Author sets the author output by some help templates.
func (a *Application) Author(author string) *Application {
a.author = author
return a
}
// Action callback to call when all values are populated and parsing is
// complete, but before any command, flag or argument actions.
//
// All Action() callbacks are called in the order they are encountered on the
// command line.
func (a *Application) Action(action Action) *Application {
a.addAction(action)
return a
}
// Action called after parsing completes but before validation and execution.
func (a *Application) PreAction(action Action) *Application {
a.addPreAction(action)
return a
}
// Command adds a new top-level command.
func (a *Application) Command(name, help string) *CmdClause {
return a.addCommand(name, help)
}
// Interspersed control if flags can be interspersed with positional arguments
//
// true (the default) means that they can, false means that all the flags must appear before the first positional arguments.
func (a *Application) Interspersed(interspersed bool) *Application {
a.noInterspersed = !interspersed
return a
}
func (a *Application) defaultEnvarPrefix() string {
if a.defaultEnvars {
return a.Name
}
return ""
}
func (a *Application) init() error {
if a.initialized {
return nil
}
if a.cmdGroup.have() && a.argGroup.have() {
return fmt.Errorf("can't mix top-level Arg()s with Command()s")
}
// If we have subcommands, add a help command at the top-level.
if a.cmdGroup.have() {
var command []string
a.HelpCommand = a.Command("help", "Show help.").PreAction(func(context *ParseContext) error {
a.Usage(command)
a.terminate(0)
return nil
})
a.HelpCommand.Arg("command", "Show help on command.").StringsVar(&command)
// Make help first command.
l := len(a.commandOrder)
a.commandOrder = append(a.commandOrder[l-1:l], a.commandOrder[:l-1]...)
}
if err := a.flagGroup.init(a.defaultEnvarPrefix()); err != nil {
return err
}
if err := a.cmdGroup.init(); err != nil {
return err
}
if err := a.argGroup.init(); err != nil {
return err
}
for _, cmd := range a.commands {
if err := cmd.init(); err != nil {
return err
}
}
flagGroups := []*flagGroup{a.flagGroup}
for _, cmd := range a.commandOrder {
if err := checkDuplicateFlags(cmd, flagGroups); err != nil {
return err
}
}
a.initialized = true
return nil
}
// Recursively check commands for duplicate flags.
func checkDuplicateFlags(current *CmdClause, flagGroups []*flagGroup) error {
// Check for duplicates.
for _, flags := range flagGroups {
for _, flag := range current.flagOrder {
if flag.shorthand != 0 {
if _, ok := flags.short[string(flag.shorthand)]; ok {
return fmt.Errorf("duplicate short flag -%c", flag.shorthand)
}
}
if _, ok := flags.long[flag.name]; ok {
return fmt.Errorf("duplicate long flag --%s", flag.name)
}
}
}
flagGroups = append(flagGroups, current.flagGroup)
// Check subcommands.
for _, subcmd := range current.commandOrder {
if err := checkDuplicateFlags(subcmd, flagGroups); err != nil {
return err
}
}
return nil
}
func (a *Application) execute(context *ParseContext, selected []string) (string, error) {
var err error
if err = a.validateRequired(context); err != nil {
return "", err
}
if err = a.applyValidators(context); err != nil {
return "", err
}
if err = a.applyActions(context); err != nil {
return "", err
}
command := strings.Join(selected, " ")
if command == "" && a.cmdGroup.have() {
return "", ErrCommandNotSpecified
}
return command, err
}
func (a *Application) setDefaults(context *ParseContext) error {
flagElements := map[string]*ParseElement{}
for _, element := range context.Elements {
if flag, ok := element.Clause.(*FlagClause); ok {
if flag.name == "help" {
return nil
}
if flag.name == "version" {
return nil
}
flagElements[flag.name] = element
}
}
argElements := map[string]*ParseElement{}
for _, element := range context.Elements {
if arg, ok := element.Clause.(*ArgClause); ok {
argElements[arg.name] = element
}
}
// Check required flags and set defaults.
for _, flag := range context.flags.long {
if flagElements[flag.name] == nil {
if err := flag.setDefault(); err != nil {
return err
}
}
}
for _, arg := range context.arguments.args {
if argElements[arg.name] == nil {
if err := arg.setDefault(); err != nil {
return err
}
}
}
return nil
}
func (a *Application) validateRequired(context *ParseContext) error {
flagElements := map[string]*ParseElement{}
for _, element := range context.Elements {
if flag, ok := element.Clause.(*FlagClause); ok {
flagElements[flag.name] = element
}
}
argElements := map[string]*ParseElement{}
for _, element := range context.Elements {
if arg, ok := element.Clause.(*ArgClause); ok {
argElements[arg.name] = element
}
}
// Check required flags and set defaults.
var missingFlags []string
for _, flag := range context.flags.long {
if flagElements[flag.name] == nil {
// Check required flags were provided.
if flag.needsValue() {
missingFlags = append(missingFlags, fmt.Sprintf("'--%s'", flag.name))
}
}
}
if len(missingFlags) != 0 {
return fmt.Errorf("required flag(s) %s not provided", strings.Join(missingFlags, ", "))
}
for _, arg := range context.arguments.args {
if argElements[arg.name] == nil {
if arg.needsValue() {
return fmt.Errorf("required argument '%s' not provided", arg.name)
}
}
}
return nil
}
func (a *Application) setValues(context *ParseContext) (selected []string, err error) {
// Set all arg and flag values.
var (
lastCmd *CmdClause
flagSet = map[string]struct{}{}
)
for _, element := range context.Elements {
switch clause := element.Clause.(type) {
case *FlagClause:
if _, ok := flagSet[clause.name]; ok {
if v, ok := clause.value.(repeatableFlag); !ok || !v.IsCumulative() {
return nil, fmt.Errorf("flag '%s' cannot be repeated", clause.name)
}
}
if err = clause.value.Set(*element.Value); err != nil {
return
}
flagSet[clause.name] = struct{}{}
case *ArgClause:
if err = clause.value.Set(*element.Value); err != nil {
return
}
case *CmdClause:
selected = append(selected, clause.name)
lastCmd = clause
}
}
if lastCmd != nil && len(lastCmd.commands) > 0 {
return nil, fmt.Errorf("must select a subcommand of '%s'", lastCmd.FullCommand())
}
return
}
func (a *Application) applyValidators(context *ParseContext) (err error) {
// Call command validation functions.
for _, element := range context.Elements {
if cmd, ok := element.Clause.(*CmdClause); ok && cmd.validator != nil {
if err = cmd.validator(cmd); err != nil {
return err
}
}
}
if a.validator != nil {
err = a.validator(a)
}
return err
}
func (a *Application) applyPreActions(context *ParseContext, dispatch bool) error {
if err := a.actionMixin.applyPreActions(context); err != nil {
return err
}
// Dispatch to actions.
if dispatch {
for _, element := range context.Elements {
if applier, ok := element.Clause.(actionApplier); ok {
if err := applier.applyPreActions(context); err != nil {
return err
}
}
}
}
return nil
}
func (a *Application) applyActions(context *ParseContext) error {
if err := a.actionMixin.applyActions(context); err != nil {
return err
}
// Dispatch to actions.
for _, element := range context.Elements {
if applier, ok := element.Clause.(actionApplier); ok {
if err := applier.applyActions(context); err != nil {
return err
}
}
}
return nil
}
// Errorf prints an error message to w in the format "<appname>: error: <message>".
func (a *Application) Errorf(format string, args ...interface{}) {
fmt.Fprintf(a.errorWriter, a.Name+": error: "+format+"\n", args...)
}
// Fatalf writes a formatted error to w then terminates with exit status 1.
func (a *Application) Fatalf(format string, args ...interface{}) {
a.Errorf(format, args...)
a.terminate(1)
}
// FatalUsage prints an error message followed by usage information, then
// exits with a non-zero status.
func (a *Application) FatalUsage(format string, args ...interface{}) {
a.Errorf(format, args...)
// Force usage to go to error output.
a.usageWriter = a.errorWriter
a.Usage([]string{})
a.terminate(1)
}
// FatalUsageContext writes a printf formatted error message to w, then usage
// information for the given ParseContext, before exiting.
func (a *Application) FatalUsageContext(context *ParseContext, format string, args ...interface{}) {
a.Errorf(format, args...)
if err := a.UsageForContext(context); err != nil {
panic(err)
}
a.terminate(1)
}
// FatalIfError prints an error and exits if err is not nil. The error is printed
// with the given formatted string, if any.
func (a *Application) FatalIfError(err error, format string, args ...interface{}) {
if err != nil {
prefix := ""
if format != "" {
prefix = fmt.Sprintf(format, args...) + ": "
}
a.Errorf(prefix+"%s", err)
a.terminate(1)
}
}
func (a *Application) completionOptions(context *ParseContext) []string {
args := context.rawArgs
var (
currArg string
prevArg string
target cmdMixin
)
numArgs := len(args)
if numArgs > 1 {
args = args[1:]
currArg = args[len(args)-1]
}
if numArgs > 2 {
prevArg = args[len(args)-2]
}
target = a.cmdMixin
if context.SelectedCommand != nil {
// A subcommand was in use. We will use it as the target
target = context.SelectedCommand.cmdMixin
}
if (currArg != "" && strings.HasPrefix(currArg, "--")) || strings.HasPrefix(prevArg, "--") {
if context.argsOnly {
return nil
}
// Perform completion for A flag. The last/current argument started with "-"
var (
flagName string // The name of a flag if given (could be half complete)
flagValue string // The value assigned to a flag (if given) (could be half complete)
)
if strings.HasPrefix(prevArg, "--") && !strings.HasPrefix(currArg, "--") {
// Matches: ./myApp --flag value
// Wont Match: ./myApp --flag --
flagName = prevArg[2:] // Strip the "--"
flagValue = currArg
} else if strings.HasPrefix(currArg, "--") {
// Matches: ./myApp --flag --
// Matches: ./myApp --flag somevalue --
// Matches: ./myApp --
flagName = currArg[2:] // Strip the "--"
}
options, flagMatched, valueMatched := target.FlagCompletion(flagName, flagValue)
if valueMatched {
// Value Matched. Show cmdCompletions
return target.CmdCompletion(context)
}
// Add top level flags if we're not at the top level and no match was found.
if context.SelectedCommand != nil && !flagMatched {
topOptions, topFlagMatched, topValueMatched := a.FlagCompletion(flagName, flagValue)
if topValueMatched {
// Value Matched. Back to cmdCompletions
return target.CmdCompletion(context)
}
if topFlagMatched {
// Top level had a flag which matched the input. Return it's options.
options = topOptions
} else {
// Add top level flags
options = append(options, topOptions...)
}
}
return options
}
// Perform completion for sub commands and arguments.
return target.CmdCompletion(context)
}
func (a *Application) generateBashCompletion(context *ParseContext) {
options := a.completionOptions(context)
fmt.Printf("%s", strings.Join(options, "\n"))
}
func envarTransform(name string) string {
return strings.ToUpper(envarTransformRegexp.ReplaceAllString(name, "_"))
}

View File

@ -1,205 +0,0 @@
package kingpin
import (
"fmt"
)
type argGroup struct {
args []*ArgClause
}
func newArgGroup() *argGroup {
return &argGroup{}
}
func (a *argGroup) have() bool {
return len(a.args) > 0
}
// GetArg gets an argument definition.
//
// This allows existing arguments to be modified after definition but before parsing. Useful for
// modular applications.
func (a *argGroup) GetArg(name string) *ArgClause {
for _, arg := range a.args {
if arg.name == name {
return arg
}
}
return nil
}
func (a *argGroup) Arg(name, help string) *ArgClause {
arg := newArg(name, help)
a.args = append(a.args, arg)
return arg
}
func (a *argGroup) init() error {
required := 0
seen := map[string]struct{}{}
previousArgMustBeLast := false
for i, arg := range a.args {
if previousArgMustBeLast {
return fmt.Errorf("Args() can't be followed by another argument '%s'", arg.name)
}
if arg.consumesRemainder() {
previousArgMustBeLast = true
}
if _, ok := seen[arg.name]; ok {
return fmt.Errorf("duplicate argument '%s'", arg.name)
}
seen[arg.name] = struct{}{}
if arg.required && required != i {
return fmt.Errorf("required arguments found after non-required")
}
if arg.required {
required++
}
if err := arg.init(); err != nil {
return err
}
}
return nil
}
type ArgClause struct {
actionMixin
parserMixin
completionsMixin
envarMixin
name string
help string
defaultValues []string
placeholder string
hidden bool
required bool
}
func newArg(name, help string) *ArgClause {
a := &ArgClause{
name: name,
help: help,
}
return a
}
func (a *ArgClause) setDefault() error {
if a.HasEnvarValue() {
if v, ok := a.value.(remainderArg); !ok || !v.IsCumulative() {
// Use the value as-is
return a.value.Set(a.GetEnvarValue())
}
for _, value := range a.GetSplitEnvarValue() {
if err := a.value.Set(value); err != nil {
return err
}
}
return nil
}
if len(a.defaultValues) > 0 {
for _, defaultValue := range a.defaultValues {
if err := a.value.Set(defaultValue); err != nil {
return err
}
}
return nil
}
return nil
}
func (a *ArgClause) needsValue() bool {
haveDefault := len(a.defaultValues) > 0
return a.required && !(haveDefault || a.HasEnvarValue())
}
func (a *ArgClause) consumesRemainder() bool {
if r, ok := a.value.(remainderArg); ok {
return r.IsCumulative()
}
return false
}
// Hidden hides the argument from usage but still allows it to be used.
func (a *ArgClause) Hidden() *ArgClause {
a.hidden = true
return a
}
// PlaceHolder sets the place-holder string used for arg values in the help. The
// default behaviour is to use the arg name between < > brackets.
func (a *ArgClause) PlaceHolder(value string) *ArgClause {
a.placeholder = value
return a
}
// Required arguments must be input by the user. They can not have a Default() value provided.
func (a *ArgClause) Required() *ArgClause {
a.required = true
return a
}
// Default values for this argument. They *must* be parseable by the value of the argument.
func (a *ArgClause) Default(values ...string) *ArgClause {
a.defaultValues = values
return a
}
// Envar overrides the default value(s) for a flag from an environment variable,
// if it is set. Several default values can be provided by using new lines to
// separate them.
func (a *ArgClause) Envar(name string) *ArgClause {
a.envar = name
a.noEnvar = false
return a
}
// NoEnvar forces environment variable defaults to be disabled for this flag.
// Most useful in conjunction with app.DefaultEnvars().
func (a *ArgClause) NoEnvar() *ArgClause {
a.envar = ""
a.noEnvar = true
return a
}
func (a *ArgClause) Action(action Action) *ArgClause {
a.addAction(action)
return a
}
func (a *ArgClause) PreAction(action Action) *ArgClause {
a.addPreAction(action)
return a
}
// HintAction registers a HintAction (function) for the arg to provide completions
func (a *ArgClause) HintAction(action HintAction) *ArgClause {
a.addHintAction(action)
return a
}
// HintOptions registers any number of options for the flag to provide completions
func (a *ArgClause) HintOptions(options ...string) *ArgClause {
a.addHintAction(func() []string {
return options
})
return a
}
// Help sets the help message.
func (a *ArgClause) Help(help string) *ArgClause {
a.help = help
return a
}
func (a *ArgClause) init() error {
if a.required && len(a.defaultValues) > 0 {
return fmt.Errorf("required argument '%s' with unusable default value", a.name)
}
if a.value == nil {
return fmt.Errorf("no parser defined for arg '%s'", a.name)
}
return nil
}

View File

@ -1,325 +0,0 @@
package kingpin
import (
"fmt"
"strings"
)
type cmdMixin struct {
*flagGroup
*argGroup
*cmdGroup
actionMixin
}
// CmdCompletion returns completion options for arguments, if that's where
// parsing left off, or commands if there aren't any unsatisfied args.
func (c *cmdMixin) CmdCompletion(context *ParseContext) []string {
var options []string
// Count args already satisfied - we won't complete those, and add any
// default commands' alternatives, since they weren't listed explicitly
// and the user may want to explicitly list something else.
argsSatisfied := 0
allSatisfied := false
ElementLoop:
for _, el := range context.Elements {
switch clause := el.Clause.(type) {
case *ArgClause:
// Each new element should reset the previous state
allSatisfied = false
options = nil
if el.Value != nil && *el.Value != "" {
// Get the list of valid options for the last argument
validOptions := c.argGroup.args[argsSatisfied].resolveCompletions()
if len(validOptions) == 0 {
// If there are no options for this argument,
// mark is as allSatisfied as we can't suggest anything
if !clause.consumesRemainder() {
argsSatisfied++
allSatisfied = true
}
continue ElementLoop
}
for _, opt := range validOptions {
if opt == *el.Value {
// We have an exact match
// We don't need to suggest any option
if !clause.consumesRemainder() {
argsSatisfied++
}
continue ElementLoop
}
if strings.HasPrefix(opt, *el.Value) {
// If the option match the partially entered argument, add it to the list
options = append(options, opt)
}
}
// Avoid further completion as we have done everything we could
if !clause.consumesRemainder() {
argsSatisfied++
allSatisfied = true
}
}
case *CmdClause:
options = append(options, clause.completionAlts...)
default:
}
}
if argsSatisfied < len(c.argGroup.args) && !allSatisfied {
// Since not all args have been satisfied, show options for the current one
options = append(options, c.argGroup.args[argsSatisfied].resolveCompletions()...)
} else {
// If all args are satisfied, then go back to completing commands
for _, cmd := range c.cmdGroup.commandOrder {
if !cmd.hidden {
options = append(options, cmd.name)
}
}
}
return options
}
func (c *cmdMixin) FlagCompletion(flagName string, flagValue string) (choices []string, flagMatch bool, optionMatch bool) {
// Check if flagName matches a known flag.
// If it does, show the options for the flag
// Otherwise, show all flags
options := []string{}
for _, flag := range c.flagGroup.flagOrder {
// Loop through each flag and determine if a match exists
if flag.name == flagName {
// User typed entire flag. Need to look for flag options.
options = flag.resolveCompletions()
if len(options) == 0 {
// No Options to Choose From, Assume Match.
return options, true, true
}
// Loop options to find if the user specified value matches
isPrefix := false
matched := false
for _, opt := range options {
if flagValue == opt {
matched = true
} else if strings.HasPrefix(opt, flagValue) {
isPrefix = true
}
}
// Matched Flag Directly
// Flag Value Not Prefixed, and Matched Directly
return options, true, !isPrefix && matched
}
if !flag.hidden {
options = append(options, "--"+flag.name)
}
}
// No Flag directly matched.
return options, false, false
}
type cmdGroup struct {
app *Application
parent *CmdClause
commands map[string]*CmdClause
commandOrder []*CmdClause
}
func (c *cmdGroup) defaultSubcommand() *CmdClause {
for _, cmd := range c.commandOrder {
if cmd.isDefault {
return cmd
}
}
return nil
}
func (c *cmdGroup) cmdNames() []string {
names := make([]string, 0, len(c.commandOrder))
for _, cmd := range c.commandOrder {
names = append(names, cmd.name)
}
return names
}
// GetArg gets a command definition.
//
// This allows existing commands to be modified after definition but before parsing. Useful for
// modular applications.
func (c *cmdGroup) GetCommand(name string) *CmdClause {
return c.commands[name]
}
func newCmdGroup(app *Application) *cmdGroup {
return &cmdGroup{
app: app,
commands: make(map[string]*CmdClause),
}
}
func (c *cmdGroup) flattenedCommands() (out []*CmdClause) {
for _, cmd := range c.commandOrder {
if len(cmd.commands) == 0 {
out = append(out, cmd)
}
out = append(out, cmd.flattenedCommands()...)
}
return
}
func (c *cmdGroup) addCommand(name, help string) *CmdClause {
cmd := newCommand(c.app, name, help)
c.commands[name] = cmd
c.commandOrder = append(c.commandOrder, cmd)
return cmd
}
func (c *cmdGroup) init() error {
seen := map[string]bool{}
if c.defaultSubcommand() != nil && !c.have() {
return fmt.Errorf("default subcommand %q provided but no subcommands defined", c.defaultSubcommand().name)
}
defaults := []string{}
for _, cmd := range c.commandOrder {
if cmd.isDefault {
defaults = append(defaults, cmd.name)
}
if seen[cmd.name] {
return fmt.Errorf("duplicate command %q", cmd.name)
}
seen[cmd.name] = true
for _, alias := range cmd.aliases {
if seen[alias] {
return fmt.Errorf("alias duplicates existing command %q", alias)
}
c.commands[alias] = cmd
}
if err := cmd.init(); err != nil {
return err
}
}
if len(defaults) > 1 {
return fmt.Errorf("more than one default subcommand exists: %s", strings.Join(defaults, ", "))
}
return nil
}
func (c *cmdGroup) have() bool {
return len(c.commands) > 0
}
type CmdClauseValidator func(*CmdClause) error
// A CmdClause is a single top-level command. It encapsulates a set of flags
// and either subcommands or positional arguments.
type CmdClause struct {
cmdMixin
app *Application
name string
aliases []string
help string
helpLong string
isDefault bool
validator CmdClauseValidator
hidden bool
completionAlts []string
}
func newCommand(app *Application, name, help string) *CmdClause {
c := &CmdClause{
app: app,
name: name,
help: help,
}
c.flagGroup = newFlagGroup()
c.argGroup = newArgGroup()
c.cmdGroup = newCmdGroup(app)
return c
}
// Add an Alias for this command.
func (c *CmdClause) Alias(name string) *CmdClause {
c.aliases = append(c.aliases, name)
return c
}
// Validate sets a validation function to run when parsing.
func (c *CmdClause) Validate(validator CmdClauseValidator) *CmdClause {
c.validator = validator
return c
}
func (c *CmdClause) FullCommand() string {
out := []string{c.name}
for p := c.parent; p != nil; p = p.parent {
out = append([]string{p.name}, out...)
}
return strings.Join(out, " ")
}
// Command adds a new sub-command.
func (c *CmdClause) Command(name, help string) *CmdClause {
cmd := c.addCommand(name, help)
cmd.parent = c
return cmd
}
// Default makes this command the default if commands don't match.
func (c *CmdClause) Default() *CmdClause {
c.isDefault = true
return c
}
func (c *CmdClause) Action(action Action) *CmdClause {
c.addAction(action)
return c
}
func (c *CmdClause) PreAction(action Action) *CmdClause {
c.addPreAction(action)
return c
}
// Help sets the help message.
func (c *CmdClause) Help(help string) *CmdClause {
c.help = help
return c
}
func (c *CmdClause) init() error {
if err := c.flagGroup.init(c.app.defaultEnvarPrefix()); err != nil {
return err
}
if c.argGroup.have() && c.cmdGroup.have() {
return fmt.Errorf("can't mix Arg()s with Command()s")
}
if err := c.argGroup.init(); err != nil {
return err
}
if err := c.cmdGroup.init(); err != nil {
return err
}
return nil
}
func (c *CmdClause) Hidden() *CmdClause {
c.hidden = true
return c
}
// HelpLong adds a long help text, which can be used in usage templates.
// For example, to use a longer help text in the command-specific help
// than in the apps root help.
func (c *CmdClause) HelpLong(help string) *CmdClause {
c.helpLong = help
return c
}

View File

@ -1,33 +0,0 @@
package kingpin
// HintAction is a function type who is expected to return a slice of possible
// command line arguments.
type HintAction func() []string
type completionsMixin struct {
hintActions []HintAction
builtinHintActions []HintAction
}
func (a *completionsMixin) addHintAction(action HintAction) {
a.hintActions = append(a.hintActions, action)
}
// Allow adding of HintActions which are added internally, ie, EnumVar
func (a *completionsMixin) addHintActionBuiltin(action HintAction) {
a.builtinHintActions = append(a.builtinHintActions, action)
}
func (a *completionsMixin) resolveCompletions() []string {
var hints []string
options := a.builtinHintActions
if len(a.hintActions) > 0 {
// User specified their own hintActions. Use those instead.
options = a.hintActions
}
for _, hintAction := range options {
hints = append(hints, hintAction()...)
}
return hints
}

View File

@ -1,68 +0,0 @@
// Package kingpin provides command line interfaces like this:
//
// $ chat
// usage: chat [<flags>] <command> [<flags>] [<args> ...]
//
// Flags:
// --debug enable debug mode
// --help Show help.
// --server=127.0.0.1 server address
//
// Commands:
// help <command>
// Show help for a command.
//
// post [<flags>] <channel>
// Post a message to a channel.
//
// register <nick> <name>
// Register a new user.
//
// $ chat help post
// usage: chat [<flags>] post [<flags>] <channel> [<text>]
//
// Post a message to a channel.
//
// Flags:
// --image=IMAGE image to post
//
// Args:
// <channel> channel to post to
// [<text>] text to post
// $ chat post --image=~/Downloads/owls.jpg pics
//
// From code like this:
//
// package main
//
// import "github.com/alecthomas/kingpin/v2"
//
// var (
// debug = kingpin.Flag("debug", "enable debug mode").Default("false").Bool()
// serverIP = kingpin.Flag("server", "server address").Default("127.0.0.1").IP()
//
// register = kingpin.Command("register", "Register a new user.")
// registerNick = register.Arg("nick", "nickname for user").Required().String()
// registerName = register.Arg("name", "name of user").Required().String()
//
// post = kingpin.Command("post", "Post a message to a channel.")
// postImage = post.Flag("image", "image to post").ExistingFile()
// postChannel = post.Arg("channel", "channel to post to").Required().String()
// postText = post.Arg("text", "text to post").String()
// )
//
// func main() {
// switch kingpin.Parse() {
// // Register user
// case "register":
// println(*registerNick)
//
// // Post message
// case "post":
// if *postImage != nil {
// }
// if *postText != "" {
// }
// }
// }
package kingpin

View File

@ -1,40 +0,0 @@
package kingpin
import (
"os"
"regexp"
)
var (
envVarValuesSeparator = "\r?\n"
envVarValuesTrimmer = regexp.MustCompile(envVarValuesSeparator + "$")
envVarValuesSplitter = regexp.MustCompile(envVarValuesSeparator)
)
type envarMixin struct {
envar string
noEnvar bool
}
func (e *envarMixin) HasEnvarValue() bool {
return e.GetEnvarValue() != ""
}
func (e *envarMixin) GetEnvarValue() string {
if e.noEnvar || e.envar == "" {
return ""
}
return os.Getenv(e.envar)
}
func (e *envarMixin) GetSplitEnvarValue() []string {
envarValue := e.GetEnvarValue()
if envarValue == "" {
return []string{}
}
// Split by new line to extract multiple values, if any.
trimmed := envVarValuesTrimmer.ReplaceAllString(envarValue, "")
return envVarValuesSplitter.Split(trimmed, -1)
}

View File

@ -1,332 +0,0 @@
package kingpin
import (
"fmt"
"strings"
)
type flagGroup struct {
short map[string]*FlagClause
long map[string]*FlagClause
flagOrder []*FlagClause
}
func newFlagGroup() *flagGroup {
return &flagGroup{
short: map[string]*FlagClause{},
long: map[string]*FlagClause{},
}
}
// GetFlag gets a flag definition.
//
// This allows existing flags to be modified after definition but before parsing. Useful for
// modular applications.
func (f *flagGroup) GetFlag(name string) *FlagClause {
return f.long[name]
}
// Flag defines a new flag with the given long name and help.
func (f *flagGroup) Flag(name, help string) *FlagClause {
flag := newFlag(name, help)
f.long[name] = flag
f.flagOrder = append(f.flagOrder, flag)
return flag
}
func (f *flagGroup) init(defaultEnvarPrefix string) error {
if err := f.checkDuplicates(); err != nil {
return err
}
for _, flag := range f.long {
if defaultEnvarPrefix != "" && !flag.noEnvar && flag.envar == "" {
flag.envar = envarTransform(defaultEnvarPrefix + "_" + flag.name)
}
if err := flag.init(); err != nil {
return err
}
if flag.shorthand != 0 {
f.short[string(flag.shorthand)] = flag
}
}
return nil
}
func (f *flagGroup) checkDuplicates() error {
seenShort := map[rune]bool{}
seenLong := map[string]bool{}
for _, flag := range f.flagOrder {
if flag.shorthand != 0 {
if _, ok := seenShort[flag.shorthand]; ok {
return fmt.Errorf("duplicate short flag -%c", flag.shorthand)
}
seenShort[flag.shorthand] = true
}
if _, ok := seenLong[flag.name]; ok {
return fmt.Errorf("duplicate long flag --%s", flag.name)
}
seenLong[flag.name] = true
}
return nil
}
func (f *flagGroup) parse(context *ParseContext) (*FlagClause, error) {
var token *Token
loop:
for {
token = context.Peek()
switch token.Type {
case TokenEOL:
break loop
case TokenLong, TokenShort:
flagToken := token
defaultValue := ""
var flag *FlagClause
var ok bool
invert := false
name := token.Value
if token.Type == TokenLong {
flag, ok = f.long[name]
if !ok {
if strings.HasPrefix(name, "no-") {
name = name[3:]
invert = true
}
flag, ok = f.long[name]
}
if !ok {
return nil, fmt.Errorf("unknown long flag '%s'", flagToken)
}
} else {
flag, ok = f.short[name]
if !ok {
return nil, fmt.Errorf("unknown short flag '%s'", flagToken)
}
}
context.Next()
flag.isSetByUser()
fb, ok := flag.value.(boolFlag)
if ok && fb.IsBoolFlag() {
if invert {
defaultValue = "false"
} else {
defaultValue = "true"
}
} else {
if invert {
context.Push(token)
return nil, fmt.Errorf("unknown long flag '%s'", flagToken)
}
token = context.Peek()
if token.Type != TokenArg {
context.Push(token)
return nil, fmt.Errorf("expected argument for flag '%s'", flagToken)
}
context.Next()
defaultValue = token.Value
}
context.matchedFlag(flag, defaultValue)
return flag, nil
default:
break loop
}
}
return nil, nil
}
// FlagClause is a fluid interface used to build flags.
type FlagClause struct {
parserMixin
actionMixin
completionsMixin
envarMixin
name string
shorthand rune
help string
defaultValues []string
placeholder string
hidden bool
setByUser *bool
}
func newFlag(name, help string) *FlagClause {
f := &FlagClause{
name: name,
help: help,
}
return f
}
func (f *FlagClause) setDefault() error {
if f.HasEnvarValue() {
if v, ok := f.value.(repeatableFlag); !ok || !v.IsCumulative() {
// Use the value as-is
return f.value.Set(f.GetEnvarValue())
} else {
for _, value := range f.GetSplitEnvarValue() {
if err := f.value.Set(value); err != nil {
return err
}
}
return nil
}
}
if len(f.defaultValues) > 0 {
for _, defaultValue := range f.defaultValues {
if err := f.value.Set(defaultValue); err != nil {
return err
}
}
return nil
}
return nil
}
func (f *FlagClause) isSetByUser() {
if f.setByUser != nil {
*f.setByUser = true
}
}
func (f *FlagClause) needsValue() bool {
haveDefault := len(f.defaultValues) > 0
return f.required && !(haveDefault || f.HasEnvarValue())
}
func (f *FlagClause) init() error {
if f.required && len(f.defaultValues) > 0 {
return fmt.Errorf("required flag '--%s' with default value that will never be used", f.name)
}
if f.value == nil {
return fmt.Errorf("no type defined for --%s (eg. .String())", f.name)
}
if v, ok := f.value.(repeatableFlag); (!ok || !v.IsCumulative()) && len(f.defaultValues) > 1 {
return fmt.Errorf("invalid default for '--%s', expecting single value", f.name)
}
return nil
}
// Dispatch to the given function after the flag is parsed and validated.
func (f *FlagClause) Action(action Action) *FlagClause {
f.addAction(action)
return f
}
func (f *FlagClause) PreAction(action Action) *FlagClause {
f.addPreAction(action)
return f
}
// HintAction registers a HintAction (function) for the flag to provide completions
func (a *FlagClause) HintAction(action HintAction) *FlagClause {
a.addHintAction(action)
return a
}
// HintOptions registers any number of options for the flag to provide completions
func (a *FlagClause) HintOptions(options ...string) *FlagClause {
a.addHintAction(func() []string {
return options
})
return a
}
func (a *FlagClause) EnumVar(target *string, options ...string) {
a.parserMixin.EnumVar(target, options...)
a.addHintActionBuiltin(func() []string {
return options
})
}
func (a *FlagClause) Enum(options ...string) (target *string) {
a.addHintActionBuiltin(func() []string {
return options
})
return a.parserMixin.Enum(options...)
}
// IsSetByUser let to know if the flag was set by the user
func (f *FlagClause) IsSetByUser(setByUser *bool) *FlagClause {
if setByUser != nil {
*setByUser = false
}
f.setByUser = setByUser
return f
}
// Default values for this flag. They *must* be parseable by the value of the flag.
func (f *FlagClause) Default(values ...string) *FlagClause {
f.defaultValues = values
return f
}
// DEPRECATED: Use Envar(name) instead.
func (f *FlagClause) OverrideDefaultFromEnvar(envar string) *FlagClause {
return f.Envar(envar)
}
// Envar overrides the default value(s) for a flag from an environment variable,
// if it is set. Several default values can be provided by using new lines to
// separate them.
func (f *FlagClause) Envar(name string) *FlagClause {
f.envar = name
f.noEnvar = false
return f
}
// NoEnvar forces environment variable defaults to be disabled for this flag.
// Most useful in conjunction with app.DefaultEnvars().
func (f *FlagClause) NoEnvar() *FlagClause {
f.envar = ""
f.noEnvar = true
return f
}
// PlaceHolder sets the place-holder string used for flag values in the help. The
// default behaviour is to use the value provided by Default() if provided,
// then fall back on the capitalized flag name.
func (f *FlagClause) PlaceHolder(placeholder string) *FlagClause {
f.placeholder = placeholder
return f
}
// Hidden hides a flag from usage but still allows it to be used.
func (f *FlagClause) Hidden() *FlagClause {
f.hidden = true
return f
}
// Required makes the flag required. You can not provide a Default() value to a Required() flag.
func (f *FlagClause) Required() *FlagClause {
f.required = true
return f
}
// Short sets the short flag name.
func (f *FlagClause) Short(name rune) *FlagClause {
f.shorthand = name
return f
}
// Help sets the help message.
func (f *FlagClause) Help(help string) *FlagClause {
f.help = help
return f
}
// Bool makes this flag a boolean flag.
func (f *FlagClause) Bool() (target *bool) {
target = new(bool)
f.SetValue(newBoolValue(target))
return
}

View File

@ -1,96 +0,0 @@
package kingpin
import (
"os"
"path/filepath"
)
var (
// CommandLine is the default Kingpin parser.
CommandLine = New(filepath.Base(os.Args[0]), "")
// Global help flag. Exposed for user customisation.
HelpFlag = CommandLine.HelpFlag
// Top-level help command. Exposed for user customisation. May be nil.
HelpCommand = CommandLine.HelpCommand
// Global version flag. Exposed for user customisation. May be nil.
VersionFlag = CommandLine.VersionFlag
// Whether to file expansion with '@' is enabled.
EnableFileExpansion = true
)
// Command adds a new command to the default parser.
func Command(name, help string) *CmdClause {
return CommandLine.Command(name, help)
}
// Flag adds a new flag to the default parser.
func Flag(name, help string) *FlagClause {
return CommandLine.Flag(name, help)
}
// Arg adds a new argument to the top-level of the default parser.
func Arg(name, help string) *ArgClause {
return CommandLine.Arg(name, help)
}
// Parse and return the selected command. Will call the termination handler if
// an error is encountered.
func Parse() string {
selected := MustParse(CommandLine.Parse(os.Args[1:]))
if selected == "" && CommandLine.cmdGroup.have() {
Usage()
CommandLine.terminate(0)
}
return selected
}
// Errorf prints an error message to stderr.
func Errorf(format string, args ...interface{}) {
CommandLine.Errorf(format, args...)
}
// Fatalf prints an error message to stderr and exits.
func Fatalf(format string, args ...interface{}) {
CommandLine.Fatalf(format, args...)
}
// FatalIfError prints an error and exits if err is not nil. The error is printed
// with the given prefix.
func FatalIfError(err error, format string, args ...interface{}) {
CommandLine.FatalIfError(err, format, args...)
}
// FatalUsage prints an error message followed by usage information, then
// exits with a non-zero status.
func FatalUsage(format string, args ...interface{}) {
CommandLine.FatalUsage(format, args...)
}
// FatalUsageContext writes a printf formatted error message to stderr, then
// usage information for the given ParseContext, before exiting.
func FatalUsageContext(context *ParseContext, format string, args ...interface{}) {
CommandLine.FatalUsageContext(context, format, args...)
}
// Usage prints usage to stderr.
func Usage() {
CommandLine.Usage(os.Args[1:])
}
// Set global usage template to use (defaults to DefaultUsageTemplate).
func UsageTemplate(template string) *Application {
return CommandLine.UsageTemplate(template)
}
// MustParse can be used with app.Parse(args) to exit with an error if parsing fails.
func MustParse(command string, err error) string {
if err != nil {
Fatalf("%s, try --help", err)
}
return command
}
// Version adds a flag for displaying the application version number.
func Version(version string) *Application {
return CommandLine.Version(version)
}

View File

@ -1,9 +0,0 @@
// +build appengine !linux,!freebsd,!darwin,!dragonfly,!netbsd,!openbsd
package kingpin
import "io"
func guessWidth(w io.Writer) int {
return 80
}

View File

@ -1,38 +0,0 @@
// +build !appengine,linux freebsd darwin dragonfly netbsd openbsd
package kingpin
import (
"io"
"os"
"strconv"
"syscall"
"unsafe"
)
func guessWidth(w io.Writer) int {
// check if COLUMNS env is set to comply with
// http://pubs.opengroup.org/onlinepubs/009604499/basedefs/xbd_chap08.html
colsStr := os.Getenv("COLUMNS")
if colsStr != "" {
if cols, err := strconv.Atoi(colsStr); err == nil {
return cols
}
}
if t, ok := w.(*os.File); ok {
fd := t.Fd()
var dimensions [4]uint16
if _, _, err := syscall.Syscall6(
syscall.SYS_IOCTL,
uintptr(fd),
uintptr(syscall.TIOCGWINSZ),
uintptr(unsafe.Pointer(&dimensions)),
0, 0, 0,
); err == 0 {
return int(dimensions[1])
}
}
return 80
}

View File

@ -1,273 +0,0 @@
package kingpin
import (
"fmt"
"strconv"
"strings"
)
// Data model for Kingpin command-line structure.
var (
ignoreInCount = map[string]bool{
"help": true,
"help-long": true,
"help-man": true,
"completion-bash": true,
"completion-script-bash": true,
"completion-script-zsh": true,
}
)
type FlagGroupModel struct {
Flags []*FlagModel
}
func (f *FlagGroupModel) FlagSummary() string {
out := []string{}
count := 0
for _, flag := range f.Flags {
if !ignoreInCount[flag.Name] {
count++
}
if flag.Required {
if flag.IsBoolFlag() {
out = append(out, fmt.Sprintf("--[no-]%s", flag.Name))
} else {
out = append(out, fmt.Sprintf("--%s=%s", flag.Name, flag.FormatPlaceHolder()))
}
}
}
if count != len(out) {
out = append(out, "[<flags>]")
}
return strings.Join(out, " ")
}
type FlagModel struct {
Name string
Help string
Short rune
Default []string
Envar string
PlaceHolder string
Required bool
Hidden bool
Value Value
}
func (f *FlagModel) String() string {
if f.Value == nil {
return ""
}
return f.Value.String()
}
func (f *FlagModel) IsBoolFlag() bool {
if fl, ok := f.Value.(boolFlag); ok {
return fl.IsBoolFlag()
}
return false
}
func (f *FlagModel) FormatPlaceHolder() string {
if f.PlaceHolder != "" {
return f.PlaceHolder
}
if len(f.Default) > 0 {
ellipsis := ""
if len(f.Default) > 1 {
ellipsis = "..."
}
if _, ok := f.Value.(*stringValue); ok {
return strconv.Quote(f.Default[0]) + ellipsis
}
return f.Default[0] + ellipsis
}
return strings.ToUpper(f.Name)
}
func (f *FlagModel) HelpWithEnvar() string {
if f.Envar == "" {
return f.Help
}
return fmt.Sprintf("%s ($%s)", f.Help, f.Envar)
}
type ArgGroupModel struct {
Args []*ArgModel
}
func (a *ArgGroupModel) ArgSummary() string {
depth := 0
out := []string{}
for _, arg := range a.Args {
var h string
if arg.PlaceHolder != "" {
h = arg.PlaceHolder
} else {
h = "<" + arg.Name + ">"
}
if !arg.Required {
h = "[" + h
depth++
}
out = append(out, h)
}
out[len(out)-1] = out[len(out)-1] + strings.Repeat("]", depth)
return strings.Join(out, " ")
}
func (a *ArgModel) HelpWithEnvar() string {
if a.Envar == "" {
return a.Help
}
return fmt.Sprintf("%s ($%s)", a.Help, a.Envar)
}
type ArgModel struct {
Name string
Help string
Default []string
Envar string
PlaceHolder string
Required bool
Hidden bool
Value Value
}
func (a *ArgModel) String() string {
if a.Value == nil {
return ""
}
return a.Value.String()
}
type CmdGroupModel struct {
Commands []*CmdModel
}
func (c *CmdGroupModel) FlattenedCommands() (out []*CmdModel) {
for _, cmd := range c.Commands {
if len(cmd.Commands) == 0 {
out = append(out, cmd)
}
out = append(out, cmd.FlattenedCommands()...)
}
return
}
type CmdModel struct {
Name string
Aliases []string
Help string
HelpLong string
FullCommand string
Depth int
Hidden bool
Default bool
*FlagGroupModel
*ArgGroupModel
*CmdGroupModel
}
func (c *CmdModel) String() string {
return c.FullCommand
}
type ApplicationModel struct {
Name string
Help string
Version string
Author string
*ArgGroupModel
*CmdGroupModel
*FlagGroupModel
}
func (a *Application) Model() *ApplicationModel {
return &ApplicationModel{
Name: a.Name,
Help: a.Help,
Version: a.version,
Author: a.author,
FlagGroupModel: a.flagGroup.Model(),
ArgGroupModel: a.argGroup.Model(),
CmdGroupModel: a.cmdGroup.Model(),
}
}
func (a *argGroup) Model() *ArgGroupModel {
m := &ArgGroupModel{}
for _, arg := range a.args {
m.Args = append(m.Args, arg.Model())
}
return m
}
func (a *ArgClause) Model() *ArgModel {
return &ArgModel{
Name: a.name,
Help: a.help,
Default: a.defaultValues,
Envar: a.envar,
PlaceHolder: a.placeholder,
Required: a.required,
Hidden: a.hidden,
Value: a.value,
}
}
func (f *flagGroup) Model() *FlagGroupModel {
m := &FlagGroupModel{}
for _, fl := range f.flagOrder {
m.Flags = append(m.Flags, fl.Model())
}
return m
}
func (f *FlagClause) Model() *FlagModel {
return &FlagModel{
Name: f.name,
Help: f.help,
Short: rune(f.shorthand),
Default: f.defaultValues,
Envar: f.envar,
PlaceHolder: f.placeholder,
Required: f.required,
Hidden: f.hidden,
Value: f.value,
}
}
func (c *cmdGroup) Model() *CmdGroupModel {
m := &CmdGroupModel{}
for _, cm := range c.commandOrder {
m.Commands = append(m.Commands, cm.Model())
}
return m
}
func (c *CmdClause) Model() *CmdModel {
depth := 0
for i := c; i != nil; i = i.parent {
depth++
}
return &CmdModel{
Name: c.name,
Aliases: c.aliases,
Help: c.help,
HelpLong: c.helpLong,
Depth: depth,
Hidden: c.hidden,
Default: c.isDefault,
FullCommand: c.FullCommand(),
FlagGroupModel: c.flagGroup.Model(),
ArgGroupModel: c.argGroup.Model(),
CmdGroupModel: c.cmdGroup.Model(),
}
}

View File

@ -1,396 +0,0 @@
package kingpin
import (
"bufio"
"fmt"
"os"
"strings"
"unicode/utf8"
)
type TokenType int
// Token types.
const (
TokenShort TokenType = iota
TokenLong
TokenArg
TokenError
TokenEOL
)
func (t TokenType) String() string {
switch t {
case TokenShort:
return "short flag"
case TokenLong:
return "long flag"
case TokenArg:
return "argument"
case TokenError:
return "error"
case TokenEOL:
return "<EOL>"
}
return "?"
}
var (
TokenEOLMarker = Token{-1, TokenEOL, ""}
)
type Token struct {
Index int
Type TokenType
Value string
}
func (t *Token) Equal(o *Token) bool {
return t.Index == o.Index
}
func (t *Token) IsFlag() bool {
return t.Type == TokenShort || t.Type == TokenLong
}
func (t *Token) IsEOF() bool {
return t.Type == TokenEOL
}
func (t *Token) String() string {
switch t.Type {
case TokenShort:
return "-" + t.Value
case TokenLong:
return "--" + t.Value
case TokenArg:
return t.Value
case TokenError:
return "error: " + t.Value
case TokenEOL:
return "<EOL>"
default:
panic("unhandled type")
}
}
// A union of possible elements in a parse stack.
type ParseElement struct {
// Clause is either *CmdClause, *ArgClause or *FlagClause.
Clause interface{}
// Value is corresponding value for an ArgClause or FlagClause (if any).
Value *string
}
// ParseContext holds the current context of the parser. When passed to
// Action() callbacks Elements will be fully populated with *FlagClause,
// *ArgClause and *CmdClause values and their corresponding arguments (if
// any).
type ParseContext struct {
SelectedCommand *CmdClause
ignoreDefault bool
argsOnly bool
peek []*Token
argi int // Index of current command-line arg we're processing.
args []string
rawArgs []string
flags *flagGroup
arguments *argGroup
argumenti int // Cursor into arguments
// Flags, arguments and commands encountered and collected during parse.
Elements []*ParseElement
}
func (p *ParseContext) nextArg() *ArgClause {
if p.argumenti >= len(p.arguments.args) {
return nil
}
arg := p.arguments.args[p.argumenti]
if !arg.consumesRemainder() {
p.argumenti++
}
return arg
}
func (p *ParseContext) next() {
p.argi++
p.args = p.args[1:]
}
// HasTrailingArgs returns true if there are unparsed command-line arguments.
// This can occur if the parser can not match remaining arguments.
func (p *ParseContext) HasTrailingArgs() bool {
return len(p.args) > 0
}
func tokenize(args []string, ignoreDefault bool) *ParseContext {
return &ParseContext{
ignoreDefault: ignoreDefault,
args: args,
rawArgs: args,
flags: newFlagGroup(),
arguments: newArgGroup(),
}
}
func (p *ParseContext) mergeFlags(flags *flagGroup) {
for _, flag := range flags.flagOrder {
if flag.shorthand != 0 {
p.flags.short[string(flag.shorthand)] = flag
}
p.flags.long[flag.name] = flag
p.flags.flagOrder = append(p.flags.flagOrder, flag)
}
}
func (p *ParseContext) mergeArgs(args *argGroup) {
p.arguments.args = append(p.arguments.args, args.args...)
}
func (p *ParseContext) EOL() bool {
return p.Peek().Type == TokenEOL
}
func (p *ParseContext) Error() bool {
return p.Peek().Type == TokenError
}
// Next token in the parse context.
func (p *ParseContext) Next() *Token {
if len(p.peek) > 0 {
return p.pop()
}
// End of tokens.
if len(p.args) == 0 {
return &Token{Index: p.argi, Type: TokenEOL}
}
if p.argi > 0 && p.argi <= len(p.rawArgs) && p.rawArgs[p.argi-1] == "--" {
// If the previous argument was a --, from now on only arguments are parsed.
p.argsOnly = true
}
arg := p.args[0]
p.next()
if p.argsOnly {
return &Token{p.argi, TokenArg, arg}
}
if arg == "--" {
return p.Next()
}
if strings.HasPrefix(arg, "--") {
parts := strings.SplitN(arg[2:], "=", 2)
token := &Token{p.argi, TokenLong, parts[0]}
if len(parts) == 2 {
p.Push(&Token{p.argi, TokenArg, parts[1]})
}
return token
}
if strings.HasPrefix(arg, "-") {
if len(arg) == 1 {
return &Token{Index: p.argi, Type: TokenArg}
}
shortRune, size := utf8.DecodeRuneInString(arg[1:])
short := string(shortRune)
flag, ok := p.flags.short[short]
// Not a known short flag, we'll just return it anyway.
if !ok {
} else if fb, ok := flag.value.(boolFlag); ok && fb.IsBoolFlag() {
// Bool short flag.
} else {
// Short flag with combined argument: -fARG
token := &Token{p.argi, TokenShort, short}
if len(arg) > size+1 {
p.Push(&Token{p.argi, TokenArg, arg[size+1:]})
}
return token
}
if len(arg) > size+1 {
p.args = append([]string{"-" + arg[size+1:]}, p.args...)
}
return &Token{p.argi, TokenShort, short}
} else if EnableFileExpansion && strings.HasPrefix(arg, "@") {
expanded, err := ExpandArgsFromFile(arg[1:])
if err != nil {
return &Token{p.argi, TokenError, err.Error()}
}
if len(p.args) == 0 {
p.args = expanded
} else {
p.args = append(expanded, p.args...)
}
return p.Next()
}
return &Token{p.argi, TokenArg, arg}
}
func (p *ParseContext) Peek() *Token {
if len(p.peek) == 0 {
return p.Push(p.Next())
}
return p.peek[len(p.peek)-1]
}
func (p *ParseContext) Push(token *Token) *Token {
p.peek = append(p.peek, token)
return token
}
func (p *ParseContext) pop() *Token {
end := len(p.peek) - 1
token := p.peek[end]
p.peek = p.peek[0:end]
return token
}
func (p *ParseContext) String() string {
return p.SelectedCommand.FullCommand()
}
func (p *ParseContext) matchedFlag(flag *FlagClause, value string) {
p.Elements = append(p.Elements, &ParseElement{Clause: flag, Value: &value})
}
func (p *ParseContext) matchedArg(arg *ArgClause, value string) {
p.Elements = append(p.Elements, &ParseElement{Clause: arg, Value: &value})
}
func (p *ParseContext) matchedCmd(cmd *CmdClause) {
p.Elements = append(p.Elements, &ParseElement{Clause: cmd})
p.mergeFlags(cmd.flagGroup)
p.mergeArgs(cmd.argGroup)
p.SelectedCommand = cmd
}
// Expand arguments from a file. Lines starting with # will be treated as comments.
func ExpandArgsFromFile(filename string) (out []string, err error) {
if filename == "" {
return nil, fmt.Errorf("expected @ file to expand arguments from")
}
r, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("failed to open arguments file %q: %s", filename, err)
}
defer r.Close()
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") || strings.TrimSpace(line) == "" {
continue
}
out = append(out, line)
}
err = scanner.Err()
if err != nil {
return nil, fmt.Errorf("failed to read arguments from %q: %s", filename, err)
}
return
}
func parse(context *ParseContext, app *Application) (err error) {
context.mergeFlags(app.flagGroup)
context.mergeArgs(app.argGroup)
cmds := app.cmdGroup
ignoreDefault := context.ignoreDefault
loop:
for !context.EOL() && !context.Error() {
token := context.Peek()
switch token.Type {
case TokenLong, TokenShort:
if flag, err := context.flags.parse(context); err != nil {
if !ignoreDefault {
if cmd := cmds.defaultSubcommand(); cmd != nil {
cmd.completionAlts = cmds.cmdNames()
context.matchedCmd(cmd)
cmds = cmd.cmdGroup
break
}
}
return err
} else if flag == HelpFlag {
ignoreDefault = true
}
case TokenArg:
if cmds.have() {
selectedDefault := false
cmd, ok := cmds.commands[token.String()]
if !ok {
if !ignoreDefault {
if cmd = cmds.defaultSubcommand(); cmd != nil {
cmd.completionAlts = cmds.cmdNames()
selectedDefault = true
}
}
if cmd == nil {
return fmt.Errorf("expected command but got %q", token)
}
}
if cmd == HelpCommand {
ignoreDefault = true
}
cmd.completionAlts = nil
context.matchedCmd(cmd)
cmds = cmd.cmdGroup
if !selectedDefault {
context.Next()
}
} else if context.arguments.have() {
if app.noInterspersed {
// no more flags
context.argsOnly = true
}
arg := context.nextArg()
if arg == nil {
break loop
}
context.matchedArg(arg, token.String())
context.Next()
} else {
break loop
}
case TokenEOL:
break loop
}
}
// Move to innermost default command.
for !ignoreDefault {
if cmd := cmds.defaultSubcommand(); cmd != nil {
cmd.completionAlts = cmds.cmdNames()
context.matchedCmd(cmd)
cmds = cmd.cmdGroup
} else {
break
}
}
if context.Error() {
return fmt.Errorf("%s", context.Peek().Value)
}
if !context.EOL() {
return fmt.Errorf("unexpected %s", context.Peek())
}
// Set defaults for all remaining args.
for arg := context.nextArg(); arg != nil && !arg.consumesRemainder(); arg = context.nextArg() {
for _, defaultValue := range arg.defaultValues {
if err := arg.value.Set(defaultValue); err != nil {
return fmt.Errorf("invalid default value '%s' for argument '%s'", defaultValue, arg.name)
}
}
}
return
}

View File

@ -1,216 +0,0 @@
package kingpin
import (
"net"
"net/url"
"os"
"time"
"github.com/alecthomas/units"
)
type Settings interface {
SetValue(value Value)
}
type parserMixin struct {
value Value
required bool
}
func (p *parserMixin) SetText(text Text) {
p.value = &wrapText{text}
}
func (p *parserMixin) SetValue(value Value) {
p.value = value
}
// StringMap provides key=value parsing into a map.
func (p *parserMixin) StringMap() (target *map[string]string) {
target = &(map[string]string{})
p.StringMapVar(target)
return
}
// Duration sets the parser to a time.Duration parser.
func (p *parserMixin) Duration() (target *time.Duration) {
target = new(time.Duration)
p.DurationVar(target)
return
}
// Bytes parses numeric byte units. eg. 1.5KB
func (p *parserMixin) Bytes() (target *units.Base2Bytes) {
target = new(units.Base2Bytes)
p.BytesVar(target)
return
}
// IP sets the parser to a net.IP parser.
func (p *parserMixin) IP() (target *net.IP) {
target = new(net.IP)
p.IPVar(target)
return
}
// TCP (host:port) address.
func (p *parserMixin) TCP() (target **net.TCPAddr) {
target = new(*net.TCPAddr)
p.TCPVar(target)
return
}
// TCPVar (host:port) address.
func (p *parserMixin) TCPVar(target **net.TCPAddr) {
p.SetValue(newTCPAddrValue(target))
}
// ExistingFile sets the parser to one that requires and returns an existing file.
func (p *parserMixin) ExistingFile() (target *string) {
target = new(string)
p.ExistingFileVar(target)
return
}
// ExistingDir sets the parser to one that requires and returns an existing directory.
func (p *parserMixin) ExistingDir() (target *string) {
target = new(string)
p.ExistingDirVar(target)
return
}
// ExistingFileOrDir sets the parser to one that requires and returns an existing file OR directory.
func (p *parserMixin) ExistingFileOrDir() (target *string) {
target = new(string)
p.ExistingFileOrDirVar(target)
return
}
// File returns an os.File against an existing file.
func (p *parserMixin) File() (target **os.File) {
target = new(*os.File)
p.FileVar(target)
return
}
// File attempts to open a File with os.OpenFile(flag, perm).
func (p *parserMixin) OpenFile(flag int, perm os.FileMode) (target **os.File) {
target = new(*os.File)
p.OpenFileVar(target, flag, perm)
return
}
// URL provides a valid, parsed url.URL.
func (p *parserMixin) URL() (target **url.URL) {
target = new(*url.URL)
p.URLVar(target)
return
}
// StringMap provides key=value parsing into a map.
func (p *parserMixin) StringMapVar(target *map[string]string) {
p.SetValue(newStringMapValue(target))
}
// Float sets the parser to a float64 parser.
func (p *parserMixin) Float() (target *float64) {
return p.Float64()
}
// Float sets the parser to a float64 parser.
func (p *parserMixin) FloatVar(target *float64) {
p.Float64Var(target)
}
// Duration sets the parser to a time.Duration parser.
func (p *parserMixin) DurationVar(target *time.Duration) {
p.SetValue(newDurationValue(target))
}
// BytesVar parses numeric byte units. eg. 1.5KB
func (p *parserMixin) BytesVar(target *units.Base2Bytes) {
p.SetValue(newBytesValue(target))
}
// IP sets the parser to a net.IP parser.
func (p *parserMixin) IPVar(target *net.IP) {
p.SetValue(newIPValue(target))
}
// ExistingFile sets the parser to one that requires and returns an existing file.
func (p *parserMixin) ExistingFileVar(target *string) {
p.SetValue(newExistingFileValue(target))
}
// ExistingDir sets the parser to one that requires and returns an existing directory.
func (p *parserMixin) ExistingDirVar(target *string) {
p.SetValue(newExistingDirValue(target))
}
// ExistingDir sets the parser to one that requires and returns an existing directory.
func (p *parserMixin) ExistingFileOrDirVar(target *string) {
p.SetValue(newExistingFileOrDirValue(target))
}
// FileVar opens an existing file.
func (p *parserMixin) FileVar(target **os.File) {
p.SetValue(newFileValue(target, os.O_RDONLY, 0))
}
// OpenFileVar calls os.OpenFile(flag, perm)
func (p *parserMixin) OpenFileVar(target **os.File, flag int, perm os.FileMode) {
p.SetValue(newFileValue(target, flag, perm))
}
// URL provides a valid, parsed url.URL.
func (p *parserMixin) URLVar(target **url.URL) {
p.SetValue(newURLValue(target))
}
// URLList provides a parsed list of url.URL values.
func (p *parserMixin) URLList() (target *[]*url.URL) {
target = new([]*url.URL)
p.URLListVar(target)
return
}
// URLListVar provides a parsed list of url.URL values.
func (p *parserMixin) URLListVar(target *[]*url.URL) {
p.SetValue(newURLListValue(target))
}
// Enum allows a value from a set of options.
func (p *parserMixin) Enum(options ...string) (target *string) {
target = new(string)
p.EnumVar(target, options...)
return
}
// EnumVar allows a value from a set of options.
func (p *parserMixin) EnumVar(target *string, options ...string) {
p.SetValue(newEnumFlag(target, options...))
}
// Enums allows a set of values from a set of options.
func (p *parserMixin) Enums(options ...string) (target *[]string) {
target = new([]string)
p.EnumsVar(target, options...)
return
}
// EnumVar allows a value from a set of options.
func (p *parserMixin) EnumsVar(target *[]string, options ...string) {
p.SetValue(newEnumsFlag(target, options...))
}
// A Counter increments a number each time it is encountered.
func (p *parserMixin) Counter() (target *int) {
target = new(int)
p.CounterVar(target)
return
}
func (p *parserMixin) CounterVar(target *int) {
p.SetValue(newCounterValue(target))
}

View File

@ -1,262 +0,0 @@
package kingpin
// Default usage template.
var DefaultUsageTemplate = `{{define "FormatCommand" -}}
{{if .FlagSummary}} {{.FlagSummary}}{{end -}}
{{range .Args}}{{if not .Hidden}} {{if not .Required}}[{{end}}{{if .PlaceHolder}}{{.PlaceHolder}}{{else}}<{{.Name}}>{{end}}{{if .Value|IsCumulative}}...{{end}}{{if not .Required}}]{{end}}{{end}}{{end -}}
{{end -}}
{{define "FormatCommands" -}}
{{range .FlattenedCommands -}}
{{if not .Hidden -}}
{{.FullCommand}}{{if .Default}}*{{end}}{{template "FormatCommand" .}}
{{.Help|Wrap 4}}
{{end -}}
{{end -}}
{{end -}}
{{define "FormatUsage" -}}
{{template "FormatCommand" .}}{{if .Commands}} <command> [<args> ...]{{end}}
{{if .Help}}
{{.Help|Wrap 0 -}}
{{end -}}
{{end -}}
{{if .Context.SelectedCommand -}}
usage: {{.App.Name}} {{.Context.SelectedCommand}}{{template "FormatUsage" .Context.SelectedCommand}}
{{ else -}}
usage: {{.App.Name}}{{template "FormatUsage" .App}}
{{end}}
{{if .Context.Flags -}}
Flags:
{{.Context.Flags|FlagsToTwoColumns|FormatTwoColumns}}
{{end -}}
{{if .Context.Args -}}
Args:
{{.Context.Args|ArgsToTwoColumns|FormatTwoColumns}}
{{end -}}
{{if .Context.SelectedCommand -}}
{{if len .Context.SelectedCommand.Commands -}}
Subcommands:
{{template "FormatCommands" .Context.SelectedCommand}}
{{end -}}
{{else if .App.Commands -}}
Commands:
{{template "FormatCommands" .App}}
{{end -}}
`
// Usage template where command's optional flags are listed separately
var SeparateOptionalFlagsUsageTemplate = `{{define "FormatCommand" -}}
{{if .FlagSummary}} {{.FlagSummary}}{{end -}}
{{range .Args}}{{if not .Hidden}} {{if not .Required}}[{{end}}{{if .PlaceHolder}}{{.PlaceHolder}}{{else}}<{{.Name}}>{{end}}{{if .Value|IsCumulative}}...{{end}}{{if not .Required}}]{{end}}{{end}}{{end -}}
{{end -}}
{{define "FormatCommands" -}}
{{range .FlattenedCommands -}}
{{if not .Hidden -}}
{{.FullCommand}}{{if .Default}}*{{end}}{{template "FormatCommand" .}}
{{.Help|Wrap 4}}
{{end -}}
{{end -}}
{{end -}}
{{define "FormatUsage" -}}
{{template "FormatCommand" .}}{{if .Commands}} <command> [<args> ...]{{end}}
{{if .Help}}
{{.Help|Wrap 0 -}}
{{end -}}
{{end -}}
{{if .Context.SelectedCommand -}}
usage: {{.App.Name}} {{.Context.SelectedCommand}}{{template "FormatUsage" .Context.SelectedCommand}}
{{else -}}
usage: {{.App.Name}}{{template "FormatUsage" .App}}
{{end -}}
{{if .Context.Flags|RequiredFlags -}}
Required flags:
{{.Context.Flags|RequiredFlags|FlagsToTwoColumns|FormatTwoColumns}}
{{end -}}
{{if .Context.Flags|OptionalFlags -}}
Optional flags:
{{.Context.Flags|OptionalFlags|FlagsToTwoColumns|FormatTwoColumns}}
{{end -}}
{{if .Context.Args -}}
Args:
{{.Context.Args|ArgsToTwoColumns|FormatTwoColumns}}
{{end -}}
{{if .Context.SelectedCommand -}}
Subcommands:
{{if .Context.SelectedCommand.Commands -}}
{{template "FormatCommands" .Context.SelectedCommand}}
{{end -}}
{{else if .App.Commands -}}
Commands:
{{template "FormatCommands" .App}}
{{end -}}
`
// Usage template with compactly formatted commands.
var CompactUsageTemplate = `{{define "FormatCommand" -}}
{{if .FlagSummary}} {{.FlagSummary}}{{end -}}
{{range .Args}}{{if not .Hidden}} {{if not .Required}}[{{end}}{{if .PlaceHolder}}{{.PlaceHolder}}{{else}}<{{.Name}}>{{end}}{{if .Value|IsCumulative}}...{{end}}{{if not .Required}}]{{end}}{{end}}{{end -}}
{{end -}}
{{define "FormatCommandList" -}}
{{range . -}}
{{if not .Hidden -}}
{{.Depth|Indent}}{{.Name}}{{if .Default}}*{{end}}{{template "FormatCommand" .}}
{{end -}}
{{template "FormatCommandList" .Commands -}}
{{end -}}
{{end -}}
{{define "FormatUsage" -}}
{{template "FormatCommand" .}}{{if .Commands}} <command> [<args> ...]{{end}}
{{if .Help}}
{{.Help|Wrap 0 -}}
{{end -}}
{{end -}}
{{if .Context.SelectedCommand -}}
usage: {{.App.Name}} {{.Context.SelectedCommand}}{{template "FormatUsage" .Context.SelectedCommand}}
{{else -}}
usage: {{.App.Name}}{{template "FormatUsage" .App}}
{{end -}}
{{if .Context.Flags -}}
Flags:
{{.Context.Flags|FlagsToTwoColumns|FormatTwoColumns}}
{{end -}}
{{if .Context.Args -}}
Args:
{{.Context.Args|ArgsToTwoColumns|FormatTwoColumns}}
{{end -}}
{{if .Context.SelectedCommand -}}
{{if .Context.SelectedCommand.Commands -}}
Commands:
{{.Context.SelectedCommand}}
{{template "FormatCommandList" .Context.SelectedCommand.Commands}}
{{end -}}
{{else if .App.Commands -}}
Commands:
{{template "FormatCommandList" .App.Commands}}
{{end -}}
`
var ManPageTemplate = `{{define "FormatFlags" -}}
{{range .Flags -}}
{{if not .Hidden -}}
.TP
\fB{{if .Short}}-{{.Short|Char}}, {{end}}--{{.Name}}{{if not .IsBoolFlag}}={{.FormatPlaceHolder}}{{end -}}\fR
{{.Help}}
{{end -}}
{{end -}}
{{end -}}
{{define "FormatCommand" -}}
{{if .FlagSummary}} {{.FlagSummary}}{{end -}}
{{range .Args}}{{if not .Hidden}} {{if not .Required}}[{{end}}{{if .PlaceHolder}}{{.PlaceHolder}}{{else}}<{{.Name}}>{{end}}{{if .Value|IsCumulative}}...{{end}}{{if not .Required}}]{{end}}{{end}}{{end -}}
{{end -}}
{{define "FormatCommands" -}}
{{range .FlattenedCommands -}}
{{if not .Hidden -}}
.SS
\fB{{.FullCommand}}{{template "FormatCommand" . -}}\fR
.PP
{{.Help}}
{{template "FormatFlags" . -}}
{{end -}}
{{end -}}
{{end -}}
{{define "FormatUsage" -}}
{{template "FormatCommand" .}}{{if .Commands}} <command> [<args> ...]{{end -}}\fR
{{end -}}
.TH {{.App.Name}} 1 {{.App.Version}} "{{.App.Author}}"
.SH "NAME"
{{.App.Name}}
.SH "SYNOPSIS"
.TP
\fB{{.App.Name}}{{template "FormatUsage" .App}}
.SH "DESCRIPTION"
{{.App.Help}}
.SH "OPTIONS"
{{template "FormatFlags" .App -}}
{{if .App.Commands -}}
.SH "COMMANDS"
{{template "FormatCommands" .App -}}
{{end -}}
`
// Default usage template.
var LongHelpTemplate = `{{define "FormatCommand" -}}
{{if .FlagSummary}} {{.FlagSummary}}{{end -}}
{{range .Args}}{{if not .Hidden}} {{if not .Required}}[{{end}}{{if .PlaceHolder}}{{.PlaceHolder}}{{else}}<{{.Name}}>{{end}}{{if .Value|IsCumulative}}...{{end}}{{if not .Required}}]{{end}}{{end}}{{end -}}
{{end -}}
{{define "FormatCommands" -}}
{{range .FlattenedCommands -}}
{{if not .Hidden -}}
{{.FullCommand}}{{template "FormatCommand" .}}
{{.Help|Wrap 4}}
{{with .Flags|FlagsToTwoColumns}}{{FormatTwoColumnsWithIndent . 4 2}}{{end}}
{{end -}}
{{end -}}
{{end -}}
{{define "FormatUsage" -}}
{{template "FormatCommand" .}}{{if .Commands}} <command> [<args> ...]{{end}}
{{if .Help}}
{{.Help|Wrap 0 -}}
{{end -}}
{{end -}}
usage: {{.App.Name}}{{template "FormatUsage" .App}}
{{if .Context.Flags -}}
Flags:
{{.Context.Flags|FlagsToTwoColumns|FormatTwoColumns}}
{{end -}}
{{if .Context.Args -}}
Args:
{{.Context.Args|ArgsToTwoColumns|FormatTwoColumns}}
{{end -}}
{{if .App.Commands -}}
Commands:
{{template "FormatCommands" .App}}
{{end -}}
`
var BashCompletionTemplate = `
_{{.App.Name}}_bash_autocomplete() {
local cur prev opts base
COMPREPLY=()
cur="${COMP_WORDS[COMP_CWORD]}"
opts=$( ${COMP_WORDS[0]} --completion-bash "${COMP_WORDS[@]:1:$COMP_CWORD}" )
COMPREPLY=( $(compgen -W "${opts}" -- ${cur}) )
return 0
}
complete -F _{{.App.Name}}_bash_autocomplete -o default {{.App.Name}}
`
var ZshCompletionTemplate = `#compdef {{.App.Name}}
_{{.App.Name}}() {
local matches=($(${words[1]} --completion-bash "${(@)words[2,$CURRENT]}"))
compadd -a matches
if [[ $compstate[nmatches] -eq 0 && $words[$CURRENT] != -* ]]; then
_files
fi
}
if [[ "$(basename -- ${(%):-%x})" != "_{{.App.Name}}" ]]; then
compdef _{{.App.Name}} {{.App.Name}}
fi
`

View File

@ -1,225 +0,0 @@
package kingpin
import (
"bytes"
"fmt"
"go/doc"
"io"
"strings"
"text/template"
)
var (
preIndent = " "
)
func formatTwoColumns(w io.Writer, indent, padding, width int, rows [][2]string) {
// Find size of first column.
s := 0
for _, row := range rows {
if c := len(row[0]); c > s && c < 30 {
s = c
}
}
indentStr := strings.Repeat(" ", indent)
offsetStr := strings.Repeat(" ", s+padding)
for _, row := range rows {
buf := bytes.NewBuffer(nil)
doc.ToText(buf, row[1], "", preIndent, width-s-padding-indent)
lines := strings.Split(strings.TrimRight(buf.String(), "\n"), "\n")
fmt.Fprintf(w, "%s%-*s%*s", indentStr, s, row[0], padding, "")
if len(row[0]) >= 30 {
fmt.Fprintf(w, "\n%s%s", indentStr, offsetStr)
}
fmt.Fprintf(w, "%s\n", lines[0])
for _, line := range lines[1:] {
fmt.Fprintf(w, "%s%s%s\n", indentStr, offsetStr, line)
}
}
}
// Usage writes application usage to w. It parses args to determine
// appropriate help context, such as which command to show help for.
func (a *Application) Usage(args []string) {
context, err := a.parseContext(true, args)
a.FatalIfError(err, "")
if err := a.UsageForContextWithTemplate(context, 2, a.usageTemplate); err != nil {
panic(err)
}
}
func formatAppUsage(app *ApplicationModel) string {
s := []string{app.Name}
if len(app.Flags) > 0 {
s = append(s, app.FlagSummary())
}
if len(app.Args) > 0 {
s = append(s, app.ArgSummary())
}
return strings.Join(s, " ")
}
func formatCmdUsage(app *ApplicationModel, cmd *CmdModel) string {
s := []string{app.Name, cmd.String()}
if len(cmd.Flags) > 0 {
s = append(s, cmd.FlagSummary())
}
if len(cmd.Args) > 0 {
s = append(s, cmd.ArgSummary())
}
return strings.Join(s, " ")
}
func formatFlag(haveShort bool, flag *FlagModel) string {
flagString := ""
flagName := flag.Name
if flag.IsBoolFlag() {
flagName = "[no-]" + flagName
}
if flag.Short != 0 {
flagString += fmt.Sprintf("-%c, --%s", flag.Short, flagName)
} else {
if haveShort {
flagString += fmt.Sprintf(" --%s", flagName)
} else {
flagString += fmt.Sprintf("--%s", flagName)
}
}
if !flag.IsBoolFlag() {
flagString += fmt.Sprintf("=%s", flag.FormatPlaceHolder())
}
if v, ok := flag.Value.(repeatableFlag); ok && v.IsCumulative() {
flagString += " ..."
}
return flagString
}
type templateParseContext struct {
SelectedCommand *CmdModel
*FlagGroupModel
*ArgGroupModel
}
type templateContext struct {
App *ApplicationModel
Width int
Context *templateParseContext
}
// UsageForContext displays usage information from a ParseContext (obtained from
// Application.ParseContext() or Action(f) callbacks).
func (a *Application) UsageForContext(context *ParseContext) error {
return a.UsageForContextWithTemplate(context, 2, a.usageTemplate)
}
// UsageForContextWithTemplate is the base usage function. You generally don't need to use this.
func (a *Application) UsageForContextWithTemplate(context *ParseContext, indent int, tmpl string) error {
width := guessWidth(a.usageWriter)
funcs := template.FuncMap{
"Indent": func(level int) string {
return strings.Repeat(" ", level*indent)
},
"Wrap": func(indent int, s string) string {
buf := bytes.NewBuffer(nil)
indentText := strings.Repeat(" ", indent)
doc.ToText(buf, s, indentText, " "+indentText, width-indent)
return buf.String()
},
"FormatFlag": formatFlag,
"FlagsToTwoColumns": func(f []*FlagModel) [][2]string {
rows := [][2]string{}
haveShort := false
for _, flag := range f {
if flag.Short != 0 {
haveShort = true
break
}
}
for _, flag := range f {
if !flag.Hidden {
rows = append(rows, [2]string{formatFlag(haveShort, flag), flag.HelpWithEnvar()})
}
}
return rows
},
"RequiredFlags": func(f []*FlagModel) []*FlagModel {
requiredFlags := []*FlagModel{}
for _, flag := range f {
if flag.Required {
requiredFlags = append(requiredFlags, flag)
}
}
return requiredFlags
},
"OptionalFlags": func(f []*FlagModel) []*FlagModel {
optionalFlags := []*FlagModel{}
for _, flag := range f {
if !flag.Required {
optionalFlags = append(optionalFlags, flag)
}
}
return optionalFlags
},
"ArgsToTwoColumns": func(a []*ArgModel) [][2]string {
rows := [][2]string{}
for _, arg := range a {
if !arg.Hidden {
var s string
if arg.PlaceHolder != "" {
s = arg.PlaceHolder
} else {
s = "<" + arg.Name + ">"
}
if !arg.Required {
s = "[" + s + "]"
}
rows = append(rows, [2]string{s, arg.HelpWithEnvar()})
}
}
return rows
},
"FormatTwoColumns": func(rows [][2]string) string {
buf := bytes.NewBuffer(nil)
formatTwoColumns(buf, indent, indent, width, rows)
return buf.String()
},
"FormatTwoColumnsWithIndent": func(rows [][2]string, indent, padding int) string {
buf := bytes.NewBuffer(nil)
formatTwoColumns(buf, indent, padding, width, rows)
return buf.String()
},
"FormatAppUsage": formatAppUsage,
"FormatCommandUsage": formatCmdUsage,
"IsCumulative": func(value Value) bool {
r, ok := value.(remainderArg)
return ok && r.IsCumulative()
},
"Char": func(c rune) string {
return string(c)
},
}
for k, v := range a.usageFuncs {
funcs[k] = v
}
t, err := template.New("usage").Funcs(funcs).Parse(tmpl)
if err != nil {
return err
}
var selectedCommand *CmdModel
if context.SelectedCommand != nil {
selectedCommand = context.SelectedCommand.Model()
}
ctx := templateContext{
App: a.Model(),
Width: width,
Context: &templateParseContext{
SelectedCommand: selectedCommand,
FlagGroupModel: context.flags.Model(),
ArgGroupModel: context.arguments.Model(),
},
}
return t.Execute(a.usageWriter, ctx)
}

View File

@ -1,489 +0,0 @@
package kingpin
//go:generate go run ./cmd/genvalues/main.go
import (
"encoding"
"fmt"
"net"
"net/url"
"os"
"reflect"
"regexp"
"strings"
"time"
"github.com/alecthomas/units"
"github.com/xhit/go-str2duration/v2"
)
// NOTE: Most of the base type values were lifted from:
// http://golang.org/src/pkg/flag/flag.go?s=20146:20222
// Value is the interface to the dynamic value stored in a flag.
// (The default value is represented as a string.)
//
// If a Value has an IsBoolFlag() bool method returning true, the command-line
// parser makes --name equivalent to -name=true rather than using the next
// command-line argument, and adds a --no-name counterpart for negating the
// flag.
type Value interface {
String() string
Set(string) error
}
// Getter is an interface that allows the contents of a Value to be retrieved.
// It wraps the Value interface, rather than being part of it, because it
// appeared after Go 1 and its compatibility rules. All Value types provided
// by this package satisfy the Getter interface.
type Getter interface {
Value
Get() interface{}
}
// Optional interface to indicate boolean flags that don't accept a value, and
// implicitly have a --no-<x> negation counterpart.
type boolFlag interface {
IsBoolFlag() bool
}
// Optional interface for arguments that cumulatively consume all remaining
// input.
type remainderArg interface {
IsCumulative() bool
}
// Optional interface for flags that can be repeated.
type repeatableFlag interface {
IsCumulative() bool
}
// Text is the interface to the dynamic value stored in a flag.
// (The default value is represented as a string.)
type Text interface {
encoding.TextMarshaler
encoding.TextUnmarshaler
}
type wrapText struct {
text Text
}
func (w wrapText) String() string {
buf, _ := w.text.MarshalText()
return string(buf)
}
func (w *wrapText) Set(s string) error {
return w.text.UnmarshalText([]byte(s))
}
type accumulator struct {
element func(value interface{}) Value
typ reflect.Type
slice reflect.Value
}
// Use reflection to accumulate values into a slice.
//
// target := []string{}
// newAccumulator(&target, func (value interface{}) Value {
// return newStringValue(value.(*string))
// })
func newAccumulator(slice interface{}, element func(value interface{}) Value) *accumulator {
typ := reflect.TypeOf(slice)
if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Slice {
panic("expected a pointer to a slice")
}
return &accumulator{
element: element,
typ: typ.Elem().Elem(),
slice: reflect.ValueOf(slice),
}
}
func (a *accumulator) String() string {
out := []string{}
s := a.slice.Elem()
for i := 0; i < s.Len(); i++ {
out = append(out, a.element(s.Index(i).Addr().Interface()).String())
}
return strings.Join(out, ",")
}
func (a *accumulator) Set(value string) error {
e := reflect.New(a.typ)
if err := a.element(e.Interface()).Set(value); err != nil {
return err
}
slice := reflect.Append(a.slice.Elem(), e.Elem())
a.slice.Elem().Set(slice)
return nil
}
func (a *accumulator) Get() interface{} {
return a.slice.Interface()
}
func (a *accumulator) IsCumulative() bool {
return true
}
func (b *boolValue) IsBoolFlag() bool { return true }
// -- time.Duration Value
type durationValue time.Duration
func newDurationValue(p *time.Duration) *durationValue {
return (*durationValue)(p)
}
func (d *durationValue) Set(s string) error {
v, err := str2duration.ParseDuration(s)
*d = durationValue(v)
return err
}
func (d *durationValue) Get() interface{} { return time.Duration(*d) }
func (d *durationValue) String() string { return (*time.Duration)(d).String() }
// -- map[string]string Value
type stringMapValue map[string]string
func newStringMapValue(p *map[string]string) *stringMapValue {
return (*stringMapValue)(p)
}
var stringMapRegex = regexp.MustCompile("[:=]")
func (s *stringMapValue) Set(value string) error {
parts := stringMapRegex.Split(value, 2)
if len(parts) != 2 {
return fmt.Errorf("expected KEY=VALUE got '%s'", value)
}
(*s)[parts[0]] = parts[1]
return nil
}
func (s *stringMapValue) Get() interface{} {
return (map[string]string)(*s)
}
func (s *stringMapValue) String() string {
return fmt.Sprintf("%s", map[string]string(*s))
}
func (s *stringMapValue) IsCumulative() bool {
return true
}
// -- net.IP Value
type ipValue net.IP
func newIPValue(p *net.IP) *ipValue {
return (*ipValue)(p)
}
func (i *ipValue) Set(value string) error {
if ip := net.ParseIP(value); ip == nil {
return fmt.Errorf("'%s' is not an IP address", value)
} else {
*i = *(*ipValue)(&ip)
return nil
}
}
func (i *ipValue) Get() interface{} {
return (net.IP)(*i)
}
func (i *ipValue) String() string {
return (*net.IP)(i).String()
}
// -- *net.TCPAddr Value
type tcpAddrValue struct {
addr **net.TCPAddr
}
func newTCPAddrValue(p **net.TCPAddr) *tcpAddrValue {
return &tcpAddrValue{p}
}
func (i *tcpAddrValue) Set(value string) error {
if addr, err := net.ResolveTCPAddr("tcp", value); err != nil {
return fmt.Errorf("'%s' is not a valid TCP address: %s", value, err)
} else {
*i.addr = addr
return nil
}
}
func (t *tcpAddrValue) Get() interface{} {
return (*net.TCPAddr)(*t.addr)
}
func (i *tcpAddrValue) String() string {
return (*i.addr).String()
}
// -- existingFile Value
type fileStatValue struct {
path *string
predicate func(os.FileInfo) error
}
func newFileStatValue(p *string, predicate func(os.FileInfo) error) *fileStatValue {
return &fileStatValue{
path: p,
predicate: predicate,
}
}
func (e *fileStatValue) Set(value string) error {
if s, err := os.Stat(value); os.IsNotExist(err) {
return fmt.Errorf("path '%s' does not exist", value)
} else if err != nil {
return err
} else if err := e.predicate(s); err != nil {
return err
}
*e.path = value
return nil
}
func (f *fileStatValue) Get() interface{} {
return (string)(*f.path)
}
func (e *fileStatValue) String() string {
return *e.path
}
// -- os.File value
type fileValue struct {
f **os.File
flag int
perm os.FileMode
}
func newFileValue(p **os.File, flag int, perm os.FileMode) *fileValue {
return &fileValue{p, flag, perm}
}
func (f *fileValue) Set(value string) error {
if fd, err := os.OpenFile(value, f.flag, f.perm); err != nil {
return err
} else {
*f.f = fd
return nil
}
}
func (f *fileValue) Get() interface{} {
return (*os.File)(*f.f)
}
func (f *fileValue) String() string {
if *f.f == nil {
return "<nil>"
}
return (*f.f).Name()
}
// -- url.URL Value
type urlValue struct {
u **url.URL
}
func newURLValue(p **url.URL) *urlValue {
return &urlValue{p}
}
func (u *urlValue) Set(value string) error {
if url, err := url.Parse(value); err != nil {
return fmt.Errorf("invalid URL: %s", err)
} else {
*u.u = url
return nil
}
}
func (u *urlValue) Get() interface{} {
return (*url.URL)(*u.u)
}
func (u *urlValue) String() string {
if *u.u == nil {
return "<nil>"
}
return (*u.u).String()
}
// -- []*url.URL Value
type urlListValue []*url.URL
func newURLListValue(p *[]*url.URL) *urlListValue {
return (*urlListValue)(p)
}
func (u *urlListValue) Set(value string) error {
if url, err := url.Parse(value); err != nil {
return fmt.Errorf("invalid URL: %s", err)
} else {
*u = append(*u, url)
return nil
}
}
func (u *urlListValue) Get() interface{} {
return ([]*url.URL)(*u)
}
func (u *urlListValue) String() string {
out := []string{}
for _, url := range *u {
out = append(out, url.String())
}
return strings.Join(out, ",")
}
func (u *urlListValue) IsCumulative() bool {
return true
}
// A flag whose value must be in a set of options.
type enumValue struct {
value *string
options []string
}
func newEnumFlag(target *string, options ...string) *enumValue {
return &enumValue{
value: target,
options: options,
}
}
func (a *enumValue) String() string {
return *a.value
}
func (a *enumValue) Set(value string) error {
for _, v := range a.options {
if v == value {
*a.value = value
return nil
}
}
return fmt.Errorf("enum value must be one of %s, got '%s'", strings.Join(a.options, ","), value)
}
func (e *enumValue) Get() interface{} {
return (string)(*e.value)
}
// -- []string Enum Value
type enumsValue struct {
value *[]string
options []string
}
func newEnumsFlag(target *[]string, options ...string) *enumsValue {
return &enumsValue{
value: target,
options: options,
}
}
func (s *enumsValue) Set(value string) error {
for _, v := range s.options {
if v == value {
*s.value = append(*s.value, value)
return nil
}
}
return fmt.Errorf("enum value must be one of %s, got '%s'", strings.Join(s.options, ","), value)
}
func (e *enumsValue) Get() interface{} {
return ([]string)(*e.value)
}
func (s *enumsValue) String() string {
return strings.Join(*s.value, ",")
}
func (s *enumsValue) IsCumulative() bool {
return true
}
// -- units.Base2Bytes Value
type bytesValue units.Base2Bytes
func newBytesValue(p *units.Base2Bytes) *bytesValue {
return (*bytesValue)(p)
}
func (d *bytesValue) Set(s string) error {
v, err := units.ParseBase2Bytes(s)
*d = bytesValue(v)
return err
}
func (d *bytesValue) Get() interface{} { return units.Base2Bytes(*d) }
func (d *bytesValue) String() string { return (*units.Base2Bytes)(d).String() }
func newExistingFileValue(target *string) *fileStatValue {
return newFileStatValue(target, func(s os.FileInfo) error {
if s.IsDir() {
return fmt.Errorf("'%s' is a directory", s.Name())
}
return nil
})
}
func newExistingDirValue(target *string) *fileStatValue {
return newFileStatValue(target, func(s os.FileInfo) error {
if !s.IsDir() {
return fmt.Errorf("'%s' is a file", s.Name())
}
return nil
})
}
func newExistingFileOrDirValue(target *string) *fileStatValue {
return newFileStatValue(target, func(s os.FileInfo) error { return nil })
}
type counterValue int
func newCounterValue(n *int) *counterValue {
return (*counterValue)(n)
}
func (c *counterValue) Set(s string) error {
*c++
return nil
}
func (c *counterValue) Get() interface{} { return (int)(*c) }
func (c *counterValue) IsBoolFlag() bool { return true }
func (c *counterValue) String() string { return fmt.Sprintf("%d", *c) }
func (c *counterValue) IsCumulative() bool { return true }
func resolveHost(value string) (net.IP, error) {
if ip := net.ParseIP(value); ip != nil {
return ip, nil
} else {
if addr, err := net.ResolveIPAddr("ip", value); err != nil {
return nil, err
} else {
return addr.IP, nil
}
}
}

View File

@ -1,25 +0,0 @@
[
{"type": "bool", "parser": "strconv.ParseBool(s)"},
{"type": "string", "parser": "s, error(nil)", "format": "string(*f.v)", "plural": "Strings"},
{"type": "uint", "parser": "strconv.ParseUint(s, 0, 64)", "plural": "Uints"},
{"type": "uint8", "parser": "strconv.ParseUint(s, 0, 8)"},
{"type": "uint16", "parser": "strconv.ParseUint(s, 0, 16)"},
{"type": "uint32", "parser": "strconv.ParseUint(s, 0, 32)"},
{"type": "uint64", "parser": "strconv.ParseUint(s, 0, 64)"},
{"type": "int", "parser": "strconv.ParseFloat(s, 64)", "plural": "Ints"},
{"type": "int8", "parser": "strconv.ParseInt(s, 0, 8)"},
{"type": "int16", "parser": "strconv.ParseInt(s, 0, 16)"},
{"type": "int32", "parser": "strconv.ParseInt(s, 0, 32)"},
{"type": "int64", "parser": "strconv.ParseInt(s, 0, 64)"},
{"type": "float64", "parser": "strconv.ParseFloat(s, 64)"},
{"type": "float32", "parser": "strconv.ParseFloat(s, 32)"},
{"name": "Duration", "type": "time.Duration", "no_value_parser": true},
{"name": "IP", "type": "net.IP", "no_value_parser": true},
{"name": "TCPAddr", "Type": "*net.TCPAddr", "plural": "TCPList", "no_value_parser": true},
{"name": "ExistingFile", "Type": "string", "plural": "ExistingFiles", "no_value_parser": true},
{"name": "ExistingDir", "Type": "string", "plural": "ExistingDirs", "no_value_parser": true},
{"name": "ExistingFileOrDir", "Type": "string", "plural": "ExistingFilesOrDirs", "no_value_parser": true},
{"name": "Regexp", "Type": "*regexp.Regexp", "parser": "regexp.Compile(s)"},
{"name": "ResolvedIP", "Type": "net.IP", "parser": "resolveHost(s)", "help": "Resolve a hostname or IP to an IP."},
{"name": "HexBytes", "Type": "[]byte", "parser": "hex.DecodeString(s)", "help": "Bytes as a hex string."}
]

View File

@ -1,821 +0,0 @@
package kingpin
import (
"encoding/hex"
"fmt"
"net"
"regexp"
"strconv"
"time"
)
// This file is autogenerated by "go generate .". Do not modify.
// -- bool Value
type boolValue struct{ v *bool }
func newBoolValue(p *bool) *boolValue {
return &boolValue{p}
}
func (f *boolValue) Set(s string) error {
v, err := strconv.ParseBool(s)
if err == nil {
*f.v = (bool)(v)
}
return err
}
func (f *boolValue) Get() interface{} { return (bool)(*f.v) }
func (f *boolValue) String() string { return fmt.Sprintf("%v", *f.v) }
// Bool parses the next command-line value as bool.
func (p *parserMixin) Bool() (target *bool) {
target = new(bool)
p.BoolVar(target)
return
}
func (p *parserMixin) BoolVar(target *bool) {
p.SetValue(newBoolValue(target))
}
// BoolList accumulates bool values into a slice.
func (p *parserMixin) BoolList() (target *[]bool) {
target = new([]bool)
p.BoolListVar(target)
return
}
func (p *parserMixin) BoolListVar(target *[]bool) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newBoolValue(v.(*bool))
}))
}
// -- string Value
type stringValue struct{ v *string }
func newStringValue(p *string) *stringValue {
return &stringValue{p}
}
func (f *stringValue) Set(s string) error {
v, err := s, error(nil)
if err == nil {
*f.v = (string)(v)
}
return err
}
func (f *stringValue) Get() interface{} { return (string)(*f.v) }
func (f *stringValue) String() string { return string(*f.v) }
// String parses the next command-line value as string.
func (p *parserMixin) String() (target *string) {
target = new(string)
p.StringVar(target)
return
}
func (p *parserMixin) StringVar(target *string) {
p.SetValue(newStringValue(target))
}
// Strings accumulates string values into a slice.
func (p *parserMixin) Strings() (target *[]string) {
target = new([]string)
p.StringsVar(target)
return
}
func (p *parserMixin) StringsVar(target *[]string) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newStringValue(v.(*string))
}))
}
// -- uint Value
type uintValue struct{ v *uint }
func newUintValue(p *uint) *uintValue {
return &uintValue{p}
}
func (f *uintValue) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 64)
if err == nil {
*f.v = (uint)(v)
}
return err
}
func (f *uintValue) Get() interface{} { return (uint)(*f.v) }
func (f *uintValue) String() string { return fmt.Sprintf("%v", *f.v) }
// Uint parses the next command-line value as uint.
func (p *parserMixin) Uint() (target *uint) {
target = new(uint)
p.UintVar(target)
return
}
func (p *parserMixin) UintVar(target *uint) {
p.SetValue(newUintValue(target))
}
// Uints accumulates uint values into a slice.
func (p *parserMixin) Uints() (target *[]uint) {
target = new([]uint)
p.UintsVar(target)
return
}
func (p *parserMixin) UintsVar(target *[]uint) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newUintValue(v.(*uint))
}))
}
// -- uint8 Value
type uint8Value struct{ v *uint8 }
func newUint8Value(p *uint8) *uint8Value {
return &uint8Value{p}
}
func (f *uint8Value) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 8)
if err == nil {
*f.v = (uint8)(v)
}
return err
}
func (f *uint8Value) Get() interface{} { return (uint8)(*f.v) }
func (f *uint8Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Uint8 parses the next command-line value as uint8.
func (p *parserMixin) Uint8() (target *uint8) {
target = new(uint8)
p.Uint8Var(target)
return
}
func (p *parserMixin) Uint8Var(target *uint8) {
p.SetValue(newUint8Value(target))
}
// Uint8List accumulates uint8 values into a slice.
func (p *parserMixin) Uint8List() (target *[]uint8) {
target = new([]uint8)
p.Uint8ListVar(target)
return
}
func (p *parserMixin) Uint8ListVar(target *[]uint8) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newUint8Value(v.(*uint8))
}))
}
// -- uint16 Value
type uint16Value struct{ v *uint16 }
func newUint16Value(p *uint16) *uint16Value {
return &uint16Value{p}
}
func (f *uint16Value) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 16)
if err == nil {
*f.v = (uint16)(v)
}
return err
}
func (f *uint16Value) Get() interface{} { return (uint16)(*f.v) }
func (f *uint16Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Uint16 parses the next command-line value as uint16.
func (p *parserMixin) Uint16() (target *uint16) {
target = new(uint16)
p.Uint16Var(target)
return
}
func (p *parserMixin) Uint16Var(target *uint16) {
p.SetValue(newUint16Value(target))
}
// Uint16List accumulates uint16 values into a slice.
func (p *parserMixin) Uint16List() (target *[]uint16) {
target = new([]uint16)
p.Uint16ListVar(target)
return
}
func (p *parserMixin) Uint16ListVar(target *[]uint16) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newUint16Value(v.(*uint16))
}))
}
// -- uint32 Value
type uint32Value struct{ v *uint32 }
func newUint32Value(p *uint32) *uint32Value {
return &uint32Value{p}
}
func (f *uint32Value) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 32)
if err == nil {
*f.v = (uint32)(v)
}
return err
}
func (f *uint32Value) Get() interface{} { return (uint32)(*f.v) }
func (f *uint32Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Uint32 parses the next command-line value as uint32.
func (p *parserMixin) Uint32() (target *uint32) {
target = new(uint32)
p.Uint32Var(target)
return
}
func (p *parserMixin) Uint32Var(target *uint32) {
p.SetValue(newUint32Value(target))
}
// Uint32List accumulates uint32 values into a slice.
func (p *parserMixin) Uint32List() (target *[]uint32) {
target = new([]uint32)
p.Uint32ListVar(target)
return
}
func (p *parserMixin) Uint32ListVar(target *[]uint32) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newUint32Value(v.(*uint32))
}))
}
// -- uint64 Value
type uint64Value struct{ v *uint64 }
func newUint64Value(p *uint64) *uint64Value {
return &uint64Value{p}
}
func (f *uint64Value) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 64)
if err == nil {
*f.v = (uint64)(v)
}
return err
}
func (f *uint64Value) Get() interface{} { return (uint64)(*f.v) }
func (f *uint64Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Uint64 parses the next command-line value as uint64.
func (p *parserMixin) Uint64() (target *uint64) {
target = new(uint64)
p.Uint64Var(target)
return
}
func (p *parserMixin) Uint64Var(target *uint64) {
p.SetValue(newUint64Value(target))
}
// Uint64List accumulates uint64 values into a slice.
func (p *parserMixin) Uint64List() (target *[]uint64) {
target = new([]uint64)
p.Uint64ListVar(target)
return
}
func (p *parserMixin) Uint64ListVar(target *[]uint64) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newUint64Value(v.(*uint64))
}))
}
// -- int Value
type intValue struct{ v *int }
func newIntValue(p *int) *intValue {
return &intValue{p}
}
func (f *intValue) Set(s string) error {
v, err := strconv.ParseFloat(s, 64)
if err == nil {
*f.v = (int)(v)
}
return err
}
func (f *intValue) Get() interface{} { return (int)(*f.v) }
func (f *intValue) String() string { return fmt.Sprintf("%v", *f.v) }
// Int parses the next command-line value as int.
func (p *parserMixin) Int() (target *int) {
target = new(int)
p.IntVar(target)
return
}
func (p *parserMixin) IntVar(target *int) {
p.SetValue(newIntValue(target))
}
// Ints accumulates int values into a slice.
func (p *parserMixin) Ints() (target *[]int) {
target = new([]int)
p.IntsVar(target)
return
}
func (p *parserMixin) IntsVar(target *[]int) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newIntValue(v.(*int))
}))
}
// -- int8 Value
type int8Value struct{ v *int8 }
func newInt8Value(p *int8) *int8Value {
return &int8Value{p}
}
func (f *int8Value) Set(s string) error {
v, err := strconv.ParseInt(s, 0, 8)
if err == nil {
*f.v = (int8)(v)
}
return err
}
func (f *int8Value) Get() interface{} { return (int8)(*f.v) }
func (f *int8Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Int8 parses the next command-line value as int8.
func (p *parserMixin) Int8() (target *int8) {
target = new(int8)
p.Int8Var(target)
return
}
func (p *parserMixin) Int8Var(target *int8) {
p.SetValue(newInt8Value(target))
}
// Int8List accumulates int8 values into a slice.
func (p *parserMixin) Int8List() (target *[]int8) {
target = new([]int8)
p.Int8ListVar(target)
return
}
func (p *parserMixin) Int8ListVar(target *[]int8) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newInt8Value(v.(*int8))
}))
}
// -- int16 Value
type int16Value struct{ v *int16 }
func newInt16Value(p *int16) *int16Value {
return &int16Value{p}
}
func (f *int16Value) Set(s string) error {
v, err := strconv.ParseInt(s, 0, 16)
if err == nil {
*f.v = (int16)(v)
}
return err
}
func (f *int16Value) Get() interface{} { return (int16)(*f.v) }
func (f *int16Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Int16 parses the next command-line value as int16.
func (p *parserMixin) Int16() (target *int16) {
target = new(int16)
p.Int16Var(target)
return
}
func (p *parserMixin) Int16Var(target *int16) {
p.SetValue(newInt16Value(target))
}
// Int16List accumulates int16 values into a slice.
func (p *parserMixin) Int16List() (target *[]int16) {
target = new([]int16)
p.Int16ListVar(target)
return
}
func (p *parserMixin) Int16ListVar(target *[]int16) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newInt16Value(v.(*int16))
}))
}
// -- int32 Value
type int32Value struct{ v *int32 }
func newInt32Value(p *int32) *int32Value {
return &int32Value{p}
}
func (f *int32Value) Set(s string) error {
v, err := strconv.ParseInt(s, 0, 32)
if err == nil {
*f.v = (int32)(v)
}
return err
}
func (f *int32Value) Get() interface{} { return (int32)(*f.v) }
func (f *int32Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Int32 parses the next command-line value as int32.
func (p *parserMixin) Int32() (target *int32) {
target = new(int32)
p.Int32Var(target)
return
}
func (p *parserMixin) Int32Var(target *int32) {
p.SetValue(newInt32Value(target))
}
// Int32List accumulates int32 values into a slice.
func (p *parserMixin) Int32List() (target *[]int32) {
target = new([]int32)
p.Int32ListVar(target)
return
}
func (p *parserMixin) Int32ListVar(target *[]int32) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newInt32Value(v.(*int32))
}))
}
// -- int64 Value
type int64Value struct{ v *int64 }
func newInt64Value(p *int64) *int64Value {
return &int64Value{p}
}
func (f *int64Value) Set(s string) error {
v, err := strconv.ParseInt(s, 0, 64)
if err == nil {
*f.v = (int64)(v)
}
return err
}
func (f *int64Value) Get() interface{} { return (int64)(*f.v) }
func (f *int64Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Int64 parses the next command-line value as int64.
func (p *parserMixin) Int64() (target *int64) {
target = new(int64)
p.Int64Var(target)
return
}
func (p *parserMixin) Int64Var(target *int64) {
p.SetValue(newInt64Value(target))
}
// Int64List accumulates int64 values into a slice.
func (p *parserMixin) Int64List() (target *[]int64) {
target = new([]int64)
p.Int64ListVar(target)
return
}
func (p *parserMixin) Int64ListVar(target *[]int64) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newInt64Value(v.(*int64))
}))
}
// -- float64 Value
type float64Value struct{ v *float64 }
func newFloat64Value(p *float64) *float64Value {
return &float64Value{p}
}
func (f *float64Value) Set(s string) error {
v, err := strconv.ParseFloat(s, 64)
if err == nil {
*f.v = (float64)(v)
}
return err
}
func (f *float64Value) Get() interface{} { return (float64)(*f.v) }
func (f *float64Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Float64 parses the next command-line value as float64.
func (p *parserMixin) Float64() (target *float64) {
target = new(float64)
p.Float64Var(target)
return
}
func (p *parserMixin) Float64Var(target *float64) {
p.SetValue(newFloat64Value(target))
}
// Float64List accumulates float64 values into a slice.
func (p *parserMixin) Float64List() (target *[]float64) {
target = new([]float64)
p.Float64ListVar(target)
return
}
func (p *parserMixin) Float64ListVar(target *[]float64) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newFloat64Value(v.(*float64))
}))
}
// -- float32 Value
type float32Value struct{ v *float32 }
func newFloat32Value(p *float32) *float32Value {
return &float32Value{p}
}
func (f *float32Value) Set(s string) error {
v, err := strconv.ParseFloat(s, 32)
if err == nil {
*f.v = (float32)(v)
}
return err
}
func (f *float32Value) Get() interface{} { return (float32)(*f.v) }
func (f *float32Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Float32 parses the next command-line value as float32.
func (p *parserMixin) Float32() (target *float32) {
target = new(float32)
p.Float32Var(target)
return
}
func (p *parserMixin) Float32Var(target *float32) {
p.SetValue(newFloat32Value(target))
}
// Float32List accumulates float32 values into a slice.
func (p *parserMixin) Float32List() (target *[]float32) {
target = new([]float32)
p.Float32ListVar(target)
return
}
func (p *parserMixin) Float32ListVar(target *[]float32) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newFloat32Value(v.(*float32))
}))
}
// DurationList accumulates time.Duration values into a slice.
func (p *parserMixin) DurationList() (target *[]time.Duration) {
target = new([]time.Duration)
p.DurationListVar(target)
return
}
func (p *parserMixin) DurationListVar(target *[]time.Duration) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newDurationValue(v.(*time.Duration))
}))
}
// IPList accumulates net.IP values into a slice.
func (p *parserMixin) IPList() (target *[]net.IP) {
target = new([]net.IP)
p.IPListVar(target)
return
}
func (p *parserMixin) IPListVar(target *[]net.IP) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newIPValue(v.(*net.IP))
}))
}
// TCPList accumulates *net.TCPAddr values into a slice.
func (p *parserMixin) TCPList() (target *[]*net.TCPAddr) {
target = new([]*net.TCPAddr)
p.TCPListVar(target)
return
}
func (p *parserMixin) TCPListVar(target *[]*net.TCPAddr) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newTCPAddrValue(v.(**net.TCPAddr))
}))
}
// ExistingFiles accumulates string values into a slice.
func (p *parserMixin) ExistingFiles() (target *[]string) {
target = new([]string)
p.ExistingFilesVar(target)
return
}
func (p *parserMixin) ExistingFilesVar(target *[]string) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newExistingFileValue(v.(*string))
}))
}
// ExistingDirs accumulates string values into a slice.
func (p *parserMixin) ExistingDirs() (target *[]string) {
target = new([]string)
p.ExistingDirsVar(target)
return
}
func (p *parserMixin) ExistingDirsVar(target *[]string) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newExistingDirValue(v.(*string))
}))
}
// ExistingFilesOrDirs accumulates string values into a slice.
func (p *parserMixin) ExistingFilesOrDirs() (target *[]string) {
target = new([]string)
p.ExistingFilesOrDirsVar(target)
return
}
func (p *parserMixin) ExistingFilesOrDirsVar(target *[]string) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newExistingFileOrDirValue(v.(*string))
}))
}
// -- *regexp.Regexp Value
type regexpValue struct{ v **regexp.Regexp }
func newRegexpValue(p **regexp.Regexp) *regexpValue {
return &regexpValue{p}
}
func (f *regexpValue) Set(s string) error {
v, err := regexp.Compile(s)
if err == nil {
*f.v = (*regexp.Regexp)(v)
}
return err
}
func (f *regexpValue) Get() interface{} { return (*regexp.Regexp)(*f.v) }
func (f *regexpValue) String() string { return fmt.Sprintf("%v", *f.v) }
// Regexp parses the next command-line value as *regexp.Regexp.
func (p *parserMixin) Regexp() (target **regexp.Regexp) {
target = new(*regexp.Regexp)
p.RegexpVar(target)
return
}
func (p *parserMixin) RegexpVar(target **regexp.Regexp) {
p.SetValue(newRegexpValue(target))
}
// RegexpList accumulates *regexp.Regexp values into a slice.
func (p *parserMixin) RegexpList() (target *[]*regexp.Regexp) {
target = new([]*regexp.Regexp)
p.RegexpListVar(target)
return
}
func (p *parserMixin) RegexpListVar(target *[]*regexp.Regexp) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newRegexpValue(v.(**regexp.Regexp))
}))
}
// -- net.IP Value
type resolvedIPValue struct{ v *net.IP }
func newResolvedIPValue(p *net.IP) *resolvedIPValue {
return &resolvedIPValue{p}
}
func (f *resolvedIPValue) Set(s string) error {
v, err := resolveHost(s)
if err == nil {
*f.v = (net.IP)(v)
}
return err
}
func (f *resolvedIPValue) Get() interface{} { return (net.IP)(*f.v) }
func (f *resolvedIPValue) String() string { return fmt.Sprintf("%v", *f.v) }
// Resolve a hostname or IP to an IP.
func (p *parserMixin) ResolvedIP() (target *net.IP) {
target = new(net.IP)
p.ResolvedIPVar(target)
return
}
func (p *parserMixin) ResolvedIPVar(target *net.IP) {
p.SetValue(newResolvedIPValue(target))
}
// ResolvedIPList accumulates net.IP values into a slice.
func (p *parserMixin) ResolvedIPList() (target *[]net.IP) {
target = new([]net.IP)
p.ResolvedIPListVar(target)
return
}
func (p *parserMixin) ResolvedIPListVar(target *[]net.IP) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newResolvedIPValue(v.(*net.IP))
}))
}
// -- []byte Value
type hexBytesValue struct{ v *[]byte }
func newHexBytesValue(p *[]byte) *hexBytesValue {
return &hexBytesValue{p}
}
func (f *hexBytesValue) Set(s string) error {
v, err := hex.DecodeString(s)
if err == nil {
*f.v = ([]byte)(v)
}
return err
}
func (f *hexBytesValue) Get() interface{} { return ([]byte)(*f.v) }
func (f *hexBytesValue) String() string { return fmt.Sprintf("%v", *f.v) }
// Bytes as a hex string.
func (p *parserMixin) HexBytes() (target *[]byte) {
target = new([]byte)
p.HexBytesVar(target)
return
}
func (p *parserMixin) HexBytesVar(target *[]byte) {
p.SetValue(newHexBytesValue(target))
}
// HexBytesList accumulates []byte values into a slice.
func (p *parserMixin) HexBytesList() (target *[][]byte) {
target = new([][]byte)
p.HexBytesListVar(target)
return
}
func (p *parserMixin) HexBytesListVar(target *[][]byte) {
p.SetValue(newAccumulator(target, func(v interface{}) Value {
return newHexBytesValue(v.(*[]byte))
}))
}

View File

@ -1,19 +0,0 @@
Copyright (C) 2014 Alec Thomas
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,13 +0,0 @@
[![Go Reference](https://pkg.go.dev/badge/github.com/alecthomas/units.svg)](https://pkg.go.dev/github.com/alecthomas/units)
# Units - Helpful unit multipliers and functions for Go
The goal of this package is to have functionality similar to the [time](http://golang.org/pkg/time/) package.
It allows for code like this:
```go
n, err := ParseBase2Bytes("1KB")
// n == 1024
n = units.Mebibyte * 512
```

View File

@ -1,209 +0,0 @@
package units
// Base2Bytes is the old non-SI power-of-2 byte scale (1024 bytes in a kilobyte,
// etc.).
type Base2Bytes int64
// Base-2 byte units.
const (
Kibibyte Base2Bytes = 1024
KiB = Kibibyte
Mebibyte = Kibibyte * 1024
MiB = Mebibyte
Gibibyte = Mebibyte * 1024
GiB = Gibibyte
Tebibyte = Gibibyte * 1024
TiB = Tebibyte
Pebibyte = Tebibyte * 1024
PiB = Pebibyte
Exbibyte = Pebibyte * 1024
EiB = Exbibyte
)
var (
bytesUnitMap = MakeUnitMap("iB", "B", 1024)
oldBytesUnitMap = MakeUnitMap("B", "B", 1024)
)
// ParseBase2Bytes supports both iB and B in base-2 multipliers. That is, KB
// and KiB are both 1024.
// However "kB", which is the correct SI spelling of 1000 Bytes, is rejected.
func ParseBase2Bytes(s string) (Base2Bytes, error) {
n, err := ParseUnit(s, bytesUnitMap)
if err != nil {
n, err = ParseUnit(s, oldBytesUnitMap)
}
return Base2Bytes(n), err
}
func (b Base2Bytes) String() string {
return ToString(int64(b), 1024, "iB", "B")
}
// MarshalText implement encoding.TextMarshaler to process json/yaml.
func (b Base2Bytes) MarshalText() ([]byte, error) {
return []byte(b.String()), nil
}
// UnmarshalText implement encoding.TextUnmarshaler to process json/yaml.
func (b *Base2Bytes) UnmarshalText(text []byte) error {
n, err := ParseBase2Bytes(string(text))
*b = n
return err
}
// Floor returns Base2Bytes with all but the largest unit zeroed out. So that e.g. 1GiB1MiB1KiB → 1GiB.
func (b Base2Bytes) Floor() Base2Bytes {
switch {
case b > Exbibyte:
return (b / Exbibyte) * Exbibyte
case b > Pebibyte:
return (b / Pebibyte) * Pebibyte
case b > Tebibyte:
return (b / Tebibyte) * Tebibyte
case b > Gibibyte:
return (b / Gibibyte) * Gibibyte
case b > Mebibyte:
return (b / Mebibyte) * Mebibyte
case b > Kibibyte:
return (b / Kibibyte) * Kibibyte
default:
return b
}
}
// Round returns Base2Bytes with all but the first n units zeroed out. So that e.g. 1GiB1MiB1KiB → 1GiB1MiB, if n is 2.
func (b Base2Bytes) Round(n int) Base2Bytes {
idx := 0
switch {
case b > Exbibyte:
idx = n
case b > Pebibyte:
idx = n + 1
case b > Tebibyte:
idx = n + 2
case b > Gibibyte:
idx = n + 3
case b > Mebibyte:
idx = n + 4
case b > Kibibyte:
idx = n + 5
}
switch idx {
case 1:
return b - b%Exbibyte
case 2:
return b - b%Pebibyte
case 3:
return b - b%Tebibyte
case 4:
return b - b%Gibibyte
case 5:
return b - b%Mebibyte
case 6:
return b - b%Kibibyte
default:
return b
}
}
var metricBytesUnitMap = MakeUnitMap("B", "B", 1000)
// MetricBytes are SI byte units (1000 bytes in a kilobyte).
type MetricBytes SI
// SI base-10 byte units.
const (
Kilobyte MetricBytes = 1000
KB = Kilobyte
Megabyte = Kilobyte * 1000
MB = Megabyte
Gigabyte = Megabyte * 1000
GB = Gigabyte
Terabyte = Gigabyte * 1000
TB = Terabyte
Petabyte = Terabyte * 1000
PB = Petabyte
Exabyte = Petabyte * 1000
EB = Exabyte
)
// ParseMetricBytes parses base-10 metric byte units. That is, KB is 1000 bytes.
func ParseMetricBytes(s string) (MetricBytes, error) {
n, err := ParseUnit(s, metricBytesUnitMap)
return MetricBytes(n), err
}
// TODO: represents 1000B as uppercase "KB", while SI standard requires "kB".
func (m MetricBytes) String() string {
return ToString(int64(m), 1000, "B", "B")
}
// Floor returns MetricBytes with all but the largest unit zeroed out. So that e.g. 1GB1MB1KB → 1GB.
func (b MetricBytes) Floor() MetricBytes {
switch {
case b > Exabyte:
return (b / Exabyte) * Exabyte
case b > Petabyte:
return (b / Petabyte) * Petabyte
case b > Terabyte:
return (b / Terabyte) * Terabyte
case b > Gigabyte:
return (b / Gigabyte) * Gigabyte
case b > Megabyte:
return (b / Megabyte) * Megabyte
case b > Kilobyte:
return (b / Kilobyte) * Kilobyte
default:
return b
}
}
// Round returns MetricBytes with all but the first n units zeroed out. So that e.g. 1GB1MB1KB → 1GB1MB, if n is 2.
func (b MetricBytes) Round(n int) MetricBytes {
idx := 0
switch {
case b > Exabyte:
idx = n
case b > Petabyte:
idx = n + 1
case b > Terabyte:
idx = n + 2
case b > Gigabyte:
idx = n + 3
case b > Megabyte:
idx = n + 4
case b > Kilobyte:
idx = n + 5
}
switch idx {
case 1:
return b - b%Exabyte
case 2:
return b - b%Petabyte
case 3:
return b - b%Terabyte
case 4:
return b - b%Gigabyte
case 5:
return b - b%Megabyte
case 6:
return b - b%Kilobyte
default:
return b
}
}
// ParseStrictBytes supports both iB and B suffixes for base 2 and metric,
// respectively. That is, KiB represents 1024 and kB, KB represent 1000.
func ParseStrictBytes(s string) (int64, error) {
n, err := ParseUnit(s, bytesUnitMap)
if err != nil {
n, err = ParseUnit(s, metricBytesUnitMap)
}
return int64(n), err
}

View File

@ -1,13 +0,0 @@
// Package units provides helpful unit multipliers and functions for Go.
//
// The goal of this package is to have functionality similar to the time [1] package.
//
//
// [1] http://golang.org/pkg/time/
//
// It allows for code like this:
//
// n, err := ParseBase2Bytes("1KB")
// // n == 1024
// n = units.Mebibyte * 512
package units

View File

@ -1,11 +0,0 @@
{
$schema: "https://docs.renovatebot.com/renovate-schema.json",
extends: [
"config:recommended",
":semanticCommits",
":semanticCommitTypeAll(chore)",
":semanticCommitScope(deps)",
"group:allNonMajor",
"schedule:earlyMondays", // Run once a week.
],
}

View File

@ -1,50 +0,0 @@
package units
// SI units.
type SI int64
// SI unit multiples.
const (
Kilo SI = 1000
Mega = Kilo * 1000
Giga = Mega * 1000
Tera = Giga * 1000
Peta = Tera * 1000
Exa = Peta * 1000
)
func MakeUnitMap(suffix, shortSuffix string, scale int64) map[string]float64 {
res := map[string]float64{
shortSuffix: 1,
// see below for "k" / "K"
"M" + suffix: float64(scale * scale),
"G" + suffix: float64(scale * scale * scale),
"T" + suffix: float64(scale * scale * scale * scale),
"P" + suffix: float64(scale * scale * scale * scale * scale),
"E" + suffix: float64(scale * scale * scale * scale * scale * scale),
}
// Standard SI prefixes use lowercase "k" for kilo = 1000.
// For compatibility, and to be fool-proof, we accept both "k" and "K" in metric mode.
//
// However, official binary prefixes are always capitalized - "KiB" -
// and we specifically never parse "kB" as 1024B because:
//
// (1) people pedantic enough to use lowercase according to SI unlikely to abuse "k" to mean 1024 :-)
//
// (2) Use of capital K for 1024 was an informal tradition predating IEC prefixes:
// "The binary meaning of the kilobyte for 1024 bytes typically uses the symbol KB, with an
// uppercase letter K."
// -- https://en.wikipedia.org/wiki/Kilobyte#Base_2_(1024_bytes)
// "Capitalization of the letter K became the de facto standard for binary notation, although this
// could not be extended to higher powers, and use of the lowercase k did persist.[13][14][15]"
// -- https://en.wikipedia.org/wiki/Binary_prefix#History
// See also the extensive https://en.wikipedia.org/wiki/Timeline_of_binary_prefixes.
if scale == 1024 {
res["K"+suffix] = float64(scale)
} else {
res["k"+suffix] = float64(scale)
res["K"+suffix] = float64(scale)
}
return res
}

View File

@ -1,138 +0,0 @@
package units
import (
"errors"
"fmt"
"strings"
)
var (
siUnits = []string{"", "K", "M", "G", "T", "P", "E"}
)
func ToString(n int64, scale int64, suffix, baseSuffix string) string {
mn := len(siUnits)
out := make([]string, mn)
for i, m := range siUnits {
if n%scale != 0 || i == 0 && n == 0 {
s := suffix
if i == 0 {
s = baseSuffix
}
out[mn-1-i] = fmt.Sprintf("%d%s%s", n%scale, m, s)
}
n /= scale
if n == 0 {
break
}
}
return strings.Join(out, "")
}
// Below code ripped straight from http://golang.org/src/pkg/time/format.go?s=33392:33438#L1123
var errLeadingInt = errors.New("units: bad [0-9]*") // never printed
// leadingInt consumes the leading [0-9]* from s.
func leadingInt(s string) (x int64, rem string, err error) {
i := 0
for ; i < len(s); i++ {
c := s[i]
if c < '0' || c > '9' {
break
}
if x >= (1<<63-10)/10 {
// overflow
return 0, "", errLeadingInt
}
x = x*10 + int64(c) - '0'
}
return x, s[i:], nil
}
func ParseUnit(s string, unitMap map[string]float64) (int64, error) {
// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
orig := s
f := float64(0)
neg := false
// Consume [-+]?
if s != "" {
c := s[0]
if c == '-' || c == '+' {
neg = c == '-'
s = s[1:]
}
}
// Special case: if all that is left is "0", this is zero.
if s == "0" {
return 0, nil
}
if s == "" {
return 0, errors.New("units: invalid " + orig)
}
for s != "" {
g := float64(0) // this element of the sequence
var x int64
var err error
// The next character must be [0-9.]
if !(s[0] == '.' || ('0' <= s[0] && s[0] <= '9')) {
return 0, errors.New("units: invalid " + orig)
}
// Consume [0-9]*
pl := len(s)
x, s, err = leadingInt(s)
if err != nil {
return 0, errors.New("units: invalid " + orig)
}
g = float64(x)
pre := pl != len(s) // whether we consumed anything before a period
// Consume (\.[0-9]*)?
post := false
if s != "" && s[0] == '.' {
s = s[1:]
pl := len(s)
x, s, err = leadingInt(s)
if err != nil {
return 0, errors.New("units: invalid " + orig)
}
scale := 1.0
for n := pl - len(s); n > 0; n-- {
scale *= 10
}
g += float64(x) / scale
post = pl != len(s)
}
if !pre && !post {
// no digits (e.g. ".s" or "-.s")
return 0, errors.New("units: invalid " + orig)
}
// Consume unit.
i := 0
for ; i < len(s); i++ {
c := s[i]
if c == '.' || ('0' <= c && c <= '9') {
break
}
}
u := s[:i]
s = s[i:]
unit, ok := unitMap[u]
if !ok {
return 0, errors.New("units: unknown unit " + u + " in " + orig)
}
f += g * unit
}
if neg {
f = -f
}
if f < float64(-1<<63) || f > float64(1<<63-1) {
return 0, errors.New("units: overflow parsing unit")
}
return int64(f), nil
}

View File

@ -1,26 +0,0 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
/metrics.out
.idea

View File

@ -1,13 +0,0 @@
language: go
go:
- "1.x"
env:
- GO111MODULE=on
install:
- go get ./...
script:
- go test ./...

View File

@ -1,20 +0,0 @@
The MIT License (MIT)
Copyright (c) 2013 Armon Dadgar
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -1,91 +0,0 @@
go-metrics
==========
This library provides a `metrics` package which can be used to instrument code,
expose application metrics, and profile runtime performance in a flexible manner.
Current API: [![GoDoc](https://godoc.org/github.com/armon/go-metrics?status.svg)](https://godoc.org/github.com/armon/go-metrics)
Sinks
-----
The `metrics` package makes use of a `MetricSink` interface to support delivery
to any type of backend. Currently the following sinks are provided:
* StatsiteSink : Sinks to a [statsite](https://github.com/armon/statsite/) instance (TCP)
* StatsdSink: Sinks to a [StatsD](https://github.com/etsy/statsd/) / statsite instance (UDP)
* PrometheusSink: Sinks to a [Prometheus](http://prometheus.io/) metrics endpoint (exposed via HTTP for scrapes)
* InmemSink : Provides in-memory aggregation, can be used to export stats
* FanoutSink : Sinks to multiple sinks. Enables writing to multiple statsite instances for example.
* BlackholeSink : Sinks to nowhere
In addition to the sinks, the `InmemSignal` can be used to catch a signal,
and dump a formatted output of recent metrics. For example, when a process gets
a SIGUSR1, it can dump to stderr recent performance metrics for debugging.
Labels
------
Most metrics do have an equivalent ending with `WithLabels`, such methods
allow to push metrics with labels and use some features of underlying Sinks
(ex: translated into Prometheus labels).
Since some of these labels may increase greatly cardinality of metrics, the
library allow to filter labels using a blacklist/whitelist filtering system
which is global to all metrics.
* If `Config.AllowedLabels` is not nil, then only labels specified in this value will be sent to underlying Sink, otherwise, all labels are sent by default.
* If `Config.BlockedLabels` is not nil, any label specified in this value will not be sent to underlying Sinks.
By default, both `Config.AllowedLabels` and `Config.BlockedLabels` are nil, meaning that
no tags are filetered at all, but it allow to a user to globally block some tags with high
cardinality at application level.
Examples
--------
Here is an example of using the package:
```go
func SlowMethod() {
// Profiling the runtime of a method
defer metrics.MeasureSince([]string{"SlowMethod"}, time.Now())
}
// Configure a statsite sink as the global metrics sink
sink, _ := metrics.NewStatsiteSink("statsite:8125")
metrics.NewGlobal(metrics.DefaultConfig("service-name"), sink)
// Emit a Key/Value pair
metrics.EmitKey([]string{"questions", "meaning of life"}, 42)
```
Here is an example of setting up a signal handler:
```go
// Setup the inmem sink and signal handler
inm := metrics.NewInmemSink(10*time.Second, time.Minute)
sig := metrics.DefaultInmemSignal(inm)
metrics.NewGlobal(metrics.DefaultConfig("service-name"), inm)
// Run some code
inm.SetGauge([]string{"foo"}, 42)
inm.EmitKey([]string{"bar"}, 30)
inm.IncrCounter([]string{"baz"}, 42)
inm.IncrCounter([]string{"baz"}, 1)
inm.IncrCounter([]string{"baz"}, 80)
inm.AddSample([]string{"method", "wow"}, 42)
inm.AddSample([]string{"method", "wow"}, 100)
inm.AddSample([]string{"method", "wow"}, 22)
....
```
When a signal comes in, output like the following will be dumped to stderr:
[2014-01-28 14:57:33.04 -0800 PST][G] 'foo': 42.000
[2014-01-28 14:57:33.04 -0800 PST][P] 'bar': 30.000
[2014-01-28 14:57:33.04 -0800 PST][C] 'baz': Count: 3 Min: 1.000 Mean: 41.000 Max: 80.000 Stddev: 39.509
[2014-01-28 14:57:33.04 -0800 PST][S] 'method.wow': Count: 3 Min: 22.000 Mean: 54.667 Max: 100.000 Stddev: 40.513

View File

@ -1,12 +0,0 @@
// +build !windows
package metrics
import (
"syscall"
)
const (
// DefaultSignal is used with DefaultInmemSignal
DefaultSignal = syscall.SIGUSR1
)

View File

@ -1,13 +0,0 @@
// +build windows
package metrics
import (
"syscall"
)
const (
// DefaultSignal is used with DefaultInmemSignal
// Windows has no SIGUSR1, use SIGBREAK
DefaultSignal = syscall.Signal(21)
)

View File

@ -1,339 +0,0 @@
package metrics
import (
"bytes"
"fmt"
"math"
"net/url"
"strings"
"sync"
"time"
)
var spaceReplacer = strings.NewReplacer(" ", "_")
// InmemSink provides a MetricSink that does in-memory aggregation
// without sending metrics over a network. It can be embedded within
// an application to provide profiling information.
type InmemSink struct {
// How long is each aggregation interval
interval time.Duration
// Retain controls how many metrics interval we keep
retain time.Duration
// maxIntervals is the maximum length of intervals.
// It is retain / interval.
maxIntervals int
// intervals is a slice of the retained intervals
intervals []*IntervalMetrics
intervalLock sync.RWMutex
rateDenom float64
}
// IntervalMetrics stores the aggregated metrics
// for a specific interval
type IntervalMetrics struct {
sync.RWMutex
// The start time of the interval
Interval time.Time
// Gauges maps the key to the last set value
Gauges map[string]GaugeValue
// Points maps the string to the list of emitted values
// from EmitKey
Points map[string][]float32
// Counters maps the string key to a sum of the counter
// values
Counters map[string]SampledValue
// Samples maps the key to an AggregateSample,
// which has the rolled up view of a sample
Samples map[string]SampledValue
// done is closed when this interval has ended, and a new IntervalMetrics
// has been created to receive any future metrics.
done chan struct{}
}
// NewIntervalMetrics creates a new IntervalMetrics for a given interval
func NewIntervalMetrics(intv time.Time) *IntervalMetrics {
return &IntervalMetrics{
Interval: intv,
Gauges: make(map[string]GaugeValue),
Points: make(map[string][]float32),
Counters: make(map[string]SampledValue),
Samples: make(map[string]SampledValue),
done: make(chan struct{}),
}
}
// AggregateSample is used to hold aggregate metrics
// about a sample
type AggregateSample struct {
Count int // The count of emitted pairs
Rate float64 // The values rate per time unit (usually 1 second)
Sum float64 // The sum of values
SumSq float64 `json:"-"` // The sum of squared values
Min float64 // Minimum value
Max float64 // Maximum value
LastUpdated time.Time `json:"-"` // When value was last updated
}
// Computes a Stddev of the values
func (a *AggregateSample) Stddev() float64 {
num := (float64(a.Count) * a.SumSq) - math.Pow(a.Sum, 2)
div := float64(a.Count * (a.Count - 1))
if div == 0 {
return 0
}
return math.Sqrt(num / div)
}
// Computes a mean of the values
func (a *AggregateSample) Mean() float64 {
if a.Count == 0 {
return 0
}
return a.Sum / float64(a.Count)
}
// Ingest is used to update a sample
func (a *AggregateSample) Ingest(v float64, rateDenom float64) {
a.Count++
a.Sum += v
a.SumSq += (v * v)
if v < a.Min || a.Count == 1 {
a.Min = v
}
if v > a.Max || a.Count == 1 {
a.Max = v
}
a.Rate = float64(a.Sum) / rateDenom
a.LastUpdated = time.Now()
}
func (a *AggregateSample) String() string {
if a.Count == 0 {
return "Count: 0"
} else if a.Stddev() == 0 {
return fmt.Sprintf("Count: %d Sum: %0.3f LastUpdated: %s", a.Count, a.Sum, a.LastUpdated)
} else {
return fmt.Sprintf("Count: %d Min: %0.3f Mean: %0.3f Max: %0.3f Stddev: %0.3f Sum: %0.3f LastUpdated: %s",
a.Count, a.Min, a.Mean(), a.Max, a.Stddev(), a.Sum, a.LastUpdated)
}
}
// NewInmemSinkFromURL creates an InmemSink from a URL. It is used
// (and tested) from NewMetricSinkFromURL.
func NewInmemSinkFromURL(u *url.URL) (MetricSink, error) {
params := u.Query()
interval, err := time.ParseDuration(params.Get("interval"))
if err != nil {
return nil, fmt.Errorf("Bad 'interval' param: %s", err)
}
retain, err := time.ParseDuration(params.Get("retain"))
if err != nil {
return nil, fmt.Errorf("Bad 'retain' param: %s", err)
}
return NewInmemSink(interval, retain), nil
}
// NewInmemSink is used to construct a new in-memory sink.
// Uses an aggregation interval and maximum retention period.
func NewInmemSink(interval, retain time.Duration) *InmemSink {
rateTimeUnit := time.Second
i := &InmemSink{
interval: interval,
retain: retain,
maxIntervals: int(retain / interval),
rateDenom: float64(interval.Nanoseconds()) / float64(rateTimeUnit.Nanoseconds()),
}
i.intervals = make([]*IntervalMetrics, 0, i.maxIntervals)
return i
}
func (i *InmemSink) SetGauge(key []string, val float32) {
i.SetGaugeWithLabels(key, val, nil)
}
func (i *InmemSink) SetGaugeWithLabels(key []string, val float32, labels []Label) {
k, name := i.flattenKeyLabels(key, labels)
intv := i.getInterval()
intv.Lock()
defer intv.Unlock()
intv.Gauges[k] = GaugeValue{Name: name, Value: val, Labels: labels}
}
func (i *InmemSink) EmitKey(key []string, val float32) {
k := i.flattenKey(key)
intv := i.getInterval()
intv.Lock()
defer intv.Unlock()
vals := intv.Points[k]
intv.Points[k] = append(vals, val)
}
func (i *InmemSink) IncrCounter(key []string, val float32) {
i.IncrCounterWithLabels(key, val, nil)
}
func (i *InmemSink) IncrCounterWithLabels(key []string, val float32, labels []Label) {
k, name := i.flattenKeyLabels(key, labels)
intv := i.getInterval()
intv.Lock()
defer intv.Unlock()
agg, ok := intv.Counters[k]
if !ok {
agg = SampledValue{
Name: name,
AggregateSample: &AggregateSample{},
Labels: labels,
}
intv.Counters[k] = agg
}
agg.Ingest(float64(val), i.rateDenom)
}
func (i *InmemSink) AddSample(key []string, val float32) {
i.AddSampleWithLabels(key, val, nil)
}
func (i *InmemSink) AddSampleWithLabels(key []string, val float32, labels []Label) {
k, name := i.flattenKeyLabels(key, labels)
intv := i.getInterval()
intv.Lock()
defer intv.Unlock()
agg, ok := intv.Samples[k]
if !ok {
agg = SampledValue{
Name: name,
AggregateSample: &AggregateSample{},
Labels: labels,
}
intv.Samples[k] = agg
}
agg.Ingest(float64(val), i.rateDenom)
}
// Data is used to retrieve all the aggregated metrics
// Intervals may be in use, and a read lock should be acquired
func (i *InmemSink) Data() []*IntervalMetrics {
// Get the current interval, forces creation
i.getInterval()
i.intervalLock.RLock()
defer i.intervalLock.RUnlock()
n := len(i.intervals)
intervals := make([]*IntervalMetrics, n)
copy(intervals[:n-1], i.intervals[:n-1])
current := i.intervals[n-1]
// make its own copy for current interval
intervals[n-1] = &IntervalMetrics{}
copyCurrent := intervals[n-1]
current.RLock()
*copyCurrent = *current
// RWMutex is not safe to copy, so create a new instance on the copy
copyCurrent.RWMutex = sync.RWMutex{}
copyCurrent.Gauges = make(map[string]GaugeValue, len(current.Gauges))
for k, v := range current.Gauges {
copyCurrent.Gauges[k] = v
}
// saved values will be not change, just copy its link
copyCurrent.Points = make(map[string][]float32, len(current.Points))
for k, v := range current.Points {
copyCurrent.Points[k] = v
}
copyCurrent.Counters = make(map[string]SampledValue, len(current.Counters))
for k, v := range current.Counters {
copyCurrent.Counters[k] = v.deepCopy()
}
copyCurrent.Samples = make(map[string]SampledValue, len(current.Samples))
for k, v := range current.Samples {
copyCurrent.Samples[k] = v.deepCopy()
}
current.RUnlock()
return intervals
}
// getInterval returns the current interval. A new interval is created if no
// previous interval exists, or if the current time is beyond the window for the
// current interval.
func (i *InmemSink) getInterval() *IntervalMetrics {
intv := time.Now().Truncate(i.interval)
// Attempt to return the existing interval first, because it only requires
// a read lock.
i.intervalLock.RLock()
n := len(i.intervals)
if n > 0 && i.intervals[n-1].Interval == intv {
defer i.intervalLock.RUnlock()
return i.intervals[n-1]
}
i.intervalLock.RUnlock()
i.intervalLock.Lock()
defer i.intervalLock.Unlock()
// Re-check for an existing interval now that the lock is re-acquired.
n = len(i.intervals)
if n > 0 && i.intervals[n-1].Interval == intv {
return i.intervals[n-1]
}
current := NewIntervalMetrics(intv)
i.intervals = append(i.intervals, current)
if n > 0 {
close(i.intervals[n-1].done)
}
n++
// Prune old intervals if the count exceeds the max.
if n >= i.maxIntervals {
copy(i.intervals[0:], i.intervals[n-i.maxIntervals:])
i.intervals = i.intervals[:i.maxIntervals]
}
return current
}
// Flattens the key for formatting, removes spaces
func (i *InmemSink) flattenKey(parts []string) string {
buf := &bytes.Buffer{}
joined := strings.Join(parts, ".")
spaceReplacer.WriteString(buf, joined)
return buf.String()
}
// Flattens the key for formatting along with its labels, removes spaces
func (i *InmemSink) flattenKeyLabels(parts []string, labels []Label) (string, string) {
key := i.flattenKey(parts)
buf := bytes.NewBufferString(key)
for _, label := range labels {
spaceReplacer.WriteString(buf, fmt.Sprintf(";%s=%s", label.Name, label.Value))
}
return buf.String(), key
}

View File

@ -1,162 +0,0 @@
package metrics
import (
"context"
"fmt"
"net/http"
"sort"
"time"
)
// MetricsSummary holds a roll-up of metrics info for a given interval
type MetricsSummary struct {
Timestamp string
Gauges []GaugeValue
Points []PointValue
Counters []SampledValue
Samples []SampledValue
}
type GaugeValue struct {
Name string
Hash string `json:"-"`
Value float32
Labels []Label `json:"-"`
DisplayLabels map[string]string `json:"Labels"`
}
type PointValue struct {
Name string
Points []float32
}
type SampledValue struct {
Name string
Hash string `json:"-"`
*AggregateSample
Mean float64
Stddev float64
Labels []Label `json:"-"`
DisplayLabels map[string]string `json:"Labels"`
}
// deepCopy allocates a new instance of AggregateSample
func (source *SampledValue) deepCopy() SampledValue {
dest := *source
if source.AggregateSample != nil {
dest.AggregateSample = &AggregateSample{}
*dest.AggregateSample = *source.AggregateSample
}
return dest
}
// DisplayMetrics returns a summary of the metrics from the most recent finished interval.
func (i *InmemSink) DisplayMetrics(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
data := i.Data()
var interval *IntervalMetrics
n := len(data)
switch {
case n == 0:
return nil, fmt.Errorf("no metric intervals have been initialized yet")
case n == 1:
// Show the current interval if it's all we have
interval = data[0]
default:
// Show the most recent finished interval if we have one
interval = data[n-2]
}
return newMetricSummaryFromInterval(interval), nil
}
func newMetricSummaryFromInterval(interval *IntervalMetrics) MetricsSummary {
interval.RLock()
defer interval.RUnlock()
summary := MetricsSummary{
Timestamp: interval.Interval.Round(time.Second).UTC().String(),
Gauges: make([]GaugeValue, 0, len(interval.Gauges)),
Points: make([]PointValue, 0, len(interval.Points)),
}
// Format and sort the output of each metric type, so it gets displayed in a
// deterministic order.
for name, points := range interval.Points {
summary.Points = append(summary.Points, PointValue{name, points})
}
sort.Slice(summary.Points, func(i, j int) bool {
return summary.Points[i].Name < summary.Points[j].Name
})
for hash, value := range interval.Gauges {
value.Hash = hash
value.DisplayLabels = make(map[string]string)
for _, label := range value.Labels {
value.DisplayLabels[label.Name] = label.Value
}
value.Labels = nil
summary.Gauges = append(summary.Gauges, value)
}
sort.Slice(summary.Gauges, func(i, j int) bool {
return summary.Gauges[i].Hash < summary.Gauges[j].Hash
})
summary.Counters = formatSamples(interval.Counters)
summary.Samples = formatSamples(interval.Samples)
return summary
}
func formatSamples(source map[string]SampledValue) []SampledValue {
output := make([]SampledValue, 0, len(source))
for hash, sample := range source {
displayLabels := make(map[string]string)
for _, label := range sample.Labels {
displayLabels[label.Name] = label.Value
}
output = append(output, SampledValue{
Name: sample.Name,
Hash: hash,
AggregateSample: sample.AggregateSample,
Mean: sample.AggregateSample.Mean(),
Stddev: sample.AggregateSample.Stddev(),
DisplayLabels: displayLabels,
})
}
sort.Slice(output, func(i, j int) bool {
return output[i].Hash < output[j].Hash
})
return output
}
type Encoder interface {
Encode(interface{}) error
}
// Stream writes metrics using encoder.Encode each time an interval ends. Runs
// until the request context is cancelled, or the encoder returns an error.
// The caller is responsible for logging any errors from encoder.
func (i *InmemSink) Stream(ctx context.Context, encoder Encoder) {
interval := i.getInterval()
for {
select {
case <-interval.done:
summary := newMetricSummaryFromInterval(interval)
if err := encoder.Encode(summary); err != nil {
return
}
// update interval to the next one
interval = i.getInterval()
case <-ctx.Done():
return
}
}
}

View File

@ -1,117 +0,0 @@
package metrics
import (
"bytes"
"fmt"
"io"
"os"
"os/signal"
"strings"
"sync"
"syscall"
)
// InmemSignal is used to listen for a given signal, and when received,
// to dump the current metrics from the InmemSink to an io.Writer
type InmemSignal struct {
signal syscall.Signal
inm *InmemSink
w io.Writer
sigCh chan os.Signal
stop bool
stopCh chan struct{}
stopLock sync.Mutex
}
// NewInmemSignal creates a new InmemSignal which listens for a given signal,
// and dumps the current metrics out to a writer
func NewInmemSignal(inmem *InmemSink, sig syscall.Signal, w io.Writer) *InmemSignal {
i := &InmemSignal{
signal: sig,
inm: inmem,
w: w,
sigCh: make(chan os.Signal, 1),
stopCh: make(chan struct{}),
}
signal.Notify(i.sigCh, sig)
go i.run()
return i
}
// DefaultInmemSignal returns a new InmemSignal that responds to SIGUSR1
// and writes output to stderr. Windows uses SIGBREAK
func DefaultInmemSignal(inmem *InmemSink) *InmemSignal {
return NewInmemSignal(inmem, DefaultSignal, os.Stderr)
}
// Stop is used to stop the InmemSignal from listening
func (i *InmemSignal) Stop() {
i.stopLock.Lock()
defer i.stopLock.Unlock()
if i.stop {
return
}
i.stop = true
close(i.stopCh)
signal.Stop(i.sigCh)
}
// run is a long running routine that handles signals
func (i *InmemSignal) run() {
for {
select {
case <-i.sigCh:
i.dumpStats()
case <-i.stopCh:
return
}
}
}
// dumpStats is used to dump the data to output writer
func (i *InmemSignal) dumpStats() {
buf := bytes.NewBuffer(nil)
data := i.inm.Data()
// Skip the last period which is still being aggregated
for j := 0; j < len(data)-1; j++ {
intv := data[j]
intv.RLock()
for _, val := range intv.Gauges {
name := i.flattenLabels(val.Name, val.Labels)
fmt.Fprintf(buf, "[%v][G] '%s': %0.3f\n", intv.Interval, name, val.Value)
}
for name, vals := range intv.Points {
for _, val := range vals {
fmt.Fprintf(buf, "[%v][P] '%s': %0.3f\n", intv.Interval, name, val)
}
}
for _, agg := range intv.Counters {
name := i.flattenLabels(agg.Name, agg.Labels)
fmt.Fprintf(buf, "[%v][C] '%s': %s\n", intv.Interval, name, agg.AggregateSample)
}
for _, agg := range intv.Samples {
name := i.flattenLabels(agg.Name, agg.Labels)
fmt.Fprintf(buf, "[%v][S] '%s': %s\n", intv.Interval, name, agg.AggregateSample)
}
intv.RUnlock()
}
// Write out the bytes
i.w.Write(buf.Bytes())
}
// Flattens the key for formatting along with its labels, removes spaces
func (i *InmemSignal) flattenLabels(name string, labels []Label) string {
buf := bytes.NewBufferString(name)
replacer := strings.NewReplacer(" ", "_", ":", "_")
for _, label := range labels {
replacer.WriteString(buf, ".")
replacer.WriteString(buf, label.Value)
}
return buf.String()
}

View File

@ -1,299 +0,0 @@
package metrics
import (
"runtime"
"strings"
"time"
iradix "github.com/hashicorp/go-immutable-radix"
)
type Label struct {
Name string
Value string
}
func (m *Metrics) SetGauge(key []string, val float32) {
m.SetGaugeWithLabels(key, val, nil)
}
func (m *Metrics) SetGaugeWithLabels(key []string, val float32, labels []Label) {
if m.HostName != "" {
if m.EnableHostnameLabel {
labels = append(labels, Label{"host", m.HostName})
} else if m.EnableHostname {
key = insert(0, m.HostName, key)
}
}
if m.EnableTypePrefix {
key = insert(0, "gauge", key)
}
if m.ServiceName != "" {
if m.EnableServiceLabel {
labels = append(labels, Label{"service", m.ServiceName})
} else {
key = insert(0, m.ServiceName, key)
}
}
allowed, labelsFiltered := m.allowMetric(key, labels)
if !allowed {
return
}
m.sink.SetGaugeWithLabels(key, val, labelsFiltered)
}
func (m *Metrics) EmitKey(key []string, val float32) {
if m.EnableTypePrefix {
key = insert(0, "kv", key)
}
if m.ServiceName != "" {
key = insert(0, m.ServiceName, key)
}
allowed, _ := m.allowMetric(key, nil)
if !allowed {
return
}
m.sink.EmitKey(key, val)
}
func (m *Metrics) IncrCounter(key []string, val float32) {
m.IncrCounterWithLabels(key, val, nil)
}
func (m *Metrics) IncrCounterWithLabels(key []string, val float32, labels []Label) {
if m.HostName != "" && m.EnableHostnameLabel {
labels = append(labels, Label{"host", m.HostName})
}
if m.EnableTypePrefix {
key = insert(0, "counter", key)
}
if m.ServiceName != "" {
if m.EnableServiceLabel {
labels = append(labels, Label{"service", m.ServiceName})
} else {
key = insert(0, m.ServiceName, key)
}
}
allowed, labelsFiltered := m.allowMetric(key, labels)
if !allowed {
return
}
m.sink.IncrCounterWithLabels(key, val, labelsFiltered)
}
func (m *Metrics) AddSample(key []string, val float32) {
m.AddSampleWithLabels(key, val, nil)
}
func (m *Metrics) AddSampleWithLabels(key []string, val float32, labels []Label) {
if m.HostName != "" && m.EnableHostnameLabel {
labels = append(labels, Label{"host", m.HostName})
}
if m.EnableTypePrefix {
key = insert(0, "sample", key)
}
if m.ServiceName != "" {
if m.EnableServiceLabel {
labels = append(labels, Label{"service", m.ServiceName})
} else {
key = insert(0, m.ServiceName, key)
}
}
allowed, labelsFiltered := m.allowMetric(key, labels)
if !allowed {
return
}
m.sink.AddSampleWithLabels(key, val, labelsFiltered)
}
func (m *Metrics) MeasureSince(key []string, start time.Time) {
m.MeasureSinceWithLabels(key, start, nil)
}
func (m *Metrics) MeasureSinceWithLabels(key []string, start time.Time, labels []Label) {
if m.HostName != "" && m.EnableHostnameLabel {
labels = append(labels, Label{"host", m.HostName})
}
if m.EnableTypePrefix {
key = insert(0, "timer", key)
}
if m.ServiceName != "" {
if m.EnableServiceLabel {
labels = append(labels, Label{"service", m.ServiceName})
} else {
key = insert(0, m.ServiceName, key)
}
}
allowed, labelsFiltered := m.allowMetric(key, labels)
if !allowed {
return
}
now := time.Now()
elapsed := now.Sub(start)
msec := float32(elapsed.Nanoseconds()) / float32(m.TimerGranularity)
m.sink.AddSampleWithLabels(key, msec, labelsFiltered)
}
// UpdateFilter overwrites the existing filter with the given rules.
func (m *Metrics) UpdateFilter(allow, block []string) {
m.UpdateFilterAndLabels(allow, block, m.AllowedLabels, m.BlockedLabels)
}
// UpdateFilterAndLabels overwrites the existing filter with the given rules.
func (m *Metrics) UpdateFilterAndLabels(allow, block, allowedLabels, blockedLabels []string) {
m.filterLock.Lock()
defer m.filterLock.Unlock()
m.AllowedPrefixes = allow
m.BlockedPrefixes = block
if allowedLabels == nil {
// Having a white list means we take only elements from it
m.allowedLabels = nil
} else {
m.allowedLabels = make(map[string]bool)
for _, v := range allowedLabels {
m.allowedLabels[v] = true
}
}
m.blockedLabels = make(map[string]bool)
for _, v := range blockedLabels {
m.blockedLabels[v] = true
}
m.AllowedLabels = allowedLabels
m.BlockedLabels = blockedLabels
m.filter = iradix.New()
for _, prefix := range m.AllowedPrefixes {
m.filter, _, _ = m.filter.Insert([]byte(prefix), true)
}
for _, prefix := range m.BlockedPrefixes {
m.filter, _, _ = m.filter.Insert([]byte(prefix), false)
}
}
func (m *Metrics) Shutdown() {
if ss, ok := m.sink.(ShutdownSink); ok {
ss.Shutdown()
}
}
// labelIsAllowed return true if a should be included in metric
// the caller should lock m.filterLock while calling this method
func (m *Metrics) labelIsAllowed(label *Label) bool {
labelName := (*label).Name
if m.blockedLabels != nil {
_, ok := m.blockedLabels[labelName]
if ok {
// If present, let's remove this label
return false
}
}
if m.allowedLabels != nil {
_, ok := m.allowedLabels[labelName]
return ok
}
// Allow by default
return true
}
// filterLabels return only allowed labels
// the caller should lock m.filterLock while calling this method
func (m *Metrics) filterLabels(labels []Label) []Label {
if labels == nil {
return nil
}
toReturn := []Label{}
for _, label := range labels {
if m.labelIsAllowed(&label) {
toReturn = append(toReturn, label)
}
}
return toReturn
}
// Returns whether the metric should be allowed based on configured prefix filters
// Also return the applicable labels
func (m *Metrics) allowMetric(key []string, labels []Label) (bool, []Label) {
m.filterLock.RLock()
defer m.filterLock.RUnlock()
if m.filter == nil || m.filter.Len() == 0 {
return m.Config.FilterDefault, m.filterLabels(labels)
}
_, allowed, ok := m.filter.Root().LongestPrefix([]byte(strings.Join(key, ".")))
if !ok {
return m.Config.FilterDefault, m.filterLabels(labels)
}
return allowed.(bool), m.filterLabels(labels)
}
// Periodically collects runtime stats to publish
func (m *Metrics) collectStats() {
for {
time.Sleep(m.ProfileInterval)
m.EmitRuntimeStats()
}
}
// Emits various runtime statsitics
func (m *Metrics) EmitRuntimeStats() {
// Export number of Goroutines
numRoutines := runtime.NumGoroutine()
m.SetGauge([]string{"runtime", "num_goroutines"}, float32(numRoutines))
// Export memory stats
var stats runtime.MemStats
runtime.ReadMemStats(&stats)
m.SetGauge([]string{"runtime", "alloc_bytes"}, float32(stats.Alloc))
m.SetGauge([]string{"runtime", "sys_bytes"}, float32(stats.Sys))
m.SetGauge([]string{"runtime", "malloc_count"}, float32(stats.Mallocs))
m.SetGauge([]string{"runtime", "free_count"}, float32(stats.Frees))
m.SetGauge([]string{"runtime", "heap_objects"}, float32(stats.HeapObjects))
m.SetGauge([]string{"runtime", "total_gc_pause_ns"}, float32(stats.PauseTotalNs))
m.SetGauge([]string{"runtime", "total_gc_runs"}, float32(stats.NumGC))
// Export info about the last few GC runs
num := stats.NumGC
// Handle wrap around
if num < m.lastNumGC {
m.lastNumGC = 0
}
// Ensure we don't scan more than 256
if num-m.lastNumGC >= 256 {
m.lastNumGC = num - 255
}
for i := m.lastNumGC; i < num; i++ {
pause := stats.PauseNs[i%256]
m.AddSample([]string{"runtime", "gc_pause_ns"}, float32(pause))
}
m.lastNumGC = num
}
// Creates a new slice with the provided string value as the first element
// and the provided slice values as the remaining values.
// Ordering of the values in the provided input slice is kept in tact in the output slice.
func insert(i int, v string, s []string) []string {
// Allocate new slice to avoid modifying the input slice
newS := make([]string, len(s)+1)
// Copy s[0, i-1] into newS
for j := 0; j < i; j++ {
newS[j] = s[j]
}
// Insert provided element at index i
newS[i] = v
// Copy s[i, len(s)-1] into newS starting at newS[i+1]
for j := i; j < len(s); j++ {
newS[j+1] = s[j]
}
return newS
}

View File

@ -1,132 +0,0 @@
package metrics
import (
"fmt"
"net/url"
)
// The MetricSink interface is used to transmit metrics information
// to an external system
type MetricSink interface {
// A Gauge should retain the last value it is set to
SetGauge(key []string, val float32)
SetGaugeWithLabels(key []string, val float32, labels []Label)
// Should emit a Key/Value pair for each call
EmitKey(key []string, val float32)
// Counters should accumulate values
IncrCounter(key []string, val float32)
IncrCounterWithLabels(key []string, val float32, labels []Label)
// Samples are for timing information, where quantiles are used
AddSample(key []string, val float32)
AddSampleWithLabels(key []string, val float32, labels []Label)
}
type ShutdownSink interface {
MetricSink
// Shutdown the metric sink, flush metrics to storage, and cleanup resources.
// Called immediately prior to application exit. Implementations must block
// until metrics are flushed to storage.
Shutdown()
}
// BlackholeSink is used to just blackhole messages
type BlackholeSink struct{}
func (*BlackholeSink) SetGauge(key []string, val float32) {}
func (*BlackholeSink) SetGaugeWithLabels(key []string, val float32, labels []Label) {}
func (*BlackholeSink) EmitKey(key []string, val float32) {}
func (*BlackholeSink) IncrCounter(key []string, val float32) {}
func (*BlackholeSink) IncrCounterWithLabels(key []string, val float32, labels []Label) {}
func (*BlackholeSink) AddSample(key []string, val float32) {}
func (*BlackholeSink) AddSampleWithLabels(key []string, val float32, labels []Label) {}
// FanoutSink is used to sink to fanout values to multiple sinks
type FanoutSink []MetricSink
func (fh FanoutSink) SetGauge(key []string, val float32) {
fh.SetGaugeWithLabels(key, val, nil)
}
func (fh FanoutSink) SetGaugeWithLabels(key []string, val float32, labels []Label) {
for _, s := range fh {
s.SetGaugeWithLabels(key, val, labels)
}
}
func (fh FanoutSink) EmitKey(key []string, val float32) {
for _, s := range fh {
s.EmitKey(key, val)
}
}
func (fh FanoutSink) IncrCounter(key []string, val float32) {
fh.IncrCounterWithLabels(key, val, nil)
}
func (fh FanoutSink) IncrCounterWithLabels(key []string, val float32, labels []Label) {
for _, s := range fh {
s.IncrCounterWithLabels(key, val, labels)
}
}
func (fh FanoutSink) AddSample(key []string, val float32) {
fh.AddSampleWithLabels(key, val, nil)
}
func (fh FanoutSink) AddSampleWithLabels(key []string, val float32, labels []Label) {
for _, s := range fh {
s.AddSampleWithLabels(key, val, labels)
}
}
func (fh FanoutSink) Shutdown() {
for _, s := range fh {
if ss, ok := s.(ShutdownSink); ok {
ss.Shutdown()
}
}
}
// sinkURLFactoryFunc is an generic interface around the *SinkFromURL() function provided
// by each sink type
type sinkURLFactoryFunc func(*url.URL) (MetricSink, error)
// sinkRegistry supports the generic NewMetricSink function by mapping URL
// schemes to metric sink factory functions
var sinkRegistry = map[string]sinkURLFactoryFunc{
"statsd": NewStatsdSinkFromURL,
"statsite": NewStatsiteSinkFromURL,
"inmem": NewInmemSinkFromURL,
}
// NewMetricSinkFromURL allows a generic URL input to configure any of the
// supported sinks. The scheme of the URL identifies the type of the sink, the
// and query parameters are used to set options.
//
// "statsd://" - Initializes a StatsdSink. The host and port are passed through
// as the "addr" of the sink
//
// "statsite://" - Initializes a StatsiteSink. The host and port become the
// "addr" of the sink
//
// "inmem://" - Initializes an InmemSink. The host and port are ignored. The
// "interval" and "duration" query parameters must be specified with valid
// durations, see NewInmemSink for details.
func NewMetricSinkFromURL(urlStr string) (MetricSink, error) {
u, err := url.Parse(urlStr)
if err != nil {
return nil, err
}
sinkURLFactoryFunc := sinkRegistry[u.Scheme]
if sinkURLFactoryFunc == nil {
return nil, fmt.Errorf(
"cannot create metric sink, unrecognized sink name: %q", u.Scheme)
}
return sinkURLFactoryFunc(u)
}

View File

@ -1,158 +0,0 @@
package metrics
import (
"os"
"sync"
"sync/atomic"
"time"
iradix "github.com/hashicorp/go-immutable-radix"
)
// Config is used to configure metrics settings
type Config struct {
ServiceName string // Prefixed with keys to separate services
HostName string // Hostname to use. If not provided and EnableHostname, it will be os.Hostname
EnableHostname bool // Enable prefixing gauge values with hostname
EnableHostnameLabel bool // Enable adding hostname to labels
EnableServiceLabel bool // Enable adding service to labels
EnableRuntimeMetrics bool // Enables profiling of runtime metrics (GC, Goroutines, Memory)
EnableTypePrefix bool // Prefixes key with a type ("counter", "gauge", "timer")
TimerGranularity time.Duration // Granularity of timers.
ProfileInterval time.Duration // Interval to profile runtime metrics
AllowedPrefixes []string // A list of metric prefixes to allow, with '.' as the separator
BlockedPrefixes []string // A list of metric prefixes to block, with '.' as the separator
AllowedLabels []string // A list of metric labels to allow, with '.' as the separator
BlockedLabels []string // A list of metric labels to block, with '.' as the separator
FilterDefault bool // Whether to allow metrics by default
}
// Metrics represents an instance of a metrics sink that can
// be used to emit
type Metrics struct {
Config
lastNumGC uint32
sink MetricSink
filter *iradix.Tree
allowedLabels map[string]bool
blockedLabels map[string]bool
filterLock sync.RWMutex // Lock filters and allowedLabels/blockedLabels access
}
// Shared global metrics instance
var globalMetrics atomic.Value // *Metrics
func init() {
// Initialize to a blackhole sink to avoid errors
globalMetrics.Store(&Metrics{sink: &BlackholeSink{}})
}
// Default returns the shared global metrics instance.
func Default() *Metrics {
return globalMetrics.Load().(*Metrics)
}
// DefaultConfig provides a sane default configuration
func DefaultConfig(serviceName string) *Config {
c := &Config{
ServiceName: serviceName, // Use client provided service
HostName: "",
EnableHostname: true, // Enable hostname prefix
EnableRuntimeMetrics: true, // Enable runtime profiling
EnableTypePrefix: false, // Disable type prefix
TimerGranularity: time.Millisecond, // Timers are in milliseconds
ProfileInterval: time.Second, // Poll runtime every second
FilterDefault: true, // Don't filter metrics by default
}
// Try to get the hostname
name, _ := os.Hostname()
c.HostName = name
return c
}
// New is used to create a new instance of Metrics
func New(conf *Config, sink MetricSink) (*Metrics, error) {
met := &Metrics{}
met.Config = *conf
met.sink = sink
met.UpdateFilterAndLabels(conf.AllowedPrefixes, conf.BlockedPrefixes, conf.AllowedLabels, conf.BlockedLabels)
// Start the runtime collector
if conf.EnableRuntimeMetrics {
go met.collectStats()
}
return met, nil
}
// NewGlobal is the same as New, but it assigns the metrics object to be
// used globally as well as returning it.
func NewGlobal(conf *Config, sink MetricSink) (*Metrics, error) {
metrics, err := New(conf, sink)
if err == nil {
globalMetrics.Store(metrics)
}
return metrics, err
}
// Proxy all the methods to the globalMetrics instance
func SetGauge(key []string, val float32) {
globalMetrics.Load().(*Metrics).SetGauge(key, val)
}
func SetGaugeWithLabels(key []string, val float32, labels []Label) {
globalMetrics.Load().(*Metrics).SetGaugeWithLabels(key, val, labels)
}
func EmitKey(key []string, val float32) {
globalMetrics.Load().(*Metrics).EmitKey(key, val)
}
func IncrCounter(key []string, val float32) {
globalMetrics.Load().(*Metrics).IncrCounter(key, val)
}
func IncrCounterWithLabels(key []string, val float32, labels []Label) {
globalMetrics.Load().(*Metrics).IncrCounterWithLabels(key, val, labels)
}
func AddSample(key []string, val float32) {
globalMetrics.Load().(*Metrics).AddSample(key, val)
}
func AddSampleWithLabels(key []string, val float32, labels []Label) {
globalMetrics.Load().(*Metrics).AddSampleWithLabels(key, val, labels)
}
func MeasureSince(key []string, start time.Time) {
globalMetrics.Load().(*Metrics).MeasureSince(key, start)
}
func MeasureSinceWithLabels(key []string, start time.Time, labels []Label) {
globalMetrics.Load().(*Metrics).MeasureSinceWithLabels(key, start, labels)
}
func UpdateFilter(allow, block []string) {
globalMetrics.Load().(*Metrics).UpdateFilter(allow, block)
}
// UpdateFilterAndLabels set allow/block prefixes of metrics while allowedLabels
// and blockedLabels - when not nil - allow filtering of labels in order to
// block/allow globally labels (especially useful when having large number of
// values for a given label). See README.md for more information about usage.
func UpdateFilterAndLabels(allow, block, allowedLabels, blockedLabels []string) {
globalMetrics.Load().(*Metrics).UpdateFilterAndLabels(allow, block, allowedLabels, blockedLabels)
}
// Shutdown disables metric collection, then blocks while attempting to flush metrics to storage.
// WARNING: Not all MetricSink backends support this functionality, and calling this will cause them to leak resources.
// This is intended for use immediately prior to application exit.
func Shutdown() {
m := globalMetrics.Load().(*Metrics)
// Swap whatever MetricSink is currently active with a BlackholeSink. Callers must not have a
// reason to expect that calls to the library will successfully collect metrics after Shutdown
// has been called.
globalMetrics.Store(&Metrics{sink: &BlackholeSink{}})
m.Shutdown()
}

View File

@ -1,184 +0,0 @@
package metrics
import (
"bytes"
"fmt"
"log"
"net"
"net/url"
"strings"
"time"
)
const (
// statsdMaxLen is the maximum size of a packet
// to send to statsd
statsdMaxLen = 1400
)
// StatsdSink provides a MetricSink that can be used
// with a statsite or statsd metrics server. It uses
// only UDP packets, while StatsiteSink uses TCP.
type StatsdSink struct {
addr string
metricQueue chan string
}
// NewStatsdSinkFromURL creates an StatsdSink from a URL. It is used
// (and tested) from NewMetricSinkFromURL.
func NewStatsdSinkFromURL(u *url.URL) (MetricSink, error) {
return NewStatsdSink(u.Host)
}
// NewStatsdSink is used to create a new StatsdSink
func NewStatsdSink(addr string) (*StatsdSink, error) {
s := &StatsdSink{
addr: addr,
metricQueue: make(chan string, 4096),
}
go s.flushMetrics()
return s, nil
}
// Close is used to stop flushing to statsd
func (s *StatsdSink) Shutdown() {
close(s.metricQueue)
}
func (s *StatsdSink) SetGauge(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|g\n", flatKey, val))
}
func (s *StatsdSink) SetGaugeWithLabels(key []string, val float32, labels []Label) {
flatKey := s.flattenKeyLabels(key, labels)
s.pushMetric(fmt.Sprintf("%s:%f|g\n", flatKey, val))
}
func (s *StatsdSink) EmitKey(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|kv\n", flatKey, val))
}
func (s *StatsdSink) IncrCounter(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|c\n", flatKey, val))
}
func (s *StatsdSink) IncrCounterWithLabels(key []string, val float32, labels []Label) {
flatKey := s.flattenKeyLabels(key, labels)
s.pushMetric(fmt.Sprintf("%s:%f|c\n", flatKey, val))
}
func (s *StatsdSink) AddSample(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|ms\n", flatKey, val))
}
func (s *StatsdSink) AddSampleWithLabels(key []string, val float32, labels []Label) {
flatKey := s.flattenKeyLabels(key, labels)
s.pushMetric(fmt.Sprintf("%s:%f|ms\n", flatKey, val))
}
// Flattens the key for formatting, removes spaces
func (s *StatsdSink) flattenKey(parts []string) string {
joined := strings.Join(parts, ".")
return strings.Map(func(r rune) rune {
switch r {
case ':':
fallthrough
case ' ':
return '_'
default:
return r
}
}, joined)
}
// Flattens the key along with labels for formatting, removes spaces
func (s *StatsdSink) flattenKeyLabels(parts []string, labels []Label) string {
for _, label := range labels {
parts = append(parts, label.Value)
}
return s.flattenKey(parts)
}
// Does a non-blocking push to the metrics queue
func (s *StatsdSink) pushMetric(m string) {
select {
case s.metricQueue <- m:
default:
}
}
// Flushes metrics
func (s *StatsdSink) flushMetrics() {
var sock net.Conn
var err error
var wait <-chan time.Time
ticker := time.NewTicker(flushInterval)
defer ticker.Stop()
CONNECT:
// Create a buffer
buf := bytes.NewBuffer(nil)
// Attempt to connect
sock, err = net.Dial("udp", s.addr)
if err != nil {
log.Printf("[ERR] Error connecting to statsd! Err: %s", err)
goto WAIT
}
for {
select {
case metric, ok := <-s.metricQueue:
// Get a metric from the queue
if !ok {
goto QUIT
}
// Check if this would overflow the packet size
if len(metric)+buf.Len() > statsdMaxLen {
_, err := sock.Write(buf.Bytes())
buf.Reset()
if err != nil {
log.Printf("[ERR] Error writing to statsd! Err: %s", err)
goto WAIT
}
}
// Append to the buffer
buf.WriteString(metric)
case <-ticker.C:
if buf.Len() == 0 {
continue
}
_, err := sock.Write(buf.Bytes())
buf.Reset()
if err != nil {
log.Printf("[ERR] Error flushing to statsd! Err: %s", err)
goto WAIT
}
}
}
WAIT:
// Wait for a while
wait = time.After(time.Duration(5) * time.Second)
for {
select {
// Dequeue the messages to avoid backlog
case _, ok := <-s.metricQueue:
if !ok {
goto QUIT
}
case <-wait:
goto CONNECT
}
}
QUIT:
s.metricQueue = nil
}

View File

@ -1,172 +0,0 @@
package metrics
import (
"bufio"
"fmt"
"log"
"net"
"net/url"
"strings"
"time"
)
const (
// We force flush the statsite metrics after this period of
// inactivity. Prevents stats from getting stuck in a buffer
// forever.
flushInterval = 100 * time.Millisecond
)
// NewStatsiteSinkFromURL creates an StatsiteSink from a URL. It is used
// (and tested) from NewMetricSinkFromURL.
func NewStatsiteSinkFromURL(u *url.URL) (MetricSink, error) {
return NewStatsiteSink(u.Host)
}
// StatsiteSink provides a MetricSink that can be used with a
// statsite metrics server
type StatsiteSink struct {
addr string
metricQueue chan string
}
// NewStatsiteSink is used to create a new StatsiteSink
func NewStatsiteSink(addr string) (*StatsiteSink, error) {
s := &StatsiteSink{
addr: addr,
metricQueue: make(chan string, 4096),
}
go s.flushMetrics()
return s, nil
}
// Close is used to stop flushing to statsite
func (s *StatsiteSink) Shutdown() {
close(s.metricQueue)
}
func (s *StatsiteSink) SetGauge(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|g\n", flatKey, val))
}
func (s *StatsiteSink) SetGaugeWithLabels(key []string, val float32, labels []Label) {
flatKey := s.flattenKeyLabels(key, labels)
s.pushMetric(fmt.Sprintf("%s:%f|g\n", flatKey, val))
}
func (s *StatsiteSink) EmitKey(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|kv\n", flatKey, val))
}
func (s *StatsiteSink) IncrCounter(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|c\n", flatKey, val))
}
func (s *StatsiteSink) IncrCounterWithLabels(key []string, val float32, labels []Label) {
flatKey := s.flattenKeyLabels(key, labels)
s.pushMetric(fmt.Sprintf("%s:%f|c\n", flatKey, val))
}
func (s *StatsiteSink) AddSample(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|ms\n", flatKey, val))
}
func (s *StatsiteSink) AddSampleWithLabels(key []string, val float32, labels []Label) {
flatKey := s.flattenKeyLabels(key, labels)
s.pushMetric(fmt.Sprintf("%s:%f|ms\n", flatKey, val))
}
// Flattens the key for formatting, removes spaces
func (s *StatsiteSink) flattenKey(parts []string) string {
joined := strings.Join(parts, ".")
return strings.Map(func(r rune) rune {
switch r {
case ':':
fallthrough
case ' ':
return '_'
default:
return r
}
}, joined)
}
// Flattens the key along with labels for formatting, removes spaces
func (s *StatsiteSink) flattenKeyLabels(parts []string, labels []Label) string {
for _, label := range labels {
parts = append(parts, label.Value)
}
return s.flattenKey(parts)
}
// Does a non-blocking push to the metrics queue
func (s *StatsiteSink) pushMetric(m string) {
select {
case s.metricQueue <- m:
default:
}
}
// Flushes metrics
func (s *StatsiteSink) flushMetrics() {
var sock net.Conn
var err error
var wait <-chan time.Time
var buffered *bufio.Writer
ticker := time.NewTicker(flushInterval)
defer ticker.Stop()
CONNECT:
// Attempt to connect
sock, err = net.Dial("tcp", s.addr)
if err != nil {
log.Printf("[ERR] Error connecting to statsite! Err: %s", err)
goto WAIT
}
// Create a buffered writer
buffered = bufio.NewWriter(sock)
for {
select {
case metric, ok := <-s.metricQueue:
// Get a metric from the queue
if !ok {
goto QUIT
}
// Try to send to statsite
_, err := buffered.Write([]byte(metric))
if err != nil {
log.Printf("[ERR] Error writing to statsite! Err: %s", err)
goto WAIT
}
case <-ticker.C:
if err := buffered.Flush(); err != nil {
log.Printf("[ERR] Error flushing to statsite! Err: %s", err)
goto WAIT
}
}
}
WAIT:
// Wait for a while
wait = time.After(time.Duration(5) * time.Second)
for {
select {
// Dequeue the messages to avoid backlog
case _, ok := <-s.metricQueue:
if !ok {
goto QUIT
}
case <-wait:
goto CONNECT
}
}
QUIT:
s.metricQueue = nil
}

View File

@ -1,15 +0,0 @@
bin/
.idea/
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out

View File

@ -1,12 +0,0 @@
language: go
dist: xenial
go:
- '1.10'
- '1.11'
- '1.12'
- '1.13'
- 'tip'
script:
- go test -coverpkg=./... -coverprofile=coverage.info -timeout=5s
- bash <(curl -s https://codecov.io/bash)

View File

@ -1,43 +0,0 @@
# Contributor Code of Conduct
This project adheres to [The Code Manifesto](http://codemanifesto.com)
as its guidelines for contributor interactions.
## The Code Manifesto
We want to work in an ecosystem that empowers developers to reach their
potential — one that encourages growth and effective collaboration. A space
that is safe for all.
A space such as this benefits everyone that participates in it. It encourages
new developers to enter our field. It is through discussion and collaboration
that we grow, and through growth that we improve.
In the effort to create such a place, we hold to these values:
1. **Discrimination limits us.** This includes discrimination on the basis of
race, gender, sexual orientation, gender identity, age, nationality,
technology and any other arbitrary exclusion of a group of people.
2. **Boundaries honor us.** Your comfort levels are not everyones comfort
levels. Remember that, and if brought to your attention, heed it.
3. **We are our biggest assets.** None of us were born masters of our trade.
Each of us has been helped along the way. Return that favor, when and where
you can.
4. **We are resources for the future.** As an extension of #3, share what you
know. Make yourself a resource to help those that come after you.
5. **Respect defines us.** Treat others as you wish to be treated. Make your
discussions, criticisms and debates from a position of respectfulness. Ask
yourself, is it true? Is it necessary? Is it constructive? Anything less is
unacceptable.
6. **Reactions require grace.** Angry responses are valid, but abusive language
and vindictive actions are toxic. When something happens that offends you,
handle it assertively, but be respectful. Escalate reasonably, and try to
allow the offender an opportunity to explain themselves, and possibly
correct the issue.
7. **Opinions are just that: opinions.** Each and every one of us, due to our
background and upbringing, have varying opinions. That is perfectly
acceptable. Remember this: if you respect your own opinions, you should
respect the opinions of others.
8. **To err is human.** You might not intend it, but mistakes do happen and
contribute to build experience. Tolerate honest mistakes, and don't
hesitate to apologize if you make one yourself.

View File

@ -1,63 +0,0 @@
#### Support
If you do have a contribution to the package, feel free to create a Pull Request or an Issue.
#### What to contribute
If you don't know what to do, there are some features and functions that need to be done
- [ ] Refactor code
- [ ] Edit docs and [README](https://github.com/asaskevich/govalidator/README.md): spellcheck, grammar and typo check
- [ ] Create actual list of contributors and projects that currently using this package
- [ ] Resolve [issues and bugs](https://github.com/asaskevich/govalidator/issues)
- [ ] Update actual [list of functions](https://github.com/asaskevich/govalidator#list-of-functions)
- [ ] Update [list of validators](https://github.com/asaskevich/govalidator#validatestruct-2) that available for `ValidateStruct` and add new
- [ ] Implement new validators: `IsFQDN`, `IsIMEI`, `IsPostalCode`, `IsISIN`, `IsISRC` etc
- [x] Implement [validation by maps](https://github.com/asaskevich/govalidator/issues/224)
- [ ] Implement fuzzing testing
- [ ] Implement some struct/map/array utilities
- [ ] Implement map/array validation
- [ ] Implement benchmarking
- [ ] Implement batch of examples
- [ ] Look at forks for new features and fixes
#### Advice
Feel free to create what you want, but keep in mind when you implement new features:
- Code must be clear and readable, names of variables/constants clearly describes what they are doing
- Public functions must be documented and described in source file and added to README.md to the list of available functions
- There are must be unit-tests for any new functions and improvements
## Financial contributions
We also welcome financial contributions in full transparency on our [open collective](https://opencollective.com/govalidator).
Anyone can file an expense. If the expense makes sense for the development of the community, it will be "merged" in the ledger of our open collective by the core contributors and the person who filed the expense will be reimbursed.
## Credits
### Contributors
Thank you to all the people who have already contributed to govalidator!
<a href="https://github.com/asaskevich/govalidator/graphs/contributors"><img src="https://opencollective.com/govalidator/contributors.svg?width=890" /></a>
### Backers
Thank you to all our backers! [[Become a backer](https://opencollective.com/govalidator#backer)]
<a href="https://opencollective.com/govalidator#backers" target="_blank"><img src="https://opencollective.com/govalidator/backers.svg?width=890"></a>
### Sponsors
Thank you to all our sponsors! (please ask your company to also support this open source project by [becoming a sponsor](https://opencollective.com/govalidator#sponsor))
<a href="https://opencollective.com/govalidator/sponsor/0/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/0/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/1/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/1/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/2/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/2/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/3/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/3/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/4/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/4/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/5/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/5/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/6/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/6/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/7/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/7/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/8/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/8/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/9/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/9/avatar.svg"></a>

View File

@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2014-2020 Alex Saskevich
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,622 +0,0 @@
govalidator
===========
[![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/asaskevich/govalidator?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [![GoDoc](https://godoc.org/github.com/asaskevich/govalidator?status.png)](https://godoc.org/github.com/asaskevich/govalidator)
[![Build Status](https://travis-ci.org/asaskevich/govalidator.svg?branch=master)](https://travis-ci.org/asaskevich/govalidator)
[![Coverage](https://codecov.io/gh/asaskevich/govalidator/branch/master/graph/badge.svg)](https://codecov.io/gh/asaskevich/govalidator) [![Go Report Card](https://goreportcard.com/badge/github.com/asaskevich/govalidator)](https://goreportcard.com/report/github.com/asaskevich/govalidator) [![GoSearch](http://go-search.org/badge?id=github.com%2Fasaskevich%2Fgovalidator)](http://go-search.org/view?id=github.com%2Fasaskevich%2Fgovalidator) [![Backers on Open Collective](https://opencollective.com/govalidator/backers/badge.svg)](#backers) [![Sponsors on Open Collective](https://opencollective.com/govalidator/sponsors/badge.svg)](#sponsors) [![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2Fasaskevich%2Fgovalidator.svg?type=shield)](https://app.fossa.io/projects/git%2Bgithub.com%2Fasaskevich%2Fgovalidator?ref=badge_shield)
A package of validators and sanitizers for strings, structs and collections. Based on [validator.js](https://github.com/chriso/validator.js).
#### Installation
Make sure that Go is installed on your computer.
Type the following command in your terminal:
go get github.com/asaskevich/govalidator
or you can get specified release of the package with `gopkg.in`:
go get gopkg.in/asaskevich/govalidator.v10
After it the package is ready to use.
#### Import package in your project
Add following line in your `*.go` file:
```go
import "github.com/asaskevich/govalidator"
```
If you are unhappy to use long `govalidator`, you can do something like this:
```go
import (
valid "github.com/asaskevich/govalidator"
)
```
#### Activate behavior to require all fields have a validation tag by default
`SetFieldsRequiredByDefault` causes validation to fail when struct fields do not include validations or are not explicitly marked as exempt (using `valid:"-"` or `valid:"email,optional"`). A good place to activate this is a package init function or the main() function.
`SetNilPtrAllowedByRequired` causes validation to pass when struct fields marked by `required` are set to nil. This is disabled by default for consistency, but some packages that need to be able to determine between `nil` and `zero value` state can use this. If disabled, both `nil` and `zero` values cause validation errors.
```go
import "github.com/asaskevich/govalidator"
func init() {
govalidator.SetFieldsRequiredByDefault(true)
}
```
Here's some code to explain it:
```go
// this struct definition will fail govalidator.ValidateStruct() (and the field values do not matter):
type exampleStruct struct {
Name string ``
Email string `valid:"email"`
}
// this, however, will only fail when Email is empty or an invalid email address:
type exampleStruct2 struct {
Name string `valid:"-"`
Email string `valid:"email"`
}
// lastly, this will only fail when Email is an invalid email address but not when it's empty:
type exampleStruct2 struct {
Name string `valid:"-"`
Email string `valid:"email,optional"`
}
```
#### Recent breaking changes (see [#123](https://github.com/asaskevich/govalidator/pull/123))
##### Custom validator function signature
A context was added as the second parameter, for structs this is the object being validated this makes dependent validation possible.
```go
import "github.com/asaskevich/govalidator"
// old signature
func(i interface{}) bool
// new signature
func(i interface{}, o interface{}) bool
```
##### Adding a custom validator
This was changed to prevent data races when accessing custom validators.
```go
import "github.com/asaskevich/govalidator"
// before
govalidator.CustomTypeTagMap["customByteArrayValidator"] = func(i interface{}, o interface{}) bool {
// ...
}
// after
govalidator.CustomTypeTagMap.Set("customByteArrayValidator", func(i interface{}, o interface{}) bool {
// ...
})
```
#### List of functions:
```go
func Abs(value float64) float64
func BlackList(str, chars string) string
func ByteLength(str string, params ...string) bool
func CamelCaseToUnderscore(str string) string
func Contains(str, substring string) bool
func Count(array []interface{}, iterator ConditionIterator) int
func Each(array []interface{}, iterator Iterator)
func ErrorByField(e error, field string) string
func ErrorsByField(e error) map[string]string
func Filter(array []interface{}, iterator ConditionIterator) []interface{}
func Find(array []interface{}, iterator ConditionIterator) interface{}
func GetLine(s string, index int) (string, error)
func GetLines(s string) []string
func HasLowerCase(str string) bool
func HasUpperCase(str string) bool
func HasWhitespace(str string) bool
func HasWhitespaceOnly(str string) bool
func InRange(value interface{}, left interface{}, right interface{}) bool
func InRangeFloat32(value, left, right float32) bool
func InRangeFloat64(value, left, right float64) bool
func InRangeInt(value, left, right interface{}) bool
func IsASCII(str string) bool
func IsAlpha(str string) bool
func IsAlphanumeric(str string) bool
func IsBase64(str string) bool
func IsByteLength(str string, min, max int) bool
func IsCIDR(str string) bool
func IsCRC32(str string) bool
func IsCRC32b(str string) bool
func IsCreditCard(str string) bool
func IsDNSName(str string) bool
func IsDataURI(str string) bool
func IsDialString(str string) bool
func IsDivisibleBy(str, num string) bool
func IsEmail(str string) bool
func IsExistingEmail(email string) bool
func IsFilePath(str string) (bool, int)
func IsFloat(str string) bool
func IsFullWidth(str string) bool
func IsHalfWidth(str string) bool
func IsHash(str string, algorithm string) bool
func IsHexadecimal(str string) bool
func IsHexcolor(str string) bool
func IsHost(str string) bool
func IsIP(str string) bool
func IsIPv4(str string) bool
func IsIPv6(str string) bool
func IsISBN(str string, version int) bool
func IsISBN10(str string) bool
func IsISBN13(str string) bool
func IsISO3166Alpha2(str string) bool
func IsISO3166Alpha3(str string) bool
func IsISO4217(str string) bool
func IsISO693Alpha2(str string) bool
func IsISO693Alpha3b(str string) bool
func IsIn(str string, params ...string) bool
func IsInRaw(str string, params ...string) bool
func IsInt(str string) bool
func IsJSON(str string) bool
func IsLatitude(str string) bool
func IsLongitude(str string) bool
func IsLowerCase(str string) bool
func IsMAC(str string) bool
func IsMD4(str string) bool
func IsMD5(str string) bool
func IsMagnetURI(str string) bool
func IsMongoID(str string) bool
func IsMultibyte(str string) bool
func IsNatural(value float64) bool
func IsNegative(value float64) bool
func IsNonNegative(value float64) bool
func IsNonPositive(value float64) bool
func IsNotNull(str string) bool
func IsNull(str string) bool
func IsNumeric(str string) bool
func IsPort(str string) bool
func IsPositive(value float64) bool
func IsPrintableASCII(str string) bool
func IsRFC3339(str string) bool
func IsRFC3339WithoutZone(str string) bool
func IsRGBcolor(str string) bool
func IsRegex(str string) bool
func IsRequestURI(rawurl string) bool
func IsRequestURL(rawurl string) bool
func IsRipeMD128(str string) bool
func IsRipeMD160(str string) bool
func IsRsaPub(str string, params ...string) bool
func IsRsaPublicKey(str string, keylen int) bool
func IsSHA1(str string) bool
func IsSHA256(str string) bool
func IsSHA384(str string) bool
func IsSHA512(str string) bool
func IsSSN(str string) bool
func IsSemver(str string) bool
func IsTiger128(str string) bool
func IsTiger160(str string) bool
func IsTiger192(str string) bool
func IsTime(str string, format string) bool
func IsType(v interface{}, params ...string) bool
func IsURL(str string) bool
func IsUTFDigit(str string) bool
func IsUTFLetter(str string) bool
func IsUTFLetterNumeric(str string) bool
func IsUTFNumeric(str string) bool
func IsUUID(str string) bool
func IsUUIDv3(str string) bool
func IsUUIDv4(str string) bool
func IsUUIDv5(str string) bool
func IsULID(str string) bool
func IsUnixTime(str string) bool
func IsUpperCase(str string) bool
func IsVariableWidth(str string) bool
func IsWhole(value float64) bool
func LeftTrim(str, chars string) string
func Map(array []interface{}, iterator ResultIterator) []interface{}
func Matches(str, pattern string) bool
func MaxStringLength(str string, params ...string) bool
func MinStringLength(str string, params ...string) bool
func NormalizeEmail(str string) (string, error)
func PadBoth(str string, padStr string, padLen int) string
func PadLeft(str string, padStr string, padLen int) string
func PadRight(str string, padStr string, padLen int) string
func PrependPathToErrors(err error, path string) error
func Range(str string, params ...string) bool
func RemoveTags(s string) string
func ReplacePattern(str, pattern, replace string) string
func Reverse(s string) string
func RightTrim(str, chars string) string
func RuneLength(str string, params ...string) bool
func SafeFileName(str string) string
func SetFieldsRequiredByDefault(value bool)
func SetNilPtrAllowedByRequired(value bool)
func Sign(value float64) float64
func StringLength(str string, params ...string) bool
func StringMatches(s string, params ...string) bool
func StripLow(str string, keepNewLines bool) string
func ToBoolean(str string) (bool, error)
func ToFloat(str string) (float64, error)
func ToInt(value interface{}) (res int64, err error)
func ToJSON(obj interface{}) (string, error)
func ToString(obj interface{}) string
func Trim(str, chars string) string
func Truncate(str string, length int, ending string) string
func TruncatingErrorf(str string, args ...interface{}) error
func UnderscoreToCamelCase(s string) string
func ValidateMap(inputMap map[string]interface{}, validationMap map[string]interface{}) (bool, error)
func ValidateStruct(s interface{}) (bool, error)
func WhiteList(str, chars string) string
type ConditionIterator
type CustomTypeValidator
type Error
func (e Error) Error() string
type Errors
func (es Errors) Error() string
func (es Errors) Errors() []error
type ISO3166Entry
type ISO693Entry
type InterfaceParamValidator
type Iterator
type ParamValidator
type ResultIterator
type UnsupportedTypeError
func (e *UnsupportedTypeError) Error() string
type Validator
```
#### Examples
###### IsURL
```go
println(govalidator.IsURL(`http://user@pass:domain.com/path/page`))
```
###### IsType
```go
println(govalidator.IsType("Bob", "string"))
println(govalidator.IsType(1, "int"))
i := 1
println(govalidator.IsType(&i, "*int"))
```
IsType can be used through the tag `type` which is essential for map validation:
```go
type User struct {
Name string `valid:"type(string)"`
Age int `valid:"type(int)"`
Meta interface{} `valid:"type(string)"`
}
result, err := govalidator.ValidateStruct(User{"Bob", 20, "meta"})
if err != nil {
println("error: " + err.Error())
}
println(result)
```
###### ToString
```go
type User struct {
FirstName string
LastName string
}
str := govalidator.ToString(&User{"John", "Juan"})
println(str)
```
###### Each, Map, Filter, Count for slices
Each iterates over the slice/array and calls Iterator for every item
```go
data := []interface{}{1, 2, 3, 4, 5}
var fn govalidator.Iterator = func(value interface{}, index int) {
println(value.(int))
}
govalidator.Each(data, fn)
```
```go
data := []interface{}{1, 2, 3, 4, 5}
var fn govalidator.ResultIterator = func(value interface{}, index int) interface{} {
return value.(int) * 3
}
_ = govalidator.Map(data, fn) // result = []interface{}{1, 6, 9, 12, 15}
```
```go
data := []interface{}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
var fn govalidator.ConditionIterator = func(value interface{}, index int) bool {
return value.(int)%2 == 0
}
_ = govalidator.Filter(data, fn) // result = []interface{}{2, 4, 6, 8, 10}
_ = govalidator.Count(data, fn) // result = 5
```
###### ValidateStruct [#2](https://github.com/asaskevich/govalidator/pull/2)
If you want to validate structs, you can use tag `valid` for any field in your structure. All validators used with this field in one tag are separated by comma. If you want to skip validation, place `-` in your tag. If you need a validator that is not on the list below, you can add it like this:
```go
govalidator.TagMap["duck"] = govalidator.Validator(func(str string) bool {
return str == "duck"
})
```
For completely custom validators (interface-based), see below.
Here is a list of available validators for struct fields (validator - used function):
```go
"email": IsEmail,
"url": IsURL,
"dialstring": IsDialString,
"requrl": IsRequestURL,
"requri": IsRequestURI,
"alpha": IsAlpha,
"utfletter": IsUTFLetter,
"alphanum": IsAlphanumeric,
"utfletternum": IsUTFLetterNumeric,
"numeric": IsNumeric,
"utfnumeric": IsUTFNumeric,
"utfdigit": IsUTFDigit,
"hexadecimal": IsHexadecimal,
"hexcolor": IsHexcolor,
"rgbcolor": IsRGBcolor,
"lowercase": IsLowerCase,
"uppercase": IsUpperCase,
"int": IsInt,
"float": IsFloat,
"null": IsNull,
"uuid": IsUUID,
"uuidv3": IsUUIDv3,
"uuidv4": IsUUIDv4,
"uuidv5": IsUUIDv5,
"creditcard": IsCreditCard,
"isbn10": IsISBN10,
"isbn13": IsISBN13,
"json": IsJSON,
"multibyte": IsMultibyte,
"ascii": IsASCII,
"printableascii": IsPrintableASCII,
"fullwidth": IsFullWidth,
"halfwidth": IsHalfWidth,
"variablewidth": IsVariableWidth,
"base64": IsBase64,
"datauri": IsDataURI,
"ip": IsIP,
"port": IsPort,
"ipv4": IsIPv4,
"ipv6": IsIPv6,
"dns": IsDNSName,
"host": IsHost,
"mac": IsMAC,
"latitude": IsLatitude,
"longitude": IsLongitude,
"ssn": IsSSN,
"semver": IsSemver,
"rfc3339": IsRFC3339,
"rfc3339WithoutZone": IsRFC3339WithoutZone,
"ISO3166Alpha2": IsISO3166Alpha2,
"ISO3166Alpha3": IsISO3166Alpha3,
"ulid": IsULID,
```
Validators with parameters
```go
"range(min|max)": Range,
"length(min|max)": ByteLength,
"runelength(min|max)": RuneLength,
"stringlength(min|max)": StringLength,
"matches(pattern)": StringMatches,
"in(string1|string2|...|stringN)": IsIn,
"rsapub(keylength)" : IsRsaPub,
"minstringlength(int): MinStringLength,
"maxstringlength(int): MaxStringLength,
```
Validators with parameters for any type
```go
"type(type)": IsType,
```
And here is small example of usage:
```go
type Post struct {
Title string `valid:"alphanum,required"`
Message string `valid:"duck,ascii"`
Message2 string `valid:"animal(dog)"`
AuthorIP string `valid:"ipv4"`
Date string `valid:"-"`
}
post := &Post{
Title: "My Example Post",
Message: "duck",
Message2: "dog",
AuthorIP: "123.234.54.3",
}
// Add your own struct validation tags
govalidator.TagMap["duck"] = govalidator.Validator(func(str string) bool {
return str == "duck"
})
// Add your own struct validation tags with parameter
govalidator.ParamTagMap["animal"] = govalidator.ParamValidator(func(str string, params ...string) bool {
species := params[0]
return str == species
})
govalidator.ParamTagRegexMap["animal"] = regexp.MustCompile("^animal\\((\\w+)\\)$")
result, err := govalidator.ValidateStruct(post)
if err != nil {
println("error: " + err.Error())
}
println(result)
```
###### ValidateMap [#2](https://github.com/asaskevich/govalidator/pull/338)
If you want to validate maps, you can use the map to be validated and a validation map that contain the same tags used in ValidateStruct, both maps have to be in the form `map[string]interface{}`
So here is small example of usage:
```go
var mapTemplate = map[string]interface{}{
"name":"required,alpha",
"family":"required,alpha",
"email":"required,email",
"cell-phone":"numeric",
"address":map[string]interface{}{
"line1":"required,alphanum",
"line2":"alphanum",
"postal-code":"numeric",
},
}
var inputMap = map[string]interface{}{
"name":"Bob",
"family":"Smith",
"email":"foo@bar.baz",
"address":map[string]interface{}{
"line1":"",
"line2":"",
"postal-code":"",
},
}
result, err := govalidator.ValidateMap(inputMap, mapTemplate)
if err != nil {
println("error: " + err.Error())
}
println(result)
```
###### WhiteList
```go
// Remove all characters from string ignoring characters between "a" and "z"
println(govalidator.WhiteList("a3a43a5a4a3a2a23a4a5a4a3a4", "a-z") == "aaaaaaaaaaaa")
```
###### Custom validation functions
Custom validation using your own domain specific validators is also available - here's an example of how to use it:
```go
import "github.com/asaskevich/govalidator"
type CustomByteArray [6]byte // custom types are supported and can be validated
type StructWithCustomByteArray struct {
ID CustomByteArray `valid:"customByteArrayValidator,customMinLengthValidator"` // multiple custom validators are possible as well and will be evaluated in sequence
Email string `valid:"email"`
CustomMinLength int `valid:"-"`
}
govalidator.CustomTypeTagMap.Set("customByteArrayValidator", func(i interface{}, context interface{}) bool {
switch v := context.(type) { // you can type switch on the context interface being validated
case StructWithCustomByteArray:
// you can check and validate against some other field in the context,
// return early or not validate against the context at all your choice
case SomeOtherType:
// ...
default:
// expecting some other type? Throw/panic here or continue
}
switch v := i.(type) { // type switch on the struct field being validated
case CustomByteArray:
for _, e := range v { // this validator checks that the byte array is not empty, i.e. not all zeroes
if e != 0 {
return true
}
}
}
return false
})
govalidator.CustomTypeTagMap.Set("customMinLengthValidator", func(i interface{}, context interface{}) bool {
switch v := context.(type) { // this validates a field against the value in another field, i.e. dependent validation
case StructWithCustomByteArray:
return len(v.ID) >= v.CustomMinLength
}
return false
})
```
###### Loop over Error()
By default .Error() returns all errors in a single String. To access each error you can do this:
```go
if err != nil {
errs := err.(govalidator.Errors).Errors()
for _, e := range errs {
fmt.Println(e.Error())
}
}
```
###### Custom error messages
Custom error messages are supported via annotations by adding the `~` separator - here's an example of how to use it:
```go
type Ticket struct {
Id int64 `json:"id"`
FirstName string `json:"firstname" valid:"required~First name is blank"`
}
```
#### Notes
Documentation is available here: [godoc.org](https://godoc.org/github.com/asaskevich/govalidator).
Full information about code coverage is also available here: [govalidator on gocover.io](http://gocover.io/github.com/asaskevich/govalidator).
#### Support
If you do have a contribution to the package, feel free to create a Pull Request or an Issue.
#### What to contribute
If you don't know what to do, there are some features and functions that need to be done
- [ ] Refactor code
- [ ] Edit docs and [README](https://github.com/asaskevich/govalidator/README.md): spellcheck, grammar and typo check
- [ ] Create actual list of contributors and projects that currently using this package
- [ ] Resolve [issues and bugs](https://github.com/asaskevich/govalidator/issues)
- [ ] Update actual [list of functions](https://github.com/asaskevich/govalidator#list-of-functions)
- [ ] Update [list of validators](https://github.com/asaskevich/govalidator#validatestruct-2) that available for `ValidateStruct` and add new
- [ ] Implement new validators: `IsFQDN`, `IsIMEI`, `IsPostalCode`, `IsISIN`, `IsISRC` etc
- [x] Implement [validation by maps](https://github.com/asaskevich/govalidator/issues/224)
- [ ] Implement fuzzing testing
- [ ] Implement some struct/map/array utilities
- [ ] Implement map/array validation
- [ ] Implement benchmarking
- [ ] Implement batch of examples
- [ ] Look at forks for new features and fixes
#### Advice
Feel free to create what you want, but keep in mind when you implement new features:
- Code must be clear and readable, names of variables/constants clearly describes what they are doing
- Public functions must be documented and described in source file and added to README.md to the list of available functions
- There are must be unit-tests for any new functions and improvements
## Credits
### Contributors
This project exists thanks to all the people who contribute. [[Contribute](CONTRIBUTING.md)].
#### Special thanks to [contributors](https://github.com/asaskevich/govalidator/graphs/contributors)
* [Daniel Lohse](https://github.com/annismckenzie)
* [Attila Oláh](https://github.com/attilaolah)
* [Daniel Korner](https://github.com/Dadie)
* [Steven Wilkin](https://github.com/stevenwilkin)
* [Deiwin Sarjas](https://github.com/deiwin)
* [Noah Shibley](https://github.com/slugmobile)
* [Nathan Davies](https://github.com/nathj07)
* [Matt Sanford](https://github.com/mzsanford)
* [Simon ccl1115](https://github.com/ccl1115)
<a href="https://github.com/asaskevich/govalidator/graphs/contributors"><img src="https://opencollective.com/govalidator/contributors.svg?width=890" /></a>
### Backers
Thank you to all our backers! 🙏 [[Become a backer](https://opencollective.com/govalidator#backer)]
<a href="https://opencollective.com/govalidator#backers" target="_blank"><img src="https://opencollective.com/govalidator/backers.svg?width=890"></a>
### Sponsors
Support this project by becoming a sponsor. Your logo will show up here with a link to your website. [[Become a sponsor](https://opencollective.com/govalidator#sponsor)]
<a href="https://opencollective.com/govalidator/sponsor/0/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/0/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/1/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/1/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/2/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/2/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/3/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/3/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/4/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/4/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/5/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/5/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/6/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/6/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/7/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/7/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/8/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/8/avatar.svg"></a>
<a href="https://opencollective.com/govalidator/sponsor/9/website" target="_blank"><img src="https://opencollective.com/govalidator/sponsor/9/avatar.svg"></a>
## License
[![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2Fasaskevich%2Fgovalidator.svg?type=large)](https://app.fossa.io/projects/git%2Bgithub.com%2Fasaskevich%2Fgovalidator?ref=badge_large)

View File

@ -1,87 +0,0 @@
package govalidator
// Iterator is the function that accepts element of slice/array and its index
type Iterator func(interface{}, int)
// ResultIterator is the function that accepts element of slice/array and its index and returns any result
type ResultIterator func(interface{}, int) interface{}
// ConditionIterator is the function that accepts element of slice/array and its index and returns boolean
type ConditionIterator func(interface{}, int) bool
// ReduceIterator is the function that accepts two element of slice/array and returns result of merging those values
type ReduceIterator func(interface{}, interface{}) interface{}
// Some validates that any item of array corresponds to ConditionIterator. Returns boolean.
func Some(array []interface{}, iterator ConditionIterator) bool {
res := false
for index, data := range array {
res = res || iterator(data, index)
}
return res
}
// Every validates that every item of array corresponds to ConditionIterator. Returns boolean.
func Every(array []interface{}, iterator ConditionIterator) bool {
res := true
for index, data := range array {
res = res && iterator(data, index)
}
return res
}
// Reduce boils down a list of values into a single value by ReduceIterator
func Reduce(array []interface{}, iterator ReduceIterator, initialValue interface{}) interface{} {
for _, data := range array {
initialValue = iterator(initialValue, data)
}
return initialValue
}
// Each iterates over the slice and apply Iterator to every item
func Each(array []interface{}, iterator Iterator) {
for index, data := range array {
iterator(data, index)
}
}
// Map iterates over the slice and apply ResultIterator to every item. Returns new slice as a result.
func Map(array []interface{}, iterator ResultIterator) []interface{} {
var result = make([]interface{}, len(array))
for index, data := range array {
result[index] = iterator(data, index)
}
return result
}
// Find iterates over the slice and apply ConditionIterator to every item. Returns first item that meet ConditionIterator or nil otherwise.
func Find(array []interface{}, iterator ConditionIterator) interface{} {
for index, data := range array {
if iterator(data, index) {
return data
}
}
return nil
}
// Filter iterates over the slice and apply ConditionIterator to every item. Returns new slice.
func Filter(array []interface{}, iterator ConditionIterator) []interface{} {
var result = make([]interface{}, 0)
for index, data := range array {
if iterator(data, index) {
result = append(result, data)
}
}
return result
}
// Count iterates over the slice and apply ConditionIterator to every item. Returns count of items that meets ConditionIterator.
func Count(array []interface{}, iterator ConditionIterator) int {
count := 0
for index, data := range array {
if iterator(data, index) {
count = count + 1
}
}
return count
}

View File

@ -1,81 +0,0 @@
package govalidator
import (
"encoding/json"
"fmt"
"reflect"
"strconv"
)
// ToString convert the input to a string.
func ToString(obj interface{}) string {
res := fmt.Sprintf("%v", obj)
return res
}
// ToJSON convert the input to a valid JSON string
func ToJSON(obj interface{}) (string, error) {
res, err := json.Marshal(obj)
if err != nil {
res = []byte("")
}
return string(res), err
}
// ToFloat convert the input string to a float, or 0.0 if the input is not a float.
func ToFloat(value interface{}) (res float64, err error) {
val := reflect.ValueOf(value)
switch value.(type) {
case int, int8, int16, int32, int64:
res = float64(val.Int())
case uint, uint8, uint16, uint32, uint64:
res = float64(val.Uint())
case float32, float64:
res = val.Float()
case string:
res, err = strconv.ParseFloat(val.String(), 64)
if err != nil {
res = 0
}
default:
err = fmt.Errorf("ToInt: unknown interface type %T", value)
res = 0
}
return
}
// ToInt convert the input string or any int type to an integer type 64, or 0 if the input is not an integer.
func ToInt(value interface{}) (res int64, err error) {
val := reflect.ValueOf(value)
switch value.(type) {
case int, int8, int16, int32, int64:
res = val.Int()
case uint, uint8, uint16, uint32, uint64:
res = int64(val.Uint())
case float32, float64:
res = int64(val.Float())
case string:
if IsInt(val.String()) {
res, err = strconv.ParseInt(val.String(), 0, 64)
if err != nil {
res = 0
}
} else {
err = fmt.Errorf("ToInt: invalid numeric format %g", value)
res = 0
}
default:
err = fmt.Errorf("ToInt: unknown interface type %T", value)
res = 0
}
return
}
// ToBoolean convert the input string to a boolean.
func ToBoolean(str string) (bool, error) {
return strconv.ParseBool(str)
}

View File

@ -1,3 +0,0 @@
package govalidator
// A package of validators and sanitizers for strings, structures and collections.

View File

@ -1,47 +0,0 @@
package govalidator
import (
"sort"
"strings"
)
// Errors is an array of multiple errors and conforms to the error interface.
type Errors []error
// Errors returns itself.
func (es Errors) Errors() []error {
return es
}
func (es Errors) Error() string {
var errs []string
for _, e := range es {
errs = append(errs, e.Error())
}
sort.Strings(errs)
return strings.Join(errs, ";")
}
// Error encapsulates a name, an error and whether there's a custom error message or not.
type Error struct {
Name string
Err error
CustomErrorMessageExists bool
// Validator indicates the name of the validator that failed
Validator string
Path []string
}
func (e Error) Error() string {
if e.CustomErrorMessageExists {
return e.Err.Error()
}
errName := e.Name
if len(e.Path) > 0 {
errName = strings.Join(append(e.Path, e.Name), ".")
}
return errName + ": " + e.Err.Error()
}

Some files were not shown because too many files have changed in this diff Show More