feat: xmnt

This commit is contained in:
Yiyang Kang 2022-10-07 01:19:04 +08:00
parent 00b3e4b24f
commit 5a998b713c
16 changed files with 1481 additions and 0 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
/xmnt /xmnt
/.idea

12
Makefile Normal file
View File

@ -0,0 +1,12 @@
BIN_NAME := xmnt
build:
CGO_ENABLED=0 go build -trimpath -ldflags '-s -w' -o "$(BIN_NAME)"
install: build
install -vDt "${HOME}/.local/bin" "$(BIN_NAME)"
clean:
rm -vf "$(BIN_NAME)"
.PHONY: build install clean

318
blk/blk.go Normal file
View File

@ -0,0 +1,318 @@
package blk
import (
"bytes"
"encoding/json"
"regexp"
"strings"
"github.com/go-errors/errors"
"golang.org/x/exp/slices"
"gensokyo.cafe/xmnt/cfg"
"gensokyo.cafe/xmnt/mnt"
"gensokyo.cafe/xmnt/msg"
"gensokyo.cafe/xmnt/util"
)
type FSType string
const FSTypeLuks = FSType("crypto_LUKS")
type DevNum string
var supportedFsTypes = []FSType{
"ext2", "ext3", "ext4", "xfs", "vfat", "ntfs", FSTypeLuks,
}
type BlkDev struct {
Name *string `json:"name"`
PKName *string `json:"pkname"` // name of parent device
Path *string `json:"path"`
FSType *FSType `json:"fstype"`
MountPoint *string `json:"mountpoint"`
UUID *string `json:"uuid"`
DevNum *DevNum `json:"maj:min"`
Children []*BlkDev `json:"children"`
}
func (b *BlkDev) IsSupportedType() bool {
if b.FSType == nil {
return false
}
return slices.Contains(supportedFsTypes, *b.FSType)
}
func (b *BlkDev) NeedsDecryption() bool {
if b.FSType == nil || *b.FSType != FSTypeLuks {
return false
}
return len(b.Children) == 0
}
func (b *BlkDev) IsMounted() bool {
if b.MountPoint != nil && (*b.MountPoint)[:1] == "/" {
return true
}
for _, child := range b.Children {
if child.IsMounted() {
return true
}
}
return false
}
func flattenBlkDevs(devs []*BlkDev) []*BlkDev {
var ret []*BlkDev
for _, dev := range devs {
ret = append(ret, dev)
ret = append(ret, flattenBlkDevs(dev.Children)...)
}
return ret
}
func List(path string) ([]*BlkDev, error) {
lsblkArgs := []string{"-JT", "-o", "name,pkname,path,fstype,mountpoint,uuid"}
if path != "" {
lsblkArgs = append(lsblkArgs, path)
}
output, err := util.RunCommand("lsblk", nil, lsblkArgs...)
if err != nil {
return nil, errors.WrapPrefix(err, "cannot obtain list of block devices", 0)
}
resp := struct {
BlockDevices []*BlkDev `json:"blockdevices"`
}{}
if err := json.Unmarshal(output, &resp); err != nil {
return nil, errors.WrapPrefix(err, "cannot parse lsblk output", 0)
}
return flattenBlkDevs(resp.BlockDevices), nil
}
type Mounter struct {
dev *BlkDev
preset *mnt.Preset
}
func NewMounterFromPreset(p *mnt.Preset) (mnt.Mounter, error) {
preset := &*p
if preset.Path == "" {
return nil, errors.New("preset path is empty")
}
if !util.IsValidMountPoint(preset.MountPoint) {
return nil, errors.Errorf("invalid mount point %q", preset.MountPoint)
}
m := &Mounter{
preset: preset,
}
if err := m.refresh(); err != nil {
return nil, errors.Wrap(err, 0)
}
return m, nil
}
func (m *Mounter) refresh() error {
devs, err := List(m.preset.Path)
if err != nil {
return errors.Wrap(err, 0)
}
if len(devs) < 1 {
return errors.Errorf("block device %q not found", m.preset.Path)
}
m.dev = devs[0]
return nil
}
func escapeDmName(s string) string {
b := regexp.MustCompile(`\W`).ReplaceAllLiteral([]byte(s), []byte{'_'})
b = regexp.MustCompile(`_+`).ReplaceAllLiteral(b, []byte{'_'})
b = bytes.Trim(b, "_")
return string(b)
}
func (m *Mounter) dmName() string {
for _, c := range m.dev.Children {
if strings.HasPrefix(*c.Path, "/dev/mapper/") {
return *c.Name
}
}
// TODO get from property value of systemd mount unit
if s := escapeDmName(m.preset.Name); s != "" {
return s
}
return escapeDmName(*m.dev.Path)
}
func (m *Mounter) loadKey() error {
if !m.dev.NeedsDecryption() {
return nil
}
if m.dev.UUID == nil {
return errors.Errorf("device %q does not have a UUID", *m.dev.Path)
}
dmName := m.dmName()
if dmName == "" {
return errors.Errorf("cannot determine device mapper name for device %q", *m.dev.Path)
}
cred, err := util.ReadCredentialFile(*m.dev.UUID, cfg.Cfg.CredentialStore)
if err != nil {
return errors.Wrap(err, 0)
}
msg.Infof("Opening device %q as %q", *m.dev.Path, dmName)
if _, err := util.RunPrivilegedCommand(
"cryptsetup", strings.NewReader(cred),
"luksOpen", *m.dev.Path, dmName,
); err != nil {
return errors.WrapPrefix(err, "cannot open luks device", 0)
}
return nil
}
func (m *Mounter) mount() error {
if m.dev.IsMounted() {
return nil
}
if !m.dev.IsSupportedType() || m.dev.NeedsDecryption() {
return errors.Errorf("cannot mount %q: unsupported filesystem type %q", *m.dev.Path, *m.dev.FSType)
}
if len(m.dev.Children) > 1 {
return errors.Errorf("cannot mount %q: can't handle multiple child devices", *m.dev.Path)
}
dev := m.dev
if len(dev.Children) > 0 {
dev = dev.Children[0]
}
mp := m.preset.MountPoint
if err := util.SystemdMount(mp); err == nil {
return nil
} else if !util.ShouldSkipSdMount(err) {
return errors.WrapPrefix(err, "cannot mount device", 0)
}
msg.Infof("Mounting %q on %q", *dev.Path, mp)
mountOpts := []string{"-o", "noatime", *dev.Path, mp}
if _, err := util.RunPrivilegedCommand("mount", nil, mountOpts...); err != nil {
return errors.WrapPrefix(err, "failed to mount device", 0)
}
return nil
}
func (m *Mounter) Mount() (err error) {
if err = m.loadKey(); err != nil {
return errors.Wrap(err, 0)
}
if err = m.refresh(); err != nil {
return errors.Wrap(err, 0)
}
if err = m.mount(); err != nil {
return errors.Wrap(err, 0)
}
// TODO check
return nil
}
func (m *Mounter) unmount() error {
if !m.dev.IsMounted() {
return nil
}
if len(m.dev.Children) > 1 {
return errors.Errorf("cannot unmount %q: can't handle multiple child devices", *m.dev.Path)
}
dev := m.dev
if len(dev.Children) > 0 {
dev = dev.Children[0]
}
mp := *m.dev.MountPoint
if err := util.SystemdUnmount(mp); err == nil {
return nil
} else if !util.ShouldSkipSdMount(err) {
return errors.WrapPrefix(err, "cannot unmount device", 0)
}
msg.Infof("Unmounting %q on %q", *dev.Path, mp)
if _, err := util.RunPrivilegedCommand("umount", nil, mp); err != nil {
return errors.WrapPrefix(err, "failed to unmount device", 0)
}
return nil
}
func (m *Mounter) unloadKey() error {
if *m.dev.FSType != FSTypeLuks {
return nil
}
if m.dev.NeedsDecryption() {
return nil
}
dmName := m.dmName()
if dmName == "" {
return errors.New("cannot determine device mapper name")
}
if _, err := util.RunPrivilegedCommand("cryptsetup", nil, "close", dmName); err != nil {
return errors.WrapPrefix(err, "cannot close luks device", 0)
}
return nil
}
func (m *Mounter) Unmount() (err error) {
if err = m.unmount(); err != nil {
return errors.Wrap(err, 0)
}
if err = m.unloadKey(); err != nil {
return errors.Wrap(err, 0)
}
// TODO check
return nil
}
func match(s string) ([]*mnt.Preset, error) {
if s == "" {
return nil, nil
}
devices, err := List("")
if err != nil {
return nil, errors.Wrap(err, 0)
}
var partialMatch []*BlkDev
for _, dev := range devices {
if dev.Name != nil && *dev.Name == s ||
dev.Path != nil && *dev.Path == s {
return []*mnt.Preset{{
Name: s,
Type: "blk",
Path: *dev.Path,
}}, nil
}
if dev.Path != nil && strings.HasSuffix(*dev.Path, "/"+s) {
partialMatch = append(partialMatch, dev)
}
}
var ret []*mnt.Preset
for _, dev := range partialMatch {
ret = append(ret, &mnt.Preset{
Name: s,
Type: "blk",
Path: *dev.Path,
})
}
return ret, nil
}
func init() {
mnt.RegisterMounter("blk", NewMounterFromPreset)
mnt.RegisterMatcher(match)
}

