package blk import ( "bytes" "encoding/json" "regexp" "strings" "github.com/go-errors/errors" "golang.org/x/exp/slices" "gensokyo.cafe/xmnt/cfg" "gensokyo.cafe/xmnt/mnt" "gensokyo.cafe/xmnt/msg" "gensokyo.cafe/xmnt/util" ) type FSType string const FSTypeLuks = FSType("crypto_LUKS") type DevNum string var supportedFsTypes = []FSType{ "ext2", "ext3", "ext4", "xfs", "vfat", "ntfs", FSTypeLuks, } type BlkDev struct { Name *string `json:"name"` PKName *string `json:"pkname"` // name of parent device Path *string `json:"path"` FSType *FSType `json:"fstype"` MountPoint *string `json:"mountpoint"` UUID *string `json:"uuid"` DevNum *DevNum `json:"maj:min"` Children []*BlkDev `json:"children"` } func (b *BlkDev) IsSupportedType() bool { if b.FSType == nil { return false } return slices.Contains(supportedFsTypes, *b.FSType) } func (b *BlkDev) NeedsDecryption() bool { if b.FSType == nil || *b.FSType != FSTypeLuks { return false } return len(b.Children) == 0 } func (b *BlkDev) IsMounted() bool { if b.MountPoint != nil && (*b.MountPoint)[:1] == "/" { return true } for _, child := range b.Children { if child.IsMounted() { return true } } return false } func flattenBlkDevs(devs []*BlkDev) []*BlkDev { var ret []*BlkDev for _, dev := range devs { ret = append(ret, dev) ret = append(ret, flattenBlkDevs(dev.Children)...) } return ret } func List(path string) ([]*BlkDev, error) { lsblkArgs := []string{"-JT", "-o", "name,pkname,path,fstype,mountpoint,uuid"} if path != "" { lsblkArgs = append(lsblkArgs, path) } output, err := util.RunCommand("lsblk", nil, lsblkArgs...) if err != nil { return nil, errors.WrapPrefix(err, "cannot obtain list of block devices", 0) } resp := struct { BlockDevices []*BlkDev `json:"blockdevices"` }{} if err := json.Unmarshal(output, &resp); err != nil { return nil, errors.WrapPrefix(err, "cannot parse lsblk output", 0) } return flattenBlkDevs(resp.BlockDevices), nil } type Mounter struct { dev *BlkDev preset *mnt.Preset } func NewMounterFromPreset(p *mnt.Preset) (mnt.Mounter, error) { preset := *p if preset.Path == "" { return nil, errors.New("preset path is empty") } if !util.IsValidMountPoint(preset.MountPoint) { return nil, errors.Errorf("invalid mount point %q", preset.MountPoint) } m := &Mounter{ preset: &preset, } if err := m.refresh(); err != nil { return nil, errors.Wrap(err, 0) } return m, nil } func (m *Mounter) refresh() error { devs, err := List(m.preset.Path) if err != nil { return errors.Wrap(err, 0) } if len(devs) < 1 { return errors.Errorf("block device %q not found", m.preset.Path) } m.dev = devs[0] return nil } func escapeDmName(s string) string { b := regexp.MustCompile(`\W`).ReplaceAllLiteral([]byte(s), []byte{'_'}) b = regexp.MustCompile(`_+`).ReplaceAllLiteral(b, []byte{'_'}) b = bytes.Trim(b, "_") return string(b) } func (m *Mounter) dmName() string { for _, c := range m.dev.Children { if strings.HasPrefix(*c.Path, "/dev/mapper/") { return *c.Name } } // TODO get from property value of systemd mount unit if s := escapeDmName(m.preset.Name); s != "" { return s } return escapeDmName(*m.dev.Path) } func (m *Mounter) loadKey() error { if !m.dev.NeedsDecryption() { return nil } if m.dev.UUID == nil { return errors.Errorf("device %q does not have a UUID", *m.dev.Path) } dmName := m.dmName() if dmName == "" { return errors.Errorf("cannot determine device mapper name for device %q", *m.dev.Path) } cred, err := util.ReadCredentialFile(*m.dev.UUID, cfg.Cfg.CredentialStore) if err != nil { return errors.Wrap(err, 0) } msg.Infof("Opening device %q as %q", *m.dev.Path, dmName) if _, err := util.RunPrivilegedCommand( "cryptsetup", strings.NewReader(cred), "luksOpen", *m.dev.Path, dmName, ); err != nil { return errors.WrapPrefix(err, "cannot open luks device", 0) } return nil } func (m *Mounter) mount() error { if m.dev.IsMounted() { return nil } if !m.dev.IsSupportedType() || m.dev.NeedsDecryption() { return errors.Errorf("cannot mount %q: unsupported filesystem type %q", *m.dev.Path, *m.dev.FSType) } if len(m.dev.Children) > 1 { return errors.Errorf("cannot mount %q: can't handle multiple child devices", *m.dev.Path) } dev := m.dev if len(dev.Children) > 0 { dev = dev.Children[0] } mp := m.preset.MountPoint if err := util.SystemdMount(mp); err == nil { return nil } else if !util.ShouldSkipSdMount(err) { return errors.WrapPrefix(err, "cannot mount device", 0) } msg.Infof("Mounting %q on %q", *dev.Path, mp) mountOpts := []string{"-o", "noatime", *dev.Path, mp} if _, err := util.RunPrivilegedCommand("mount", nil, mountOpts...); err != nil { return errors.WrapPrefix(err, "failed to mount device", 0) } return nil } func (m *Mounter) Mount() (err error) { if err = m.loadKey(); err != nil { return errors.Wrap(err, 0) } if err = m.refresh(); err != nil { return errors.Wrap(err, 0) } if err = m.mount(); err != nil { return errors.Wrap(err, 0) } // TODO check return nil } func (m *Mounter) unmount() error { if !m.dev.IsMounted() { return nil } if len(m.dev.Children) > 1 { return errors.Errorf("cannot unmount %q: can't handle multiple child devices", *m.dev.Path) } dev := m.dev if len(dev.Children) > 0 { dev = dev.Children[0] } mp := *dev.MountPoint if err := util.SystemdUnmount(mp); err == nil { return nil } else if !util.ShouldSkipSdMount(err) { return errors.WrapPrefix(err, "cannot unmount device", 0) } msg.Infof("Unmounting %q on %q", *dev.Path, mp) if _, err := util.RunPrivilegedCommand("umount", nil, mp); err != nil { return errors.WrapPrefix(err, "failed to unmount device", 0) } return nil } func (m *Mounter) unloadKey() error { if *m.dev.FSType != FSTypeLuks { return nil } if m.dev.NeedsDecryption() { return nil } dmName := m.dmName() if dmName == "" { return errors.New("cannot determine device mapper name") } if _, err := util.RunPrivilegedCommand("cryptsetup", nil, "close", dmName); err != nil { return errors.WrapPrefix(err, "cannot close luks device", 0) } return nil } func (m *Mounter) Unmount() (err error) { if err = m.unmount(); err != nil { return errors.Wrap(err, 0) } if err = m.unloadKey(); err != nil { return errors.Wrap(err, 0) } // TODO check return nil } func match(s string) ([]*mnt.Preset, error) { if s == "" { return nil, nil } devices, err := List("") if err != nil { return nil, errors.Wrap(err, 0) } var partialMatch []*BlkDev for _, dev := range devices { if dev.Name != nil && *dev.Name == s || dev.Path != nil && *dev.Path == s { return []*mnt.Preset{{ Name: s, Type: "blk", Path: *dev.Path, }}, nil } if dev.Path != nil && strings.HasSuffix(*dev.Path, "/"+s) { partialMatch = append(partialMatch, dev) } } var ret []*mnt.Preset for _, dev := range partialMatch { ret = append(ret, &mnt.Preset{ Name: s, Type: "blk", Path: *dev.Path, }) } return ret, nil } func init() { mnt.RegisterMounter("blk", NewMounterFromPreset) mnt.RegisterMatcher(match) }