162 lines
4.5 KiB
Go
162 lines
4.5 KiB
Go
package etchosts
|
|
|
|
import (
|
|
"bufio"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"os"
|
|
"path"
|
|
"strings"
|
|
|
|
"github.com/pkg/errors"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// DefaultBanner is the default magic comment used to identify entries managed by etchosts
|
|
const DefaultBanner = "# ! MANAGED AUTOMATICALLY !"
|
|
|
|
// DefaultPath is the default path used to write hosts entries
|
|
const DefaultPath = "/etc/hosts"
|
|
|
|
// EtcHosts contains the options used to write hosts entries.
|
|
// The zero value can be used to write to DefaultPath using DefaultBanner as a marker.
|
|
type EtcHosts struct {
|
|
// Banner is the magic comment used to identify entries managed by etchosts; if not set, will use DefaultBanner.
|
|
// It must start with "#" to mark it as a comment.
|
|
Banner string
|
|
// Path is the path to the /etc/hosts file; if not set, will use DefaultPath.
|
|
Path string
|
|
// Logger is an optional logrus.StdLogger interface, used for debugging.
|
|
Logger log.StdLogger
|
|
}
|
|
|
|
// WriteEntries is used to write the hosts entries to EtcHosts.Path
|
|
// Each IP address with their (potentially multiple) hostnames are written to a line marked with EtcHosts.Banner, to
|
|
// avoid overwriting preexisting entries.
|
|
func (eh *EtcHosts) WriteEntries(ipsToNames map[string][]string) error {
|
|
hostsPath := eh.Path
|
|
if hostsPath == "" {
|
|
hostsPath = DefaultPath
|
|
}
|
|
|
|
// We do not want to create the hosts file; if it's not there, we probably have the wrong path.
|
|
etcHosts, err := os.OpenFile(hostsPath, os.O_RDWR, 0644)
|
|
if err != nil {
|
|
return errors.Wrapf(err, "could not open %s for reading", hostsPath)
|
|
}
|
|
defer etcHosts.Close()
|
|
|
|
// create tmpfile in same folder as
|
|
tmp, err := ioutil.TempFile(path.Dir(hostsPath), "etchosts")
|
|
if err != nil {
|
|
return errors.Wrap(err, "could not create tempfile")
|
|
}
|
|
|
|
// remove tempfile; this might fail if we managed to move it, which is ok
|
|
defer func(file *os.File) {
|
|
file.Close()
|
|
if err := os.Remove(file.Name()); err != nil && !os.IsNotExist(err) {
|
|
if eh.Logger != nil {
|
|
eh.Logger.Printf("unexpected error trying to remove temp file %s: %s", file.Name(), err)
|
|
}
|
|
}
|
|
}(tmp)
|
|
|
|
if err := eh.writeEntries(etcHosts, tmp, ipsToNames); err != nil {
|
|
return err
|
|
}
|
|
|
|
return eh.movePreservePerms(tmp, etcHosts)
|
|
}
|
|
|
|
func (eh *EtcHosts) writeEntries(orig io.Reader, dest io.Writer, ipsToNames map[string][]string) error {
|
|
banner := eh.Banner
|
|
if banner == "" {
|
|
banner = DefaultBanner
|
|
}
|
|
|
|
// go through file and update existing entries/prune nonexistent entries
|
|
scanner := bufio.NewScanner(orig)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if strings.HasSuffix(strings.TrimSpace(line), strings.TrimSpace(banner)) {
|
|
tokens := strings.Fields(line)
|
|
if len(tokens) < 1 {
|
|
continue // remove empty managed line
|
|
}
|
|
ip := tokens[0]
|
|
if names, ok := ipsToNames[ip]; ok {
|
|
err := eh.writeEntryWithBanner(dest, banner, ip, names)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
delete(ipsToNames, ip) // otherwise we'll append it again below
|
|
}
|
|
} else {
|
|
// keep original unmanaged line
|
|
fmt.Fprintf(dest, "%s\n", line)
|
|
}
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
return errors.Wrap(err, "error reading hosts file")
|
|
}
|
|
|
|
// append remaining entries to file
|
|
for ip, names := range ipsToNames {
|
|
if err := eh.writeEntryWithBanner(dest, banner, ip, names); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (eh *EtcHosts) writeEntryWithBanner(tmp io.Writer, banner, ip string, names []string) error {
|
|
if ip != "" && len(names) > 0 {
|
|
if eh.Logger != nil {
|
|
eh.Logger.Printf("writing entry for %s (%s)", ip, names)
|
|
}
|
|
if _, err := fmt.Fprintf(tmp, "%s\t%s\t%s\n", ip, strings.Join(names, " "), banner); err != nil {
|
|
return errors.Wrapf(err, "error writing entry for %s", ip)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (eh *EtcHosts) movePreservePerms(src, dst *os.File) error {
|
|
if err := src.Sync(); err != nil {
|
|
return errors.Wrapf(err, "could not sync changes to %s", src.Name())
|
|
}
|
|
|
|
etcHostsInfo, err := dst.Stat()
|
|
if err != nil {
|
|
return errors.Wrapf(err, "could not stat %s", dst.Name())
|
|
}
|
|
|
|
if err = os.Rename(src.Name(), dst.Name()); err != nil {
|
|
log.Infof("could not rename to %s; falling back to copy (%s)", dst.Name(), err)
|
|
|
|
if _, err := src.Seek(0, io.SeekStart); err != nil {
|
|
return err
|
|
}
|
|
if _, err := dst.Seek(0, io.SeekStart); err != nil {
|
|
return err
|
|
}
|
|
if err := dst.Truncate(0); err != nil {
|
|
return err
|
|
}
|
|
_, err = io.Copy(dst, src)
|
|
return err
|
|
}
|
|
|
|
// ensure we're not running with some umask that might break things
|
|
|
|
if err := src.Chmod(etcHostsInfo.Mode()); err != nil {
|
|
return errors.Wrapf(err, "could not chmod %s", src.Name())
|
|
}
|
|
// TODO: also keep user?
|
|
|
|
return nil
|
|
}
|