51
cfg/cfg.go Normal file
View File

@ -0,0 +1,51 @@
package cfg
import (
"os"
"path/filepath"
"github.com/go-errors/errors"
"github.com/ilyakaznacheev/cleanenv"
"gensokyo.cafe/xmnt/util"
)
type CfgDef struct {
CredentialStore []string `yaml:"credential_store" env-default:"$HOME/.vault2/data_encryption"`
}
func (c *CfgDef) expand() {
for i, path := range c.CredentialStore {
c.CredentialStore[i] = os.ExpandEnv(path)
}
}
var Cfg *CfgDef
func LoadAuto() error {
loc, err := os.UserConfigDir()
if err != nil {
return errors.WrapPrefix(err, "cannot obtain user config dir", 0)
}
path := filepath.Join(loc, "xmnt", "xmnt.yml")
loadFromFile, err := util.FileExists(path)
if err != nil {
return errors.WrapPrefix(err, "cannot read config file", 0)
}
cfg := &CfgDef{}
if loadFromFile {
if err = cleanenv.ReadConfig(path, cfg); err != nil {
return errors.WrapPrefix(err, "cannot read config file", 0)
}
} else {
if err = cleanenv.ReadEnv(cfg); err != nil {
return errors.WrapPrefix(err, "cannot read config from env", 0)
}
}
cfg.expand()
Cfg = cfg
return nil
}

