diff --git a/.gitignore b/.gitignore index fa4ddb8..5336ec4 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /xmnt +/.idea diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..99cf855 --- /dev/null +++ b/Makefile @@ -0,0 +1,12 @@ +BIN_NAME := xmnt + +build: + CGO_ENABLED=0 go build -trimpath -ldflags '-s -w' -o "$(BIN_NAME)" + +install: build + install -vDt "${HOME}/.local/bin" "$(BIN_NAME)" + +clean: + rm -vf "$(BIN_NAME)" + +.PHONY: build install clean diff --git a/blk/blk.go b/blk/blk.go new file mode 100644 index 0000000..f00c6f7 --- /dev/null +++ b/blk/blk.go @@ -0,0 +1,318 @@ +package blk + +import ( + "bytes" + "encoding/json" + "regexp" + "strings" + + "github.com/go-errors/errors" + "golang.org/x/exp/slices" + + "gensokyo.cafe/xmnt/cfg" + "gensokyo.cafe/xmnt/mnt" + "gensokyo.cafe/xmnt/msg" + "gensokyo.cafe/xmnt/util" +) + +type FSType string + +const FSTypeLuks = FSType("crypto_LUKS") + +type DevNum string + +var supportedFsTypes = []FSType{ + "ext2", "ext3", "ext4", "xfs", "vfat", "ntfs", FSTypeLuks, +} + +type BlkDev struct { + Name *string `json:"name"` + PKName *string `json:"pkname"` // name of parent device + Path *string `json:"path"` + FSType *FSType `json:"fstype"` + MountPoint *string `json:"mountpoint"` + UUID *string `json:"uuid"` + DevNum *DevNum `json:"maj:min"` + Children []*BlkDev `json:"children"` +} + +func (b *BlkDev) IsSupportedType() bool { + if b.FSType == nil { + return false + } + return slices.Contains(supportedFsTypes, *b.FSType) +} + +func (b *BlkDev) NeedsDecryption() bool { + if b.FSType == nil || *b.FSType != FSTypeLuks { + return false + } + return len(b.Children) == 0 +} + +func (b *BlkDev) IsMounted() bool { + if b.MountPoint != nil && (*b.MountPoint)[:1] == "/" { + return true + } + for _, child := range b.Children { + if child.IsMounted() { + return true + } + } + return false +} + +func flattenBlkDevs(devs []*BlkDev) []*BlkDev { + var ret []*BlkDev + for _, dev := range devs { + ret = append(ret, dev) + ret = append(ret, flattenBlkDevs(dev.Children)...) + } + return ret +} + +func List(path string) ([]*BlkDev, error) { + lsblkArgs := []string{"-JT", "-o", "name,pkname,path,fstype,mountpoint,uuid"} + if path != "" { + lsblkArgs = append(lsblkArgs, path) + } + + output, err := util.RunCommand("lsblk", nil, lsblkArgs...) + if err != nil { + return nil, errors.WrapPrefix(err, "cannot obtain list of block devices", 0) + } + resp := struct { + BlockDevices []*BlkDev `json:"blockdevices"` + }{} + if err := json.Unmarshal(output, &resp); err != nil { + return nil, errors.WrapPrefix(err, "cannot parse lsblk output", 0) + } + + return flattenBlkDevs(resp.BlockDevices), nil +} + +type Mounter struct { + dev *BlkDev + preset *mnt.Preset +} + +func NewMounterFromPreset(p *mnt.Preset) (mnt.Mounter, error) { + preset := &*p + if preset.Path == "" { + return nil, errors.New("preset path is empty") + } + if !util.IsValidMountPoint(preset.MountPoint) { + return nil, errors.Errorf("invalid mount point %q", preset.MountPoint) + } + m := &Mounter{ + preset: preset, + } + if err := m.refresh(); err != nil { + return nil, errors.Wrap(err, 0) + } + return m, nil +} + +func (m *Mounter) refresh() error { + devs, err := List(m.preset.Path) + if err != nil { + return errors.Wrap(err, 0) + } + if len(devs) < 1 { + return errors.Errorf("block device %q not found", m.preset.Path) + } + m.dev = devs[0] + return nil +} + +func escapeDmName(s string) string { + b := regexp.MustCompile(`\W`).ReplaceAllLiteral([]byte(s), []byte{'_'}) + b = regexp.MustCompile(`_+`).ReplaceAllLiteral(b, []byte{'_'}) + b = bytes.Trim(b, "_") + return string(b) +} + +func (m *Mounter) dmName() string { + for _, c := range m.dev.Children { + if strings.HasPrefix(*c.Path, "/dev/mapper/") { + return *c.Name + } + } + // TODO get from property value of systemd mount unit + if s := escapeDmName(m.preset.Name); s != "" { + return s + } + return escapeDmName(*m.dev.Path) +} + +func (m *Mounter) loadKey() error { + if !m.dev.NeedsDecryption() { + return nil + } + if m.dev.UUID == nil { + return errors.Errorf("device %q does not have a UUID", *m.dev.Path) + } + dmName := m.dmName() + if dmName == "" { + return errors.Errorf("cannot determine device mapper name for device %q", *m.dev.Path) + } + + cred, err := util.ReadCredentialFile(*m.dev.UUID, cfg.Cfg.CredentialStore) + if err != nil { + return errors.Wrap(err, 0) + } + msg.Infof("Opening device %q as %q", *m.dev.Path, dmName) + if _, err := util.RunPrivilegedCommand( + "cryptsetup", strings.NewReader(cred), + "luksOpen", *m.dev.Path, dmName, + ); err != nil { + return errors.WrapPrefix(err, "cannot open luks device", 0) + } + + return nil +} + +func (m *Mounter) mount() error { + if m.dev.IsMounted() { + return nil + } + if !m.dev.IsSupportedType() || m.dev.NeedsDecryption() { + return errors.Errorf("cannot mount %q: unsupported filesystem type %q", *m.dev.Path, *m.dev.FSType) + } + if len(m.dev.Children) > 1 { + return errors.Errorf("cannot mount %q: can't handle multiple child devices", *m.dev.Path) + } + + dev := m.dev + if len(dev.Children) > 0 { + dev = dev.Children[0] + } + mp := m.preset.MountPoint + + if err := util.SystemdMount(mp); err == nil { + return nil + } else if !util.ShouldSkipSdMount(err) { + return errors.WrapPrefix(err, "cannot mount device", 0) + } + + msg.Infof("Mounting %q on %q", *dev.Path, mp) + mountOpts := []string{"-o", "noatime", *dev.Path, mp} + if _, err := util.RunPrivilegedCommand("mount", nil, mountOpts...); err != nil { + return errors.WrapPrefix(err, "failed to mount device", 0) + } + + return nil +} + +func (m *Mounter) Mount() (err error) { + if err = m.loadKey(); err != nil { + return errors.Wrap(err, 0) + } + if err = m.refresh(); err != nil { + return errors.Wrap(err, 0) + } + if err = m.mount(); err != nil { + return errors.Wrap(err, 0) + } + // TODO check + + return nil +} + +func (m *Mounter) unmount() error { + if !m.dev.IsMounted() { + return nil + } + if len(m.dev.Children) > 1 { + return errors.Errorf("cannot unmount %q: can't handle multiple child devices", *m.dev.Path) + } + + dev := m.dev + if len(dev.Children) > 0 { + dev = dev.Children[0] + } + mp := *m.dev.MountPoint + + if err := util.SystemdUnmount(mp); err == nil { + return nil + } else if !util.ShouldSkipSdMount(err) { + return errors.WrapPrefix(err, "cannot unmount device", 0) + } + + msg.Infof("Unmounting %q on %q", *dev.Path, mp) + if _, err := util.RunPrivilegedCommand("umount", nil, mp); err != nil { + return errors.WrapPrefix(err, "failed to unmount device", 0) + } + + return nil +} + +func (m *Mounter) unloadKey() error { + if *m.dev.FSType != FSTypeLuks { + return nil + } + if m.dev.NeedsDecryption() { + return nil + } + + dmName := m.dmName() + if dmName == "" { + return errors.New("cannot determine device mapper name") + } + if _, err := util.RunPrivilegedCommand("cryptsetup", nil, "close", dmName); err != nil { + return errors.WrapPrefix(err, "cannot close luks device", 0) + } + return nil +} + +func (m *Mounter) Unmount() (err error) { + if err = m.unmount(); err != nil { + return errors.Wrap(err, 0) + } + if err = m.unloadKey(); err != nil { + return errors.Wrap(err, 0) + } + // TODO check + + return nil +} + +func match(s string) ([]*mnt.Preset, error) { + if s == "" { + return nil, nil + } + devices, err := List("") + if err != nil { + return nil, errors.Wrap(err, 0) + } + + var partialMatch []*BlkDev + for _, dev := range devices { + if dev.Name != nil && *dev.Name == s || + dev.Path != nil && *dev.Path == s { + return []*mnt.Preset{{ + Name: s, + Type: "blk", + Path: *dev.Path, + }}, nil + } + if dev.Path != nil && strings.HasSuffix(*dev.Path, "/"+s) { + partialMatch = append(partialMatch, dev) + } + } + + var ret []*mnt.Preset + for _, dev := range partialMatch { + ret = append(ret, &mnt.Preset{ + Name: s, + Type: "blk", + Path: *dev.Path, + }) + } + return ret, nil +} + +func init() { + mnt.RegisterMounter("blk", NewMounterFromPreset) + mnt.RegisterMatcher(match) +} diff --git a/cfg/cfg.go b/cfg/cfg.go new file mode 100644 index 0000000..81d9c7a --- /dev/null +++ b/cfg/cfg.go @@ -0,0 +1,51 @@ +package cfg + +import ( + "os" + "path/filepath" + + "github.com/go-errors/errors" + "github.com/ilyakaznacheev/cleanenv" + + "gensokyo.cafe/xmnt/util" +) + +type CfgDef struct { + CredentialStore []string `yaml:"credential_store" env-default:"$HOME/.vault2/data_encryption"` +} + +func (c *CfgDef) expand() { + for i, path := range c.CredentialStore { + c.CredentialStore[i] = os.ExpandEnv(path) + } +} + +var Cfg *CfgDef + +func LoadAuto() error { + loc, err := os.UserConfigDir() + if err != nil { + return errors.WrapPrefix(err, "cannot obtain user config dir", 0) + } + path := filepath.Join(loc, "xmnt", "xmnt.yml") + + loadFromFile, err := util.FileExists(path) + if err != nil { + return errors.WrapPrefix(err, "cannot read config file", 0) + } + + cfg := &CfgDef{} + if loadFromFile { + if err = cleanenv.ReadConfig(path, cfg); err != nil { + return errors.WrapPrefix(err, "cannot read config file", 0) + } + } else { + if err = cleanenv.ReadEnv(cfg); err != nil { + return errors.WrapPrefix(err, "cannot read config from env", 0) + } + } + + cfg.expand() + Cfg = cfg + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d8eaee2 --- /dev/null +++ b/go.mod @@ -0,0 +1,27 @@ +module gensokyo.cafe/xmnt + +go 1.19 + +require ( + github.com/bitfield/script v0.20.2 + github.com/coreos/go-systemd/v22 v22.4.0 + github.com/fatih/color v1.13.0 + github.com/go-errors/errors v1.4.2 + github.com/ilyakaznacheev/cleanenv v1.3.0 + golang.org/x/exp v0.0.0-20221002003631-540bb7301a08 + golang.org/x/text v0.3.7 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + bitbucket.org/creachadair/shell v0.0.7 // indirect + github.com/BurntSushi/toml v1.2.0 // indirect + github.com/godbus/dbus/v5 v5.0.4 // indirect + github.com/itchyny/gojq v0.12.7 // indirect + github.com/itchyny/timefmt-go v0.1.3 // indirect + github.com/joho/godotenv v1.4.0 // indirect + github.com/mattn/go-colorable v0.1.9 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect + golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f // indirect + olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..bf46776 --- /dev/null +++ b/go.sum @@ -0,0 +1,52 @@ +bitbucket.org/creachadair/shell v0.0.7 h1:Z96pB6DkSb7F3Y3BBnJeOZH2gazyMTWlvecSD4vDqfk= +bitbucket.org/creachadair/shell v0.0.7/go.mod h1:oqtXSSvSYr4624lnnabXHaBsYW6RD80caLi2b3hJk0U= +github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= +github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0= +github.com/BurntSushi/toml v1.2.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= +github.com/bitfield/script v0.20.2 h1:4DexsRtBILVMEn3EZwHbtJdDqdk43sXI8gM3F04JXgs= +github.com/bitfield/script v0.20.2/go.mod h1:l3AZPVAtKQrL03bwh7nlNTUtgrgSWurpJSbtqspYrOA= +github.com/coreos/go-systemd/v22 v22.4.0 h1:y9YHcjnjynCd/DVbg5j9L/33jQM3MxJlbj/zWskzfGU= +github.com/coreos/go-systemd/v22 v22.4.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= +github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= +github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= +github.com/godbus/dbus/v5 v5.0.4 h1:9349emZab16e7zQvpmsbtjc18ykshndd8y2PG3sgJbA= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/ilyakaznacheev/cleanenv v1.3.0 h1:RapuLclPPUbmdd5Bi5UXScwMEZA6+ZNLU5OW9itPjj0= +github.com/ilyakaznacheev/cleanenv v1.3.0/go.mod h1:i0owW+HDxeGKE0/JPREJOdSCPIyOnmh6C0xhWAkF/xA= +github.com/itchyny/gojq v0.12.7 h1:hYPTpeWfrJ1OT+2j6cvBScbhl0TkdwGM4bc66onUSOQ= +github.com/itchyny/gojq v0.12.7/go.mod h1:ZdvNHVlzPgUf8pgjnuDTmGfHA/21KoutQUJ3An/xNuw= +github.com/itchyny/timefmt-go v0.1.3 h1:7M3LGVDsqcd0VZH2U+x393obrzZisp7C0uEe921iRkU= +github.com/itchyny/timefmt-go v0.1.3/go.mod h1:0osSSCQSASBJMsIZnhAaF1C2fCBTJZXrnj37mG8/c+A= +github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg= +github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/mattn/go-colorable v0.1.9 h1:sqDoxXbdeALODt0DAeJCVp38ps9ZogZEAXjus69YV3U= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +golang.org/x/exp v0.0.0-20221002003631-540bb7301a08 h1:LtBIgSqNhkuC9gA3BFjGy5obHQT1lnmNsMDFSqWzQ5w= +golang.org/x/exp v0.0.0-20221002003631-540bb7301a08/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3 h1:slmdOY3vp8a7KQbHkL+FLbvbkgMqmXojpFUO/jENuqQ= +olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3/go.mod h1:oVgVk4OWVDi43qWBEyGhXgYxt7+ED4iYNpTngSLX2Iw= diff --git a/main.go b/main.go new file mode 100644 index 0000000..8124642 --- /dev/null +++ b/main.go @@ -0,0 +1,90 @@ +package main + +import ( + "flag" + "fmt" + "os" + + "gensokyo.cafe/xmnt/cfg" + "gensokyo.cafe/xmnt/mnt" + "gensokyo.cafe/xmnt/msg" +) + +var ( + unmount = flag.Bool("u", false, "Unmount instead of mount") +) + +func main() { + if err := cfg.LoadAuto(); err != nil { + panic(err) + } + + flag.Parse() + args := flag.Args() + + if len(args) == 0 { + flag.Usage() + os.Exit(1) + } + + var name, mountPoint string + name = args[0] + if len(args) > 1 { + mountPoint = args[1] + } + + matched, err := mnt.MatchAll(name) + if err != nil { + msg.Errorf("Failed to find match for %q: %v", name, err) + os.Exit(1) + } + if len(matched) == 0 { + msg.Errorf("No match for %q", name) + os.Exit(1) + } + + if len(matched) > 1 { + msg.Errorf("Ambiguous name %q", name) + msg.Infof("%d matches:", len(matched)) + for _, m := range matched { + msg.Infof(" %s", m) + } + os.Exit(1) + } + + preset := matched[0] + if mountPoint != "" { + preset.MountPoint = mountPoint + } + mounter, err := mnt.MounterFromPreset(preset) + if err != nil { + msg.Errorf("Failed to initialize mounter: %v", err) + os.Exit(1) + } + switch *unmount { + case false: + err = mounter.Mount() + case true: + err = mounter.Unmount() + } + if err != nil { + msg.Errorf("%v", mountPoint, err) + os.Exit(1) + } + msg.Infof("Success") +} + +func init() { + usageText := fmt.Sprintf("Usage: %s [options] [mountpoint]\n\nOptions:\n", os.Args[0]) + argsUsageText := ` +Arguments: + name Name or path of the device to mount. Can also be name of a preset + mountpoint Mount point (optional) +` + + flag.Usage = func() { + _, _ = fmt.Fprint(flag.CommandLine.Output(), usageText) + flag.PrintDefaults() + _, _ = fmt.Fprintln(flag.CommandLine.Output(), argsUsageText) + } +} diff --git a/mnt/mnt.go b/mnt/mnt.go new file mode 100644 index 0000000..c6daede --- /dev/null +++ b/mnt/mnt.go @@ -0,0 +1,68 @@ +package mnt + +import ( + "github.com/go-errors/errors" + + "gensokyo.cafe/xmnt/msg" +) + +type Mounter interface { + Mount() error + Unmount() error +} + +type ( + MounterFunc func(*Preset) (Mounter, error) + MatcherFunc func(string) ([]*Preset, error) +) + +var mounters = map[string]MounterFunc{} + +func RegisterMounter(typeName string, fn MounterFunc) { + mounters[typeName] = fn +} + +func MounterFromPreset(preset *Preset) (Mounter, error) { + if fn, ok := mounters[preset.Type]; !ok { + return nil, errors.Errorf("unknown type %q", preset.Type) + } else { + return fn(preset) + } +} + +var matchers []MatcherFunc + +func RegisterMatcher(fn MatcherFunc) { + matchers = append(matchers, fn) +} + +func MatchAll(s string) ([]*Preset, error) { + // Match against presets. If matched, return the preset (should be only one). + if matches, err := match(s); err != nil { + msg.Errorf("Failed to find match in presets: %v", err) + } else if len(matches) > 0 { + return matches[:1], nil + } + + // Run other matchers. Might produce multiple matches. + var ret []*Preset + rCh := make(chan []*Preset, 1) + + for _, fn := range matchers { + go func(fn MatcherFunc) { + matches, err := fn(s) + if err != nil { + msg.Errorf("%v", err) + rCh <- nil + } else { + rCh <- matches + } + }(fn) + } + for range matchers { + matches := <-rCh + ret = append(ret, matches...) + } + + return ret, nil +} diff --git a/mnt/preset.go b/mnt/preset.go new file mode 100644 index 0000000..8d5061d --- /dev/null +++ b/mnt/preset.go @@ -0,0 +1,97 @@ +package mnt + +import ( + "os" + "path/filepath" + "strings" + + "github.com/go-errors/errors" + "gopkg.in/yaml.v3" + + "gensokyo.cafe/xmnt/msg" + "gensokyo.cafe/xmnt/util" +) + +type Preset struct { + Name string `yaml:"name"` + Type string `yaml:"type"` + Path string `yaml:"path"` + MountPoint string `yaml:"mountpoint"` + + AuthCmd string `yaml:"auth_cmd"` // e.g. for loading the encryption key + MountCmd string `yaml:"mount_cmd"` + CheckCmd string `yaml:"check_cmd"` // check if the mounting was successful + + UnmountCmd string `yaml:"unmount_cmd"` + UnAuthCmd string `yaml:"unauth_cmd"` // e.g. for unloading the encryption key + UnmountCheckCmd string `yaml:"unmount_check_cmd"` +} + +func (p *Preset) String() string { + return p.Type + ": " + p.Name + " (" + p.Path + ")" +} + +func readPreset(path string) (*Preset, error) { + buf, err := os.ReadFile(path) + if err != nil { + return nil, errors.Wrap(err, 0) + } + var preset *Preset + if err = yaml.Unmarshal(buf, &preset); err != nil { + return nil, errors.Wrap(err, 0) + } + + if preset.Name == "" { + preset.Name = filepath.Base(path)[:len(filepath.Base(path))-4] // remove .yml + } + return preset, nil +} + +func ReadPresets() ([]*Preset, error) { + cfgDir, err := os.UserConfigDir() + if err != nil { + return nil, errors.WrapPrefix(err, "cannot obtain user config dir", 0) + } + + presetsDir := filepath.Join(cfgDir, "xmnt", "presets") + dirExist, err := util.DirExists(presetsDir) + if err != nil { + return nil, errors.Wrap(err, 0) + } + if !dirExist { + return nil, nil + } + + var ret []*Preset + + entries, err := os.ReadDir(presetsDir) + if err != nil { + return nil, errors.Wrap(err, 0) + } + for _, entry := range entries { + if !entry.Type().IsRegular() || !strings.HasSuffix(entry.Name(), ".yml") { + continue + } + preset, err := readPreset(filepath.Join(presetsDir, entry.Name())) + if err != nil { + msg.Errorf("Failed to read preset %s: %v", entry.Name(), err) + continue + } + ret = append(ret, preset) + } + return ret, nil +} + +func match(s string) ([]*Preset, error) { + presets, err := ReadPresets() + if err != nil { + return nil, errors.Wrap(err, 0) + } + for _, p := range presets { + // for presets, only return exact match. + if p.Name == s { + return []*Preset{p}, nil + } + } + return nil, nil +} diff --git a/mounters.go b/mounters.go new file mode 100644 index 0000000..e288b82 --- /dev/null +++ b/mounters.go @@ -0,0 +1,7 @@ +package main + +import ( + _ "gensokyo.cafe/xmnt/blk" + + _ "gensokyo.cafe/xmnt/zfs" +) \ No newline at end of file diff --git a/msg/msg.go b/msg/msg.go new file mode 100644 index 0000000..145ba8a --- /dev/null +++ b/msg/msg.go @@ -0,0 +1,22 @@ +package msg + +import ( + "os" + + "github.com/fatih/color" +) + +var ( + infoStyle = color.New(color.FgBlue) + errStyle = color.New(color.FgRed) +) + +func Infof(format string, a ...any) { + _, _ = infoStyle.Fprintf(os.Stderr, format, a...) + _, _ = os.Stderr.WriteString("\n") +} + +func Errorf(format string, a ...any) { + _, _ = errStyle.Fprintf(os.Stderr, format, a...) + _, _ = os.Stderr.WriteString("\n") +} diff --git a/util/command.go b/util/command.go new file mode 100644 index 0000000..ba2a920 --- /dev/null +++ b/util/command.go @@ -0,0 +1,39 @@ +package util + +import ( + "fmt" + "io" + "os" + "os/exec" + + "github.com/go-errors/errors" + + "gensokyo.cafe/xmnt/msg" +) + +func RunCommand(cmd string, input io.Reader, args ...string) ([]byte, error) { + command := exec.Command(cmd, args...) + if input != nil { + command.Stdin = input + } else { + command.Stdin = os.Stdin + } + command.Stderr = os.Stderr + byt, err := command.Output() + + if err != nil { + if e, ok := err.(*exec.ExitError); ok { + return nil, errors.New( + fmt.Sprintf("command %q %q failed with exit status %d", cmd, args, e.ExitCode()), + ) + } + return nil, errors.New(err) + } + return byt, nil +} + +func RunPrivilegedCommand(cmd string, input io.Reader, args ...string) ([]byte, error) { + realArgs := append([]string{cmd}, args...) + msg.Infof("Running command with sudo: %q", realArgs) + return RunCommand("sudo", input, realArgs...) +} diff --git a/util/mount.go b/util/mount.go new file mode 100644 index 0000000..4ef8381 --- /dev/null +++ b/util/mount.go @@ -0,0 +1,16 @@ +package util + +import ( + "path/filepath" + "regexp" +) + +var mountPointPattern = regexp.MustCompile(`^/[\w/.-]*$`) + +func IsValidMountPoint(mp string) bool { + if !mountPointPattern.MatchString(mp) { + return false + } + + return filepath.Clean(mp) == mp +} diff --git a/util/systemd.go b/util/systemd.go new file mode 100644 index 0000000..ed91cd5 --- /dev/null +++ b/util/systemd.go @@ -0,0 +1,124 @@ +package util + +import ( + "context" + "fmt" + "time" + + "github.com/coreos/go-systemd/v22/dbus" + "github.com/coreos/go-systemd/v22/unit" + "github.com/go-errors/errors" + + "gensokyo.cafe/xmnt/msg" +) + +type SystemdConnection struct { + conn *dbus.Conn +} + +func NewSystemdConnection() (*SystemdConnection, error) { + conn, err := dbus.NewSystemConnectionContext(context.Background()) + if err != nil { + return nil, errors.WrapPrefix(err, "failed to establish dbus connection with systemd", 0) + } + return &SystemdConnection{conn: conn}, nil +} + +func (c *SystemdConnection) Close() { + c.conn.Close() +} + +func (c *SystemdConnection) startOrStopUnit(name string, isStart bool) (err error) { + rCh := make(chan string, 1) + defer close(rCh) + + var ( + actionText = []string{"start", "starting"} + actionFn = c.conn.StartUnitContext + ) + if !isStart { + actionText = []string{"stop", "stopping"} + actionFn = c.conn.StopUnitContext + } + msg.Infof("%s systemd unit %q", Title(actionText[1]), name) + if _, err := actionFn(context.Background(), name, "replace", rCh); err != nil { + return errors.WrapPrefix(err, fmt.Sprintf("failed to %s systemd unit %q", actionText[0], name), 0) + } + + timeout := time.After(5 * time.Second) + select { + case result := <-rCh: + if result != "done" { + return errors.Errorf("failed to %s systemd unit %q: result is %q", actionText[0], name, result) + } + case <-timeout: + // FIXME: The go-systemd library does not write anything to the channel when the job is + // failed due to dependency not met. For now we use a timeout to deal with this case. + // Might cause panic because we also close the channel. + return errors.Errorf( + "failed to %s systemd unit %q: timed out waiting for response from systemd", + actionText[0], name, + ) + } + return nil +} + +func (c *SystemdConnection) StartUnit(name string) error { + return c.startOrStopUnit(name, true) +} + +func (c *SystemdConnection) StopUnit(name string) error { + return c.startOrStopUnit(name, false) +} + +func (c *SystemdConnection) FindMount(mountPoint string) (string, error) { + unitName := unit.UnitNamePathEscape(mountPoint) + ".mount" + + units, err := c.conn.ListUnitsByNamesContext(context.Background(), []string{unitName}) + if err != nil { + return "", errors.WrapPrefix(err, "cannot retrieve list of systemd units", 0) + } + for _, u := range units { + if u.LoadState == "loaded" { + return u.Name, nil + } + } + return "", nil +} + +var ( + ErrSdNotAvailable = fmt.Errorf("systemd is not available") + ErrSdMountNotFound = fmt.Errorf("systemd mount not found") +) + +func systemdMountOrUnmount(mountPoint string, isMount bool) error { + sd, err := NewSystemdConnection() + if err != nil { + return errors.WrapPrefix(ErrSdNotAvailable, err.Error(), 0) + } + defer sd.Close() + + sdUnit, err := sd.FindMount(mountPoint) + if err != nil { + return errors.WrapPrefix(err, "failed to find systemd mount", 0) + } + if sdUnit == "" { + return errors.Wrap(ErrSdMountNotFound, 0) + } + if err := sd.startOrStopUnit(sdUnit, isMount); err != nil { + return errors.Wrap(err, 0) + } + return nil +} + +func SystemdMount(mountPoint string) error { + return systemdMountOrUnmount(mountPoint, true) +} + +func SystemdUnmount(mountPoint string) error { + return systemdMountOrUnmount(mountPoint, false) +} + +func ShouldSkipSdMount(err error) bool { + return errors.Is(err, ErrSdNotAvailable) || errors.Is(err, ErrSdMountNotFound) +} diff --git a/util/util.go b/util/util.go new file mode 100644 index 0000000..77e8761 --- /dev/null +++ b/util/util.go @@ -0,0 +1,126 @@ +package util + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/bitfield/script" + "github.com/go-errors/errors" + "golang.org/x/text/cases" + "golang.org/x/text/language" + + "gensokyo.cafe/xmnt/msg" +) + +// FileExists checks if a given path exists. It returns error when the target +// exists but is not a regular file. +func FileExists(path string) (bool, error) { + fi, err := os.Stat(path) + if os.IsNotExist(err) { + return false, nil + } + if err != nil { + return false, err + } + if !fi.Mode().IsRegular() { + err = errors.Errorf("%q is not a regular file", path) + } + return true, err +} + +// DirExists checks if a given path exists. It returns error when the target +// exists but is not a directory. +func DirExists(path string) (bool, error) { + fi, err := os.Stat(path) + if os.IsNotExist(err) { + return false, nil + } + if err != nil { + return false, err + } + if !fi.Mode().IsDir() { + err = errors.Errorf("%q is not a directory", path) + } + return true, err +} + +func findCredentialFileFrom(dir, uuid string) (string, error) { + keyName := uuid + ".key" + keyPath := "" + if err := filepath.WalkDir(dir, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + if keyPath != "" { + return filepath.SkipDir + } + + if d.Type().IsRegular() && d.Name() == keyName { + keyPath = path + } + return nil + }); err != nil { + return "", errors.Wrap(err, 0) + } + return keyPath, nil +} + +func ReadCredentialFile(uuid string, dirs []string) (string, error) { + var ( + keyPath string + err error + keyText []byte + ) + + for _, dir := range dirs { + if keyPath != "" { + break + } + keyPath, err = findCredentialFileFrom(dir, uuid) + if err != nil { + return "", errors.WrapPrefix(err, "cannot read credential file", 0) + } + } + if keyPath == "" { + return "", errors.Errorf("cannot find credential file for uuid %q", uuid) + } + + isPgp, err := isPgpEncrypted(keyPath) + if err != nil { + msg.Errorf("Cannot determine type of credential file %q: %v", keyPath, err) + isPgp = false + } + if isPgp { + msg.Infof("Reading PGP encrypted credential from %q", keyPath) + keyText, err = readPgpEncryptedFile(keyPath) + } else { + msg.Infof("Reading credential from %q", keyPath) + keyText, err = os.ReadFile(keyPath) + } + if err != nil { + return "", errors.WrapPrefix(err, "cannot read credential file", 0) + } + return string(keyText), nil +} + +func isPgpEncrypted(path string) (bool, error) { + output, err := RunCommand("file", nil, path) + if err != nil { + return false, errors.Wrap(err, 0) + } + l, err := script.Echo(string(output)).Match("PGP").Match("encrypted").CountLines() + return l > 0, nil +} + +func readPgpEncryptedFile(path string) ([]byte, error) { + output, err := RunCommand("gpg", nil, "--decrypt", path) + if err != nil { + return nil, errors.WrapPrefix(err, fmt.Sprintf("failed to read pgp encrypted file %q", path), 0) + } + return output, nil +} + +func Title(s string) string { + return cases.Title(language.Und).String(s) +} diff --git a/zfs/zfs.go b/zfs/zfs.go new file mode 100644 index 0000000..29f1c82 --- /dev/null +++ b/zfs/zfs.go @@ -0,0 +1,431 @@ +package zfs + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os/user" + "regexp" + "strings" + + "github.com/bitfield/script" + "github.com/go-errors/errors" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" + + "gensokyo.cafe/xmnt/cfg" + "gensokyo.cafe/xmnt/mnt" + "gensokyo.cafe/xmnt/msg" + "gensokyo.cafe/xmnt/util" +) + +const ZfsBin = "/usr/bin/zfs" + +type execFunc func(string, io.Reader, ...string) ([]byte, error) + +type Permissions struct { + Mount bool + LoadKey bool +} + +var allPermission = Permissions{Mount: true, LoadKey: true} + +type KeyStatus string + +const ( + KeyStatusAvailable KeyStatus = "available" + KeyStatusUnavailable KeyStatus = "unavailable" +) + +type CanMount string + +const ( + CanMountOn CanMount = "on" + CanMountOff CanMount = "off" + CanMountNoAuto CanMount = "noauto" +) + +type Dataset struct { + Name string + GUID string + CanMount CanMount + MountPoint string + KeyStatus KeyStatus + EncryptionRoot string + Mounted bool + + permissions *Permissions +} + +func (d *Dataset) loadPermissions() error { + if d.permissions != nil { + return nil + } + currentUser, err := user.Current() + if err != nil { + return errors.WrapPrefix(err, "cannot obtain current user", 0) + } + if currentUser.Uid == "0" { + d.permissions = &allPermission + return nil + } + + permissions := &Permissions{} + + output, err := util.RunCommand(ZfsBin, nil, "allow", d.Name) + if err != nil { + return errors.WrapPrefix(err, fmt.Sprintf("cannot obtain permission list for zfs dataset %s", d.Name), 0) + } + pattern := regexp.MustCompile(`^\s*user ` + currentUser.Username + ` ([\w-]+(,[\w-]+)*)$`) + script.Echo(string(output)).FilterLine(func(line string) string { + match := pattern.FindStringSubmatch(line) + if match != nil { + for _, perm := range strings.Split(match[1], ",") { + switch perm { + case "mount": + permissions.Mount = true + case "load-key": + permissions.LoadKey = true + } + } + } + return "" + }).Wait() + d.permissions = permissions + + return nil +} + +func (d *Dataset) Permissions() (*Permissions, error) { + if err := d.loadPermissions(); err != nil { + return nil, err + } + return &*d.permissions, nil +} + +func listCmd(name string, recursive bool) ([]byte, error) { + zfsArgs := []string{"get", "-Ht", "filesystem", "guid,canmount,mountpoint,encryptionroot,keystatus,mounted"} + if recursive { + zfsArgs = append(zfsArgs, "-r") + } + if name != "" { + zfsArgs = append(zfsArgs, name) + } + + return util.RunCommand(ZfsBin, nil, zfsArgs...) +} + +func ParseListOutput(output []byte) ([]*Dataset, error) { + idx := map[string]*Dataset{} + + sc := bufio.NewScanner(bytes.NewReader(output)) + for sc.Scan() { + fields := strings.Split(sc.Text(), "\t") + if len(fields) != 4 { + return nil, errors.Errorf("invalid zfs list output: %q", sc.Text()) + } + + name := fields[0] + key := fields[1] + value := fields[2] + + dataset, ok := idx[name] + if !ok { + dataset = &Dataset{Name: name} + idx[name] = dataset + } + switch key { + case "guid": + dataset.GUID = value + case "canmount": + dataset.CanMount = CanMount(value) + case "mountpoint": + dataset.MountPoint = value + case "encryptionroot": + dataset.EncryptionRoot = value + case "keystatus": + dataset.KeyStatus = KeyStatus(value) + case "mounted": + dataset.Mounted = value == "yes" + } + } + + return maps.Values(idx), nil +} + +func List(name string, recursive bool) ([]*Dataset, error) { + listOutput, err := listCmd(name, recursive) + if err != nil { + return nil, errors.WrapPrefix(err, "cannot obtain list of zfs datasets", 0) + } + + datasets, err := ParseListOutput(listOutput) + if err != nil { + return nil, errors.WrapPrefix(err, "cannot parse zfs list output", 0) + } + return datasets, nil +} + +type Mounter struct { + preset *mnt.Preset + dataset *Dataset +} + +func NewMounterFromPreset(p *mnt.Preset) (mnt.Mounter, error) { + preset := &*p + if preset.Path == "" { + return nil, errors.New("preset path is empty") + } + m := &Mounter{ + preset: preset, + } + + // find the dataset + if err := m.refresh(); err != nil { + return nil, err + } + + return m, nil +} + +func (m *Mounter) refresh() error { + datasets, err := List(m.preset.Path, false) + if err != nil { + return errors.WrapPrefix(err, "cannot obtain list of zfs datasets", 0) + } + var newDataset *Dataset + for _, d := range datasets { + if d.Name == m.preset.Path { + newDataset = d + } + } + if newDataset == nil { + return errors.Errorf("cannot find zfs dataset %q", m.preset.Path) + } + m.dataset = newDataset + return nil +} + +func (m *Mounter) loadKey() error { + if m.dataset.KeyStatus != KeyStatusUnavailable { + return nil + } + if m.dataset.Name != m.dataset.EncryptionRoot { + return errors.Errorf("cannot load key for zfs dataset %q: not an encryption root", m.dataset.Name) + } + + key, err := util.ReadCredentialFile(m.dataset.GUID, cfg.Cfg.CredentialStore) + if err != nil { + return errors.WrapPrefix(err, "cannot load zfs key", 0) + } + + perm, err := m.dataset.Permissions() + if err != nil { + return errors.WrapPrefix(err, "failed to load zfs key", 0) + } + var run execFunc + if perm.LoadKey { + run = util.RunCommand + } else { + run = util.RunPrivilegedCommand + } + + msg.Infof("zfs load-key %q", m.dataset.Name) + _, err = run(ZfsBin, strings.NewReader(key), "load-key", m.dataset.Name) + if err != nil { + return errors.WrapPrefix(err, "failed to load zfs key", 0) + } + return nil +} + +func (m *Mounter) mount() error { + if m.dataset.Mounted { + return nil + } + if !slices.Contains([]CanMount{CanMountNoAuto, CanMountOn}, m.dataset.CanMount) { + return errors.Errorf("cannot mount zfs dataset %q: canmount is %q", m.dataset.Name, m.dataset.CanMount) + } + if m.dataset.KeyStatus != KeyStatusAvailable { + return errors.Errorf("cannot mount zfs dataset %q: not unlocked", m.dataset.Name) + } + + mountPoint := m.preset.MountPoint + if mountPoint == "" { + mountPoint = m.dataset.MountPoint + } + if !util.IsValidMountPoint(mountPoint) { + return errors.Errorf( + "cannot mount zfs dataset %q: invalid mount point %q", + m.dataset.Name, mountPoint, + ) + } + + // If systemd mount unit exists, use it. + // In this case, if we do not use systemd for mounting, systemd will mess with the mounting process, and the zfs + // dataset will get unmounted immediately after mounting. See https://github.com/openzfs/zfs/issues/11248 + if err := util.SystemdMount(mountPoint); err == nil { + return nil + } else if !util.ShouldSkipSdMount(err) { + return errors.WrapPrefix(err, fmt.Sprintf("failed to mount zfs dataset %q", m.dataset.Name), 0) + } + + // mount using zfs command + mountArgs := []string{"mount"} + if m.preset.MountPoint != "" { + // user specified the mount point + mountArgs = append(mountArgs, "-o", "mountpoint="+mountPoint) + } + mountArgs = append(mountArgs, m.dataset.Name) + + perm, err := m.dataset.Permissions() + if err != nil { + return errors.WrapPrefix(err, fmt.Sprintf("failed to mount zfs dataset %q", m.dataset.Name), 0) + } + var run execFunc + if perm.Mount { + run = util.RunCommand + } else { + run = util.RunPrivilegedCommand + } + + _, err = run(ZfsBin, nil, mountArgs...) + if err != nil { + return errors.WrapPrefix(err, fmt.Sprintf("failed to mount zfs dataset %q", m.dataset.Name), 0) + } + return nil +} + +func (m *Mounter) Mount() error { + if err := m.loadKey(); err != nil { + return err + } + if err := m.refresh(); err != nil { + return err + } + if err := m.mount(); err != nil { + return err + } + + // check + if err := m.refresh(); err != nil { + return err + } + if !m.dataset.Mounted { + return errors.Errorf("zfs dataset %q is not mounted", m.dataset.Name) + } + + return nil +} + +func (m *Mounter) unmount() error { + if !m.dataset.Mounted { + return nil + } + + // try to unmount with systemd + mp := m.dataset.MountPoint + if util.IsValidMountPoint(mp) { + if err := util.SystemdUnmount(mp); err == nil { + return nil + } else if !util.ShouldSkipSdMount(err) { + return errors.WrapPrefix(err, fmt.Sprintf("failed to unmount zfs dataset %q", m.dataset.Name), 0) + } + } + + // try to unmount with zfs command. + perm, err := m.dataset.Permissions() + if err != nil { + return errors.WrapPrefix(err, fmt.Sprintf("failed to unmount zfs dataset %q", m.dataset.Name), 0) + } + var run execFunc + if perm.Mount { + run = util.RunCommand + } else { + run = util.RunPrivilegedCommand + } + + _, err = run(ZfsBin, nil, "unmount", "-u", m.dataset.Name) + if err != nil { + return errors.WrapPrefix(err, fmt.Sprintf("failed to unmount zfs dataset %q", m.dataset.Name), 0) + } + return nil +} + +func (m *Mounter) unloadKey() error { + if m.dataset.KeyStatus != KeyStatusAvailable || m.dataset.Name != m.dataset.EncryptionRoot { + return nil + } + + perm, err := m.dataset.Permissions() + if err != nil { + return errors.WrapPrefix(err, "failed to unload zfs key", 0) + } + var run execFunc + if perm.LoadKey { + run = util.RunCommand + } else { + run = util.RunPrivilegedCommand + } + + msg.Infof("zfs unload-key %q", m.dataset.Name) + _, err = run(ZfsBin, nil, "unload-key", m.dataset.Name) + if err != nil { + return errors.WrapPrefix(err, "failed to unload zfs key", 0) + } + return nil +} + +func (m *Mounter) Unmount() error { + if err := m.unmount(); err != nil { + return errors.Wrap(err, 0) + } + + // check + if err := m.refresh(); err != nil { + return errors.WrapPrefix(err, "failed to check for result of unmounting", 0) + } + if m.dataset.Mounted { + return errors.Errorf("zfs dataset %q is still mounted", m.dataset.Name) + } + + if err := m.unloadKey(); err != nil { + return errors.Wrap(err, 0) + } + return nil +} + +func match(s string) ([]*mnt.Preset, error) { + datasets, err := List("", true) + if err != nil { + return nil, errors.Wrap(err, 0) + } + + var partialMatch []*Dataset + for _, d := range datasets { + if d.Name == s { + return []*mnt.Preset{{ + Name: s, + Type: "zfs", + Path: d.Name, + }}, nil + } + if strings.HasSuffix(d.Name, "/"+s) { + partialMatch = append(partialMatch, d) + } + } + var ret []*mnt.Preset + for _, d := range partialMatch { + ret = append(ret, &mnt.Preset{ + Name: s, + Type: "zfs", + Path: d.Name, + }) + } + return ret, nil +} + +func init() { + mnt.RegisterMounter("zfs", NewMounterFromPreset) + mnt.RegisterMatcher(match) +}