27
go.mod Normal file
View File

@ -0,0 +1,27 @@
module gensokyo.cafe/xmnt
go 1.19
require (
github.com/bitfield/script v0.20.2
github.com/coreos/go-systemd/v22 v22.4.0
github.com/fatih/color v1.13.0
github.com/go-errors/errors v1.4.2
github.com/ilyakaznacheev/cleanenv v1.3.0
golang.org/x/exp v0.0.0-20221002003631-540bb7301a08
golang.org/x/text v0.3.7
gopkg.in/yaml.v3 v3.0.1
)
require (
bitbucket.org/creachadair/shell v0.0.7 // indirect
github.com/BurntSushi/toml v1.2.0 // indirect
github.com/godbus/dbus/v5 v5.0.4 // indirect
github.com/itchyny/gojq v0.12.7 // indirect
github.com/itchyny/timefmt-go v0.1.3 // indirect
github.com/joho/godotenv v1.4.0 // indirect
github.com/mattn/go-colorable v0.1.9 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f // indirect
olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3 // indirect
)

52
go.sum Normal file
View File

@ -0,0 +1,52 @@
bitbucket.org/creachadair/shell v0.0.7 h1:Z96pB6DkSb7F3Y3BBnJeOZH2gazyMTWlvecSD4vDqfk=
bitbucket.org/creachadair/shell v0.0.7/go.mod h1:oqtXSSvSYr4624lnnabXHaBsYW6RD80caLi2b3hJk0U=
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0=
github.com/BurntSushi/toml v1.2.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/bitfield/script v0.20.2 h1:4DexsRtBILVMEn3EZwHbtJdDqdk43sXI8gM3F04JXgs=
github.com/bitfield/script v0.20.2/go.mod h1:l3AZPVAtKQrL03bwh7nlNTUtgrgSWurpJSbtqspYrOA=
github.com/coreos/go-systemd/v22 v22.4.0 h1:y9YHcjnjynCd/DVbg5j9L/33jQM3MxJlbj/zWskzfGU=
github.com/coreos/go-systemd/v22 v22.4.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/godbus/dbus/v5 v5.0.4 h1:9349emZab16e7zQvpmsbtjc18ykshndd8y2PG3sgJbA=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/ilyakaznacheev/cleanenv v1.3.0 h1:RapuLclPPUbmdd5Bi5UXScwMEZA6+ZNLU5OW9itPjj0=
github.com/ilyakaznacheev/cleanenv v1.3.0/go.mod h1:i0owW+HDxeGKE0/JPREJOdSCPIyOnmh6C0xhWAkF/xA=
github.com/itchyny/gojq v0.12.7 h1:hYPTpeWfrJ1OT+2j6cvBScbhl0TkdwGM4bc66onUSOQ=
github.com/itchyny/gojq v0.12.7/go.mod h1:ZdvNHVlzPgUf8pgjnuDTmGfHA/21KoutQUJ3An/xNuw=
github.com/itchyny/timefmt-go v0.1.3 h1:7M3LGVDsqcd0VZH2U+x393obrzZisp7C0uEe921iRkU=
github.com/itchyny/timefmt-go v0.1.3/go.mod h1:0osSSCQSASBJMsIZnhAaF1C2fCBTJZXrnj37mG8/c+A=
github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg=
github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/mattn/go-colorable v0.1.9 h1:sqDoxXbdeALODt0DAeJCVp38ps9ZogZEAXjus69YV3U=
github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
golang.org/x/exp v0.0.0-20221002003631-540bb7301a08 h1:LtBIgSqNhkuC9gA3BFjGy5obHQT1lnmNsMDFSqWzQ5w=
golang.org/x/exp v0.0.0-20221002003631-540bb7301a08/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3 h1:slmdOY3vp8a7KQbHkL+FLbvbkgMqmXojpFUO/jENuqQ=
olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3/go.mod h1:oVgVk4OWVDi43qWBEyGhXgYxt7+ED4iYNpTngSLX2Iw=

90
main.go Normal file
View File

@ -0,0 +1,90 @@
package main
import (
"flag"
"fmt"
"os"
"gensokyo.cafe/xmnt/cfg"
"gensokyo.cafe/xmnt/mnt"
"gensokyo.cafe/xmnt/msg"
)
var (
unmount = flag.Bool("u", false, "Unmount instead of mount")
)
func main() {
if err := cfg.LoadAuto(); err != nil {
panic(err)
}
flag.Parse()
args := flag.Args()
if len(args) == 0 {
flag.Usage()
os.Exit(1)
}
var name, mountPoint string
name = args[0]
if len(args) > 1 {
mountPoint = args[1]
}
matched, err := mnt.MatchAll(name)
if err != nil {
msg.Errorf("Failed to find match for %q: %v", name, err)
os.Exit(1)
}
if len(matched) == 0 {
msg.Errorf("No match for %q", name)
os.Exit(1)
}
if len(matched) > 1 {
msg.Errorf("Ambiguous name %q", name)
msg.Infof("%d matches:", len(matched))
for _, m := range matched {
msg.Infof(" %s", m)
}
os.Exit(1)
}
preset := matched[0]
if mountPoint != "" {
preset.MountPoint = mountPoint
}
mounter, err := mnt.MounterFromPreset(preset)
if err != nil {
msg.Errorf("Failed to initialize mounter: %v", err)
os.Exit(1)
}
switch *unmount {
case false:
err = mounter.Mount()
case true:
err = mounter.Unmount()
}
if err != nil {
msg.Errorf("%v", mountPoint, err)
os.Exit(1)
}
msg.Infof("Success")
}
func init() {
usageText := fmt.Sprintf("Usage: %s [options] <name> [mountpoint]\n\nOptions:\n", os.Args[0])
argsUsageText := `
Arguments:
name Name or path of the device to mount. Can also be name of a preset
mountpoint Mount point (optional)
`
flag.Usage = func() {
_, _ = fmt.Fprint(flag.CommandLine.Output(), usageText)
flag.PrintDefaults()
_, _ = fmt.Fprintln(flag.CommandLine.Output(), argsUsageText)
}
}

68
mnt/mnt.go Normal file
View File

@ -0,0 +1,68 @@
package mnt
import (
"github.com/go-errors/errors"
"gensokyo.cafe/xmnt/msg"
)
type Mounter interface {
Mount() error
Unmount() error
}
type (
MounterFunc func(*Preset) (Mounter, error)
MatcherFunc func(string) ([]*Preset, error)
)
var mounters = map[string]MounterFunc{}
func RegisterMounter(typeName string, fn MounterFunc) {
mounters[typeName] = fn
}
func MounterFromPreset(preset *Preset) (Mounter, error) {
if fn, ok := mounters[preset.Type]; !ok {
return nil, errors.Errorf("unknown type %q", preset.Type)
} else {
return fn(preset)
}
}
var matchers []MatcherFunc
func RegisterMatcher(fn MatcherFunc) {
matchers = append(matchers, fn)
}
func MatchAll(s string) ([]*Preset, error) {
// Match against presets. If matched, return the preset (should be only one).
if matches, err := match(s); err != nil {
msg.Errorf("Failed to find match in presets: %v", err)
} else if len(matches) > 0 {
return matches[:1], nil
}
// Run other matchers. Might produce multiple matches.
var ret []*Preset
rCh := make(chan []*Preset, 1)
for _, fn := range matchers {
go func(fn MatcherFunc) {
matches, err := fn(s)
if err != nil {
msg.Errorf("%v", err)
rCh <- nil
} else {
rCh <- matches
}
}(fn)
}
for range matchers {
matches := <-rCh
ret = append(ret, matches...)
}
return ret, nil
}

97
mnt/preset.go Normal file
View File

@ -0,0 +1,97 @@
package mnt
import (
"os"
"path/filepath"
"strings"
"github.com/go-errors/errors"
"gopkg.in/yaml.v3"
"gensokyo.cafe/xmnt/msg"
"gensokyo.cafe/xmnt/util"
)
type Preset struct {
Name string `yaml:"name"`
Type string `yaml:"type"`
Path string `yaml:"path"`
MountPoint string `yaml:"mountpoint"`
AuthCmd string `yaml:"auth_cmd"` // e.g. for loading the encryption key
MountCmd string `yaml:"mount_cmd"`
CheckCmd string `yaml:"check_cmd"` // check if the mounting was successful
UnmountCmd string `yaml:"unmount_cmd"`
UnAuthCmd string `yaml:"unauth_cmd"` // e.g. for unloading the encryption key
UnmountCheckCmd string `yaml:"unmount_check_cmd"`
}
func (p *Preset) String() string {
return p.Type + ": " + p.Name + " (" + p.Path + ")"
}
func readPreset(path string) (*Preset, error) {
buf, err := os.ReadFile(path)
if err != nil {
return nil, errors.Wrap(err, 0)
}
var preset *Preset
if err = yaml.Unmarshal(buf, &preset); err != nil {
return nil, errors.Wrap(err, 0)
}
if preset.Name == "" {
preset.Name = filepath.Base(path)[:len(filepath.Base(path))-4] // remove .yml
}
return preset, nil
}
func ReadPresets() ([]*Preset, error) {
cfgDir, err := os.UserConfigDir()
if err != nil {
return nil, errors.WrapPrefix(err, "cannot obtain user config dir", 0)
}
presetsDir := filepath.Join(cfgDir, "xmnt", "presets")
dirExist, err := util.DirExists(presetsDir)
if err != nil {
return nil, errors.Wrap(err, 0)
}
if !dirExist {
return nil, nil
}
var ret []*Preset
entries, err := os.ReadDir(presetsDir)
if err != nil {
return nil, errors.Wrap(err, 0)
}
for _, entry := range entries {
if !entry.Type().IsRegular() || !strings.HasSuffix(entry.Name(), ".yml") {
continue
}
preset, err := readPreset(filepath.Join(presetsDir, entry.Name()))
if err != nil {
msg.Errorf("Failed to read preset %s: %v", entry.Name(), err)
continue
}
ret = append(ret, preset)
}
return ret, nil
}
func match(s string) ([]*Preset, error) {
presets, err := ReadPresets()
if err != nil {
return nil, errors.Wrap(err, 0)
}
for _, p := range presets {
// for presets, only return exact match.
if p.Name == s {
return []*Preset{p}, nil
}
}
return nil, nil
}

7
mounters.go Normal file
View File

@ -0,0 +1,7 @@
package main
import (
_ "gensokyo.cafe/xmnt/blk"
_ "gensokyo.cafe/xmnt/zfs"
)

22
msg/msg.go Normal file
View File

@ -0,0 +1,22 @@
package msg
import (
"os"
"github.com/fatih/color"
)
var (
infoStyle = color.New(color.FgBlue)
errStyle = color.New(color.FgRed)
)
func Infof(format string, a ...any) {
_, _ = infoStyle.Fprintf(os.Stderr, format, a...)
_, _ = os.Stderr.WriteString("\n")
}
func Errorf(format string, a ...any) {
_, _ = errStyle.Fprintf(os.Stderr, format, a...)
_, _ = os.Stderr.WriteString("\n")
}

39
util/command.go Normal file
View File

@ -0,0 +1,39 @@
package util
import (
"fmt"
"io"
"os"
"os/exec"
"github.com/go-errors/errors"
"gensokyo.cafe/xmnt/msg"
)
func RunCommand(cmd string, input io.Reader, args ...string) ([]byte, error) {
command := exec.Command(cmd, args...)
if input != nil {
command.Stdin = input
} else {
command.Stdin = os.Stdin
}
command.Stderr = os.Stderr
byt, err := command.Output()
if err != nil {
if e, ok := err.(*exec.ExitError); ok {
return nil, errors.New(
fmt.Sprintf("command %q %q failed with exit status %d", cmd, args, e.ExitCode()),
)
}
return nil, errors.New(err)
}
return byt, nil
}
func RunPrivilegedCommand(cmd string, input io.Reader, args ...string) ([]byte, error) {
realArgs := append([]string{cmd}, args...)
msg.Infof("Running command with sudo: %q", realArgs)
return RunCommand("sudo", input, realArgs...)
}

16
util/mount.go Normal file
View File

@ -0,0 +1,16 @@
package util
import (
"path/filepath"
"regexp"
)
var mountPointPattern = regexp.MustCompile(`^/[\w/.-]*$`)
func IsValidMountPoint(mp string) bool {
if !mountPointPattern.MatchString(mp) {
return false
}
return filepath.Clean(mp) == mp
}

124
util/systemd.go Normal file
View File

@ -0,0 +1,124 @@
package util
import (
"context"
"fmt"
"time"
"github.com/coreos/go-systemd/v22/dbus"
"github.com/coreos/go-systemd/v22/unit"
"github.com/go-errors/errors"
"gensokyo.cafe/xmnt/msg"
)
type SystemdConnection struct {
conn *dbus.Conn
}
func NewSystemdConnection() (*SystemdConnection, error) {
conn, err := dbus.NewSystemConnectionContext(context.Background())
if err != nil {
return nil, errors.WrapPrefix(err, "failed to establish dbus connection with systemd", 0)
}
return &SystemdConnection{conn: conn}, nil
}
func (c *SystemdConnection) Close() {
c.conn.Close()
}
func (c *SystemdConnection) startOrStopUnit(name string, isStart bool) (err error) {
rCh := make(chan string, 1)
defer close(rCh)
var (
actionText = []string{"start", "starting"}
actionFn = c.conn.StartUnitContext
)
if !isStart {
actionText = []string{"stop", "stopping"}
actionFn = c.conn.StopUnitContext
}
msg.Infof("%s systemd unit %q", Title(actionText[1]), name)
if _, err := actionFn(context.Background(), name, "replace", rCh); err != nil {
return errors.WrapPrefix(err, fmt.Sprintf("failed to %s systemd unit %q", actionText[0], name), 0)
}
timeout := time.After(5 * time.Second)
select {
case result := <-rCh:
if result != "done" {
return errors.Errorf("failed to %s systemd unit %q: result is %q", actionText[0], name, result)
}
case <-timeout:
// FIXME: The go-systemd library does not write anything to the channel when the job is
// failed due to dependency not met. For now we use a timeout to deal with this case.
// Might cause panic because we also close the channel.
return errors.Errorf(
"failed to %s systemd unit %q: timed out waiting for response from systemd",
actionText[0], name,
)
}
return nil
}
func (c *SystemdConnection) StartUnit(name string) error {
return c.startOrStopUnit(name, true)
}
func (c *SystemdConnection) StopUnit(name string) error {
return c.startOrStopUnit(name, false)
}
func (c *SystemdConnection) FindMount(mountPoint string) (string, error) {
unitName := unit.UnitNamePathEscape(mountPoint) + ".mount"
units, err := c.conn.ListUnitsByNamesContext(context.Background(), []string{unitName})
if err != nil {
return "", errors.WrapPrefix(err, "cannot retrieve list of systemd units", 0)
}
for _, u := range units {
if u.LoadState == "loaded" {
return u.Name, nil
}
}
return "", nil
}
var (
ErrSdNotAvailable = fmt.Errorf("systemd is not available")
ErrSdMountNotFound = fmt.Errorf("systemd mount not found")
)
func systemdMountOrUnmount(mountPoint string, isMount bool) error {
sd, err := NewSystemdConnection()
if err != nil {
return errors.WrapPrefix(ErrSdNotAvailable, err.Error(), 0)
}
defer sd.Close()
sdUnit, err := sd.FindMount(mountPoint)
if err != nil {
return errors.WrapPrefix(err, "failed to find systemd mount", 0)
}
if sdUnit == "" {
return errors.Wrap(ErrSdMountNotFound, 0)
}
if err := sd.startOrStopUnit(sdUnit, isMount); err != nil {
return errors.Wrap(err, 0)
}
return nil
}
func SystemdMount(mountPoint string) error {
return systemdMountOrUnmount(mountPoint, true)
}
func SystemdUnmount(mountPoint string) error {
return systemdMountOrUnmount(mountPoint, false)
}
func ShouldSkipSdMount(err error) bool {
return errors.Is(err, ErrSdNotAvailable) || errors.Is(err, ErrSdMountNotFound)
}

126
util/util.go Normal file
View File

@ -0,0 +1,126 @@
package util
import (
"fmt"
"os"
"path/filepath"
"github.com/bitfield/script"
"github.com/go-errors/errors"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gensokyo.cafe/xmnt/msg"
)
// FileExists checks if a given path exists. It returns error when the target
// exists but is not a regular file.
func FileExists(path string) (bool, error) {
fi, err := os.Stat(path)
if os.IsNotExist(err) {
return false, nil
}
if err != nil {
return false, err
}
if !fi.Mode().IsRegular() {
err = errors.Errorf("%q is not a regular file", path)
}
return true, err
}
// DirExists checks if a given path exists. It returns error when the target
// exists but is not a directory.
func DirExists(path string) (bool, error) {
fi, err := os.Stat(path)
if os.IsNotExist(err) {
return false, nil
}
if err != nil {
return false, err
}
if !fi.Mode().IsDir() {
err = errors.Errorf("%q is not a directory", path)
}
return true, err
}
func findCredentialFileFrom(dir, uuid string) (string, error) {
keyName := uuid + ".key"
keyPath := ""
if err := filepath.WalkDir(dir, func(path string, d os.DirEntry, err error) error {
if err != nil {
return err
}
if keyPath != "" {
return filepath.SkipDir
}
if d.Type().IsRegular() && d.Name() == keyName {
keyPath = path
}
return nil
}); err != nil {
return "", errors.Wrap(err, 0)
}
return keyPath, nil
}
func ReadCredentialFile(uuid string, dirs []string) (string, error) {
var (
keyPath string
err error
keyText []byte
)
for _, dir := range dirs {
if keyPath != "" {
break
}
keyPath, err = findCredentialFileFrom(dir, uuid)
if err != nil {
return "", errors.WrapPrefix(err, "cannot read credential file", 0)
}
}
if keyPath == "" {
return "", errors.Errorf("cannot find credential file for uuid %q", uuid)
}
isPgp, err := isPgpEncrypted(keyPath)
if err != nil {
msg.Errorf("Cannot determine type of credential file %q: %v", keyPath, err)
isPgp = false
}
if isPgp {
msg.Infof("Reading PGP encrypted credential from %q", keyPath)
keyText, err = readPgpEncryptedFile(keyPath)
} else {
msg.Infof("Reading credential from %q", keyPath)
keyText, err = os.ReadFile(keyPath)
}
if err != nil {
return "", errors.WrapPrefix(err, "cannot read credential file", 0)
}
return string(keyText), nil
}
func isPgpEncrypted(path string) (bool, error) {
output, err := RunCommand("file", nil, path)
if err != nil {
return false, errors.Wrap(err, 0)
}
l, err := script.Echo(string(output)).Match("PGP").Match("encrypted").CountLines()
return l > 0, nil
}
func readPgpEncryptedFile(path string) ([]byte, error) {
output, err := RunCommand("gpg", nil, "--decrypt", path)
if err != nil {
return nil, errors.WrapPrefix(err, fmt.Sprintf("failed to read pgp encrypted file %q", path), 0)
}
return output, nil
}
func Title(s string) string {
return cases.Title(language.Und).String(s)
}

431
zfs/zfs.go Normal file
View File

@ -0,0 +1,431 @@
package zfs
import (
"bufio"
"bytes"
"fmt"
"io"
"os/user"
"regexp"
"strings"
"github.com/bitfield/script"
"github.com/go-errors/errors"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"gensokyo.cafe/xmnt/cfg"
"gensokyo.cafe/xmnt/mnt"
"gensokyo.cafe/xmnt/msg"
"gensokyo.cafe/xmnt/util"
)
const ZfsBin = "/usr/bin/zfs"
type execFunc func(string, io.Reader, ...string) ([]byte, error)
type Permissions struct {
Mount bool
LoadKey bool
}
var allPermission = Permissions{Mount: true, LoadKey: true}
type KeyStatus string
const (
KeyStatusAvailable KeyStatus = "available"
KeyStatusUnavailable KeyStatus = "unavailable"
)
type CanMount string
const (
CanMountOn CanMount = "on"
CanMountOff CanMount = "off"
CanMountNoAuto CanMount = "noauto"
)
type Dataset struct {
Name string
GUID string
CanMount CanMount
MountPoint string
KeyStatus KeyStatus
EncryptionRoot string
Mounted bool
permissions *Permissions
}
func (d *Dataset) loadPermissions() error {
if d.permissions != nil {
return nil
}
currentUser, err := user.Current()
if err != nil {
return errors.WrapPrefix(err, "cannot obtain current user", 0)
}
if currentUser.Uid == "0" {
d.permissions = &allPermission
return nil
}
permissions := &Permissions{}
output, err := util.RunCommand(ZfsBin, nil, "allow", d.Name)
if err != nil {
return errors.WrapPrefix(err, fmt.Sprintf("cannot obtain permission list for zfs dataset %s", d.Name), 0)
}
pattern := regexp.MustCompile(`^\s*user ` + currentUser.Username + ` ([\w-]+(,[\w-]+)*)$`)
script.Echo(string(output)).FilterLine(func(line string) string {
match := pattern.FindStringSubmatch(line)
if match != nil {
for _, perm := range strings.Split(match[1], ",") {
switch perm {
case "mount":
permissions.Mount = true
case "load-key":
permissions.LoadKey = true
}
}
}
return ""
}).Wait()
d.permissions = permissions
return nil
}
func (d *Dataset) Permissions() (*Permissions, error) {
if err := d.loadPermissions(); err != nil {
return nil, err
}
return &*d.permissions, nil
}
func listCmd(name string, recursive bool) ([]byte, error) {
zfsArgs := []string{"get", "-Ht", "filesystem", "guid,canmount,mountpoint,encryptionroot,keystatus,mounted"}
if recursive {
zfsArgs = append(zfsArgs, "-r")
}
if name != "" {
zfsArgs = append(zfsArgs, name)
}
return util.RunCommand(ZfsBin, nil, zfsArgs...)
}
func ParseListOutput(output []byte) ([]*Dataset, error) {
idx := map[string]*Dataset{}
sc := bufio.NewScanner(bytes.NewReader(output))
for sc.Scan() {
fields := strings.Split(sc.Text(), "\t")
if len(fields) != 4 {
return nil, errors.Errorf("invalid zfs list output: %q", sc.Text())
}
name := fields[0]
key := fields[1]
value := fields[2]
dataset, ok := idx[name]
if !ok {
dataset = &Dataset{Name: name}
idx[name] = dataset
}
switch key {
case "guid":
dataset.GUID = value
case "canmount":
dataset.CanMount = CanMount(value)
case "mountpoint":
dataset.MountPoint = value
case "encryptionroot":
dataset.EncryptionRoot = value
case "keystatus":
dataset.KeyStatus = KeyStatus(value)
case "mounted":
dataset.Mounted = value == "yes"
}
}
return maps.Values(idx), nil
}
func List(name string, recursive bool) ([]*Dataset, error) {
listOutput, err := listCmd(name, recursive)
if err != nil {
return nil, errors.WrapPrefix(err, "cannot obtain list of zfs datasets", 0)
}
datasets, err := ParseListOutput(listOutput)
if err != nil {
return nil, errors.WrapPrefix(err, "cannot parse zfs list output", 0)
}
return datasets, nil
}
type Mounter struct {
preset *mnt.Preset
dataset *Dataset
}
func NewMounterFromPreset(p *mnt.Preset) (mnt.Mounter, error) {
preset := &*p
if preset.Path == "" {
return nil, errors.New("preset path is empty")
}
m := &Mounter{
preset: preset,
}
// find the dataset
if err := m.refresh(); err != nil {
return nil, err
}
return m, nil
}
func (m *Mounter) refresh() error {
datasets, err := List(m.preset.Path, false)
if err != nil {
return errors.WrapPrefix(err, "cannot obtain list of zfs datasets", 0)
}
var newDataset *Dataset
for _, d := range datasets {
if d.Name == m.preset.Path {
newDataset = d
}
}
if newDataset == nil {
return errors.Errorf("cannot find zfs dataset %q", m.preset.Path)
}
m.dataset = newDataset
return nil
}
func (m *Mounter) loadKey() error {
if m.dataset.KeyStatus != KeyStatusUnavailable {
return nil
}
if m.dataset.Name != m.dataset.EncryptionRoot {
return errors.Errorf("cannot load key for zfs dataset %q: not an encryption root", m.dataset.Name)
}
key, err := util.ReadCredentialFile(m.dataset.GUID, cfg.Cfg.CredentialStore)
if err != nil {
return errors.WrapPrefix(err, "cannot load zfs key", 0)
}
perm, err := m.dataset.Permissions()
if err != nil {
return errors.WrapPrefix(err, "failed to load zfs key", 0)
}
var run execFunc
if perm.LoadKey {
run = util.RunCommand
} else {
run = util.RunPrivilegedCommand
}
msg.Infof("zfs load-key %q", m.dataset.Name)
_, err = run(ZfsBin, strings.NewReader(key), "load-key", m.dataset.Name)
if err != nil {
return errors.WrapPrefix(err, "failed to load zfs key", 0)
}
return nil
}
func (m *Mounter) mount() error {
if m.dataset.Mounted {
return nil
}
if !slices.Contains([]CanMount{CanMountNoAuto, CanMountOn}, m.dataset.CanMount) {
return errors.Errorf("cannot mount zfs dataset %q: canmount is %q", m.dataset.Name, m.dataset.CanMount)
}
if m.dataset.KeyStatus != KeyStatusAvailable {
return errors.Errorf("cannot mount zfs dataset %q: not unlocked", m.dataset.Name)
}
mountPoint := m.preset.MountPoint
if mountPoint == "" {
mountPoint = m.dataset.MountPoint
}
if !util.IsValidMountPoint(mountPoint) {
return errors.Errorf(
"cannot mount zfs dataset %q: invalid mount point %q",
m.dataset.Name, mountPoint,
)
}
// If systemd mount unit exists, use it.
// In this case, if we do not use systemd for mounting, systemd will mess with the mounting process, and the zfs
// dataset will get unmounted immediately after mounting. See https://github.com/openzfs/zfs/issues/11248
if err := util.SystemdMount(mountPoint); err == nil {
return nil
} else if !util.ShouldSkipSdMount(err) {
return errors.WrapPrefix(err, fmt.Sprintf("failed to mount zfs dataset %q", m.dataset.Name), 0)
}
// mount using zfs command
mountArgs := []string{"mount"}
if m.preset.MountPoint != "" {
// user specified the mount point
mountArgs = append(mountArgs, "-o", "mountpoint="+mountPoint)
}
mountArgs = append(mountArgs, m.dataset.Name)
perm, err := m.dataset.Permissions()
if err != nil {
return errors.WrapPrefix(err, fmt.Sprintf("failed to mount zfs dataset %q", m.dataset.Name), 0)
}
var run execFunc
if perm.Mount {
run = util.RunCommand
} else {
run = util.RunPrivilegedCommand
}
_, err = run(ZfsBin, nil, mountArgs...)
if err != nil {
return errors.WrapPrefix(err, fmt.Sprintf("failed to mount zfs dataset %q", m.dataset.Name), 0)
}
return nil
}
func (m *Mounter) Mount() error {
if err := m.loadKey(); err != nil {
return err
}
if err := m.refresh(); err != nil {
return err
}
if err := m.mount(); err != nil {
return err
}
// check
if err := m.refresh(); err != nil {
return err
}
if !m.dataset.Mounted {
return errors.Errorf("zfs dataset %q is not mounted", m.dataset.Name)
}
return nil
}
func (m *Mounter) unmount() error {
if !m.dataset.Mounted {
return nil
}
// try to unmount with systemd
mp := m.dataset.MountPoint
if util.IsValidMountPoint(mp) {
if err := util.SystemdUnmount(mp); err == nil {
return nil
} else if !util.ShouldSkipSdMount(err) {
return errors.WrapPrefix(err, fmt.Sprintf("failed to unmount zfs dataset %q", m.dataset.Name), 0)
}
}
// try to unmount with zfs command.
perm, err := m.dataset.Permissions()
if err != nil {
return errors.WrapPrefix(err, fmt.Sprintf("failed to unmount zfs dataset %q", m.dataset.Name), 0)
}
var run execFunc
if perm.Mount {
run = util.RunCommand
} else {
run = util.RunPrivilegedCommand
}
_, err = run(ZfsBin, nil, "unmount", "-u", m.dataset.Name)
if err != nil {
return errors.WrapPrefix(err, fmt.Sprintf("failed to unmount zfs dataset %q", m.dataset.Name), 0)
}
return nil
}
func (m *Mounter) unloadKey() error {
if m.dataset.KeyStatus != KeyStatusAvailable || m.dataset.Name != m.dataset.EncryptionRoot {
return nil
}
perm, err := m.dataset.Permissions()
if err != nil {
return errors.WrapPrefix(err, "failed to unload zfs key", 0)
}
var run execFunc
if perm.LoadKey {
run = util.RunCommand
} else {
run = util.RunPrivilegedCommand
}
msg.Infof("zfs unload-key %q", m.dataset.Name)
_, err = run(ZfsBin, nil, "unload-key", m.dataset.Name)
if err != nil {
return errors.WrapPrefix(err, "failed to unload zfs key", 0)
}
return nil
}
func (m *Mounter) Unmount() error {
if err := m.unmount(); err != nil {
return errors.Wrap(err, 0)
}
// check
if err := m.refresh(); err != nil {
return errors.WrapPrefix(err, "failed to check for result of unmounting", 0)
}
if m.dataset.Mounted {
return errors.Errorf("zfs dataset %q is still mounted", m.dataset.Name)
}
if err := m.unloadKey(); err != nil {
return errors.Wrap(err, 0)
}
return nil
}
func match(s string) ([]*mnt.Preset, error) {
datasets, err := List("", true)
if err != nil {
return nil, errors.Wrap(err, 0)
}
var partialMatch []*Dataset
for _, d := range datasets {
if d.Name == s {
return []*mnt.Preset{{
Name: s,
Type: "zfs",
Path: d.Name,
}}, nil
}
if strings.HasSuffix(d.Name, "/"+s) {
partialMatch = append(partialMatch, d)
}
}
var ret []*mnt.Preset
for _, d := range partialMatch {
ret = append(ret, &mnt.Preset{
Name: s,
Type: "zfs",
Path: d.Name,
})
}
return ret, nil
}
func init() {
mnt.RegisterMounter("zfs", NewMounterFromPreset)
mnt.RegisterMatcher(match)